Compare commits

..

1 Commits

Author SHA1 Message Date
944913c0fa docs: clarify remaining v0 references 2025-10-06 10:59:13 -07:00
1502 changed files with 32963 additions and 62236 deletions

View File

@ -5,11 +5,11 @@ import os
import sys import sys
import zipfile import zipfile
# Read the VLLM_MAX_SIZE_MB environment variable, defaulting to 500 MiB # Read the VLLM_MAX_SIZE_MB environment variable, defaulting to 450 MiB
# Note that we have 800 MiB quota, please use it wisely. # Note that we have 800 MiB quota, please use it wisely.
# See https://github.com/pypi/support/issues/6326 . # See https://github.com/pypi/support/issues/6326 .
# Please also sync the value with the one in Dockerfile. # Please also sync the value with the one in Dockerfile.
VLLM_MAX_SIZE_MB = int(os.environ.get("VLLM_MAX_SIZE_MB", 500)) VLLM_MAX_SIZE_MB = int(os.environ.get("VLLM_MAX_SIZE_MB", 450))
def print_top_10_largest_files(zip_file): def print_top_10_largest_files(zip_file):

View File

@ -1,12 +0,0 @@
# For vllm script, with -t option (tensor parallel size).
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m HandH1998/QQQ-Llama-3-8b-g128 -b 32 -l 1000 -f 5 -t 1
model_name: "HandH1998/QQQ-Llama-3-8b-g128"
tasks:
- name: "gsm8k"
metrics:
- name: "exact_match,strict-match"
value: 0.419
- name: "exact_match,flexible-extract"
value: 0.416
limit: 1000
num_fewshot: 5

View File

@ -1,12 +0,0 @@
# For hf script, without -t option (tensor parallel size).
# bash .buildkite/lm-eval-harness/run-lm-eval-chartqa-vllm-vlm-baseline.sh -m meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8 -l 100 -t 8
model_name: "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8"
backend: "vllm-vlm"
tasks:
- name: "chartqa"
metrics:
- name: "relaxed_accuracy,none"
# TODO(zhewenl): model card is 0.90, but the actual score is 0.80.
value: 0.80
limit: 100
num_fewshot: 0

View File

@ -1,10 +0,0 @@
# For hf script, without -t option (tensor parallel size).
# bash .buildkite/lm-eval-harness/run-lm-eval-mmlupro-vllm-baseline.sh -m meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8 -l 250 -t 8 -f 5
model_name: "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8"
tasks:
- name: "mmlu_pro"
metrics:
- name: "exact_match,custom-extract"
value: 0.80
limit: 250 # will run on 250 * 14 subjects = 3500 samples
num_fewshot: 5

View File

@ -1,5 +1,4 @@
# For vllm script, with -t option (tensor parallel size) # bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m RedHatAI/Qwen2.5-VL-3B-Instruct-FP8-Dynamic -b auto -l 1319 -f 5 -t 1
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m RedHatAI/Qwen2.5-VL-3B-Instruct-FP8-Dynamic -l 1319 -t 1
model_name: "RedHatAI/Qwen2.5-VL-3B-Instruct-FP8-Dynamic" model_name: "RedHatAI/Qwen2.5-VL-3B-Instruct-FP8-Dynamic"
tasks: tasks:
- name: "gsm8k" - name: "gsm8k"

View File

@ -1,12 +0,0 @@
# For vllm script, with -t option (tensor parallel size).
# bash .buildkite/lm-eval-harness/run-lm-eval-chartqa-vllm-vlm-baseline.sh -m Qwen/Qwen2.5-VL-7B-Instruct -l 2500 -t 1
model_name: "Qwen/Qwen2.5-VL-7B-Instruct"
backend: "vllm-vlm"
tasks:
- name: "chartqa"
metrics:
- name: "relaxed_accuracy,none"
value: 0.855
limit: 2500
num_fewshot: 0

View File

@ -1 +0,0 @@
Meta-Llama-4-Maverick-17B-128E-Instruct-FP8.yaml

View File

@ -1 +0,0 @@
Meta-Llama-4-Maverick-17B-128E-Instruct-FP8-MM.yaml

View File

@ -1 +0,0 @@
Qwen2.5-VL-7B-Instruct.yaml

View File

@ -1,44 +0,0 @@
#!/bin/bash
# We can use this script to compute baseline accuracy on chartqa for vllm.
#
# Make sure you have lm-eval-harness installed:
# pip install lm-eval==0.4.9
usage() {
echo``
echo "Runs lm eval harness on ChartQA using multimodal vllm."
echo "This pathway is intended to be used to create baselines for "
echo "our correctness tests in vllm's CI."
echo
echo "usage: ${0} <options>"
echo
echo " -m - huggingface stub or local directory of the model"
echo " -l - limit number of samples to run"
echo " -t - tensor parallel size to run at"
echo
}
while getopts "m:l:t:" OPT; do
case ${OPT} in
m )
MODEL="$OPTARG"
;;
l )
LIMIT="$OPTARG"
;;
t )
TP_SIZE="$OPTARG"
;;
\? )
usage
exit 1
;;
esac
done
lm_eval --model vllm-vlm \
--model_args "pretrained=$MODEL,tensor_parallel_size=$TP_SIZE" \
--tasks chartqa \
--batch_size auto \
--apply_chat_template \
--limit $LIMIT

View File

View File

@ -1,50 +0,0 @@
#!/bin/bash
# We can use this script to compute baseline accuracy on MMLUPRO for vllm.
# We use this for fp8, which HF does not support.
#
# Make sure you have lm-eval-harness installed:
# pip install git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d#egg=lm-eval[api]
usage() {
echo``
echo "Runs lm eval harness on MMLU Pro using huggingface transformers."
echo "This pathway is intended to be used to create baselines for "
echo "our automated nm-test-accuracy workflow"
echo
echo "usage: ${0} <options>"
echo
echo " -m - huggingface stub or local directory of the model"
echo " -l - limit number of samples to run"
echo " -f - number of fewshot samples to use"
echo " -t - tensor parallel size to run at"
echo
}
while getopts "m:b:l:f:t:" OPT; do
case ${OPT} in
m )
MODEL="$OPTARG"
;;
b )
BATCH_SIZE="$OPTARG"
;;
l )
LIMIT="$OPTARG"
;;
f )
FEWSHOT="$OPTARG"
;;
t )
TP_SIZE="$OPTARG"
;;
\? )
usage
exit 1
;;
esac
done
lm_eval --model vllm \
--model_args "pretrained=$MODEL,tensor_parallel_size=$TP_SIZE,add_bos_token=true,trust_remote_code=true,max_model_len=4096" \
--tasks mmlu_pro --num_fewshot "$FEWSHOT" --limit "$LIMIT" \
--batch_size auto

View File

@ -19,27 +19,21 @@ RTOL = 0.08
def launch_lm_eval(eval_config, tp_size): def launch_lm_eval(eval_config, tp_size):
trust_remote_code = eval_config.get("trust_remote_code", False) trust_remote_code = eval_config.get("trust_remote_code", False)
max_model_len = eval_config.get("max_model_len", 4096) max_model_len = eval_config.get("max_model_len", 4096)
batch_size = eval_config.get("batch_size", "auto")
backend = eval_config.get("backend", "vllm")
model_args = ( model_args = (
f"pretrained={eval_config['model_name']}," f"pretrained={eval_config['model_name']},"
f"tensor_parallel_size={tp_size}," f"tensor_parallel_size={tp_size},"
f"enforce_eager=true," f"enforce_eager=true,"
f"add_bos_token=true," f"add_bos_token=true,"
f"trust_remote_code={trust_remote_code}," f"trust_remote_code={trust_remote_code},"
f"max_model_len={max_model_len}," f"max_model_len={max_model_len}"
) )
results = lm_eval.simple_evaluate( results = lm_eval.simple_evaluate(
model=backend, model="vllm",
model_args=model_args, model_args=model_args,
tasks=[task["name"] for task in eval_config["tasks"]], tasks=[task["name"] for task in eval_config["tasks"]],
num_fewshot=eval_config["num_fewshot"], num_fewshot=eval_config["num_fewshot"],
limit=eval_config["limit"], limit=eval_config["limit"],
# TODO(yeq): using chat template w/ fewshot_as_multiturn is supposed help batch_size="auto",
# text models. however, this is regressing measured strict-match for
# existing text models in CI, so only apply it for mm.
apply_chat_template=backend == "vllm-vlm",
batch_size=batch_size,
) )
return results return results

View File

@ -454,6 +454,11 @@ main() {
fi fi
check_hf_token check_hf_token
# Set to v1 to run v1 benchmark
if [[ "${ENGINE_VERSION:-v0}" == "v1" ]]; then
export VLLM_USE_V1=1
fi
# dependencies # dependencies
(which wget && which curl) || (apt-get update && apt-get install -y wget curl) (which wget && which curl) || (apt-get update && apt-get install -y wget curl)
(which jq) || (apt-get update && apt-get -y install jq) (which jq) || (apt-get update && apt-get -y install jq)

View File

@ -8,21 +8,7 @@ steps:
commands: commands:
# #NOTE: torch_cuda_arch_list is derived from upstream PyTorch build files here: # #NOTE: torch_cuda_arch_list is derived from upstream PyTorch build files here:
# https://github.com/pytorch/pytorch/blob/main/.ci/aarch64_linux/aarch64_ci_build.sh#L7 # https://github.com/pytorch/pytorch/blob/main/.ci/aarch64_linux/aarch64_ci_build.sh#L7
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.9.1 --build-arg VLLM_MAIN_CUDA_VERSION=12.9 --build-arg torch_cuda_arch_list='8.7 8.9 9.0 10.0+PTX 12.0' --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ." - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.9.1 --build-arg VLLM_MAIN_CUDA_VERSION=12.9 --build-arg torch_cuda_arch_list='8.7 9.0 10.0+PTX 12.0' --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ."
- "mkdir artifacts"
- "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'"
- "bash .buildkite/scripts/upload-wheels.sh"
env:
DOCKER_BUILDKIT: "1"
# aarch64 build.
- label: "Build arm64 CPU wheel"
depends_on: ~
id: build-wheel-arm64-cpu
agents:
queue: arm64_cpu_queue_postmerge
commands:
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile.cpu ."
- "mkdir artifacts" - "mkdir artifacts"
- "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'" - "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'"
- "bash .buildkite/scripts/upload-wheels.sh" - "bash .buildkite/scripts/upload-wheels.sh"
@ -62,7 +48,7 @@ steps:
agents: agents:
queue: cpu_queue_postmerge queue: cpu_queue_postmerge
commands: commands:
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.9.1 --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ." - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.9.1 --build-arg torch_cuda_arch_list='7.0 7.5 8.0 8.9 9.0+PTX' --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ."
- "mkdir artifacts" - "mkdir artifacts"
- "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'" - "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'"
- "bash .buildkite/scripts/upload-wheels.sh" - "bash .buildkite/scripts/upload-wheels.sh"
@ -90,7 +76,7 @@ steps:
queue: arm64_cpu_queue_postmerge queue: arm64_cpu_queue_postmerge
commands: commands:
- "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7" - "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7"
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.9.1 --build-arg FLASHINFER_AOT_COMPILE=true --build-arg torch_cuda_arch_list='8.7 8.9 9.0 10.0+PTX 12.0' --build-arg INSTALL_KV_CONNECTORS=true --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT-$(uname -m) --target vllm-openai --progress plain -f docker/Dockerfile ." - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.9.1 --build-arg FLASHINFER_AOT_COMPILE=true --build-arg torch_cuda_arch_list='8.7 9.0 10.0+PTX 12.0' --build-arg INSTALL_KV_CONNECTORS=true --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT-$(uname -m) --target vllm-openai --progress plain -f docker/Dockerfile ."
- "docker push public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT-$(uname -m)" - "docker push public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT-$(uname -m)"
# Add job to create multi-arch manifest # Add job to create multi-arch manifest
@ -156,22 +142,6 @@ steps:
env: env:
DOCKER_BUILDKIT: "1" DOCKER_BUILDKIT: "1"
- block: "Build arm64 CPU release image"
key: block-arm64-cpu-release-image-build
depends_on: ~
- label: "Build and publish arm64 CPU release image"
depends_on: block-arm64-cpu-release-image-build
agents:
queue: arm64_cpu_queue_postmerge
commands:
- "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7"
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --tag public.ecr.aws/q9t5s3a7/vllm-arm64-cpu-release-repo:$(buildkite-agent meta-data get release-version) --tag public.ecr.aws/q9t5s3a7/vllm-arm64-cpu-release-repo:latest --progress plain --target vllm-openai -f docker/Dockerfile.cpu ."
- "docker push public.ecr.aws/q9t5s3a7/vllm-arm64-cpu-release-repo:latest"
- "docker push public.ecr.aws/q9t5s3a7/vllm-arm64-cpu-release-repo:$(buildkite-agent meta-data get release-version)"
env:
DOCKER_BUILDKIT: "1"
- label: "Build and publish nightly multi-arch image to DockerHub" - label: "Build and publish nightly multi-arch image to DockerHub"
depends_on: depends_on:
- create-multi-arch-manifest - create-multi-arch-manifest

View File

@ -25,28 +25,25 @@ function cpu_tests() {
# offline inference # offline inference
podman exec -it "$container_id" bash -c " podman exec -it "$container_id" bash -c "
set -xve set -e
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m" >> $HOME/test_basic.log python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m"
# Run basic model test # Run basic model test
podman exec -it "$container_id" bash -c " podman exec -it "$container_id" bash -c "
set -evx set -e
pip install pytest pytest-asyncio einops peft Pillow soundfile transformers_stream_generator matplotlib pip install pytest pytest-asyncio einops peft Pillow soundfile transformers_stream_generator matplotlib
pip install sentence-transformers datamodel_code_generator pip install sentence-transformers datamodel_code_generator
pytest -v -s tests/models/language/generation/test_bart.py -m cpu_model
# Note: disable Bart until supports V1
# pytest -v -s tests/models/language/generation/test_bart.py -m cpu_model
pytest -v -s tests/models/language/generation/test_common.py::test_models[False-5-32-openai-community/gpt2] pytest -v -s tests/models/language/generation/test_common.py::test_models[False-5-32-openai-community/gpt2]
pytest -v -s tests/models/language/generation/test_common.py::test_models[False-5-32-facebook/opt-125m] pytest -v -s tests/models/language/generation/test_common.py::test_models[False-5-32-facebook/opt-125m]
pytest -v -s tests/models/language/generation/test_common.py::test_models[False-5-32-google/gemma-1.1-2b-it] pytest -v -s tests/models/language/generation/test_common.py::test_models[False-5-32-google/gemma-1.1-2b-it]
pytest -v -s tests/models/language/pooling/test_classification.py::test_models[float-jason9693/Qwen2.5-1.5B-apeach] pytest -v -s tests/models/language/pooling/test_classification.py::test_models[float-jason9693/Qwen2.5-1.5B-apeach]
# TODO: Below test case tests/models/language/pooling/test_embedding.py::test_models[True-ssmits/Qwen2-7B-Instruct-embed-base] fails on ppc64le. Disabling it for time being. pytest -v -s tests/models/language/pooling/test_embedding.py -m cpu_model"
# pytest -v -s tests/models/language/pooling/test_embedding.py -m cpu_model" >> $HOME/test_rest.log
} }
# All of CPU tests are expected to be finished less than 40 mins. # All of CPU tests are expected to be finished less than 40 mins.
export container_id export container_id
export -f cpu_tests export -f cpu_tests
timeout 120m bash -c cpu_tests timeout 40m bash -c cpu_tests

View File

@ -70,7 +70,7 @@ function cpu_tests() {
docker exec cpu-test-"$NUMA_NODE" bash -c " docker exec cpu-test-"$NUMA_NODE" bash -c "
set -e set -e
pytest -x -s -v \ pytest -x -s -v \
tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_logprobs" tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_logprobs[False-10-32-neuralmagic/Llama-3.2-1B-quantized.w8a8]"
# Note: disable it until supports V1 # Note: disable it until supports V1
# Run AWQ test # Run AWQ test

View File

@ -64,9 +64,10 @@ python3 -m pip install --progress-bar off git+https://github.com/thuml/depyf.git
&& python3 -m pip install --progress-bar off "lm-eval @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d" \ && python3 -m pip install --progress-bar off "lm-eval @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d" \
&& python3 -m pip install --progress-bar off hf-transfer tblib==3.1.0 && python3 -m pip install --progress-bar off hf-transfer tblib==3.1.0
echo "--- Python dependencies installed ---" echo "--- Python dependencies installed ---"
export VLLM_USE_V1=1
export VLLM_XLA_CHECK_RECOMPILATION=1 export VLLM_XLA_CHECK_RECOMPILATION=1
export VLLM_XLA_CACHE_PATH= export VLLM_XLA_CACHE_PATH=
echo "Using VLLM V1"
echo "--- Hardware Information ---" echo "--- Hardware Information ---"
# tpu-info # tpu-info

View File

@ -64,9 +64,10 @@ python3 -m pip install --progress-bar off git+https://github.com/thuml/depyf.git
&& python3 -m pip install --progress-bar off "lm-eval @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d" \ && python3 -m pip install --progress-bar off "lm-eval @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d" \
&& python3 -m pip install --progress-bar off hf-transfer tblib==3.1.0 && python3 -m pip install --progress-bar off hf-transfer tblib==3.1.0
echo "--- Python dependencies installed ---" echo "--- Python dependencies installed ---"
export VLLM_USE_V1=1
export VLLM_XLA_CHECK_RECOMPILATION=1 export VLLM_XLA_CHECK_RECOMPILATION=1
export VLLM_XLA_CACHE_PATH= export VLLM_XLA_CACHE_PATH=
echo "Using VLLM V1"
echo "--- Hardware Information ---" echo "--- Hardware Information ---"
# tpu-info # tpu-info

View File

@ -44,5 +44,6 @@ docker run \
pytest -v -s v1/structured_output pytest -v -s v1/structured_output
pytest -v -s v1/spec_decode --ignore=v1/spec_decode/test_max_len.py --ignore=v1/spec_decode/test_tree_attention.py pytest -v -s v1/spec_decode --ignore=v1/spec_decode/test_max_len.py --ignore=v1/spec_decode/test_tree_attention.py
pytest -v -s v1/kv_connector/unit --ignore=v1/kv_connector/unit/test_multi_connector.py --ignore=v1/kv_connector/unit/test_nixl_connector.py --ignore=v1/kv_connector/unit/test_shared_storage_connector.py pytest -v -s v1/kv_connector/unit --ignore=v1/kv_connector/unit/test_multi_connector.py --ignore=v1/kv_connector/unit/test_nixl_connector.py --ignore=v1/kv_connector/unit/test_shared_storage_connector.py
pytest -v -s v1/test_metrics
pytest -v -s v1/test_serial_utils.py pytest -v -s v1/test_serial_utils.py
' '

View File

@ -9,6 +9,6 @@ MAX_NUM_BATCHED_TOKENS=1024
TENSOR_PARALLEL_SIZE=1 TENSOR_PARALLEL_SIZE=1
MAX_MODEL_LEN=2048 MAX_MODEL_LEN=2048
DOWNLOAD_DIR=/mnt/disks/persist DOWNLOAD_DIR=/mnt/disks/persist
EXPECTED_THROUGHPUT=8.7 EXPECTED_THROUGHPUT=10.0
INPUT_LEN=1800 INPUT_LEN=1800
OUTPUT_LEN=128 OUTPUT_LEN=128

View File

@ -42,7 +42,7 @@ echo "lanching vllm..."
echo "logging to $VLLM_LOG" echo "logging to $VLLM_LOG"
echo echo
vllm serve $MODEL \ VLLM_USE_V1=1 vllm serve $MODEL \
--seed 42 \ --seed 42 \
--max-num-seqs $MAX_NUM_SEQS \ --max-num-seqs $MAX_NUM_SEQS \
--max-num-batched-tokens $MAX_NUM_BATCHED_TOKENS \ --max-num-batched-tokens $MAX_NUM_BATCHED_TOKENS \

File diff suppressed because it is too large Load Diff

View File

@ -296,7 +296,6 @@ steps:
- tests/v1 - tests/v1
commands: commands:
# split the test to avoid interference # split the test to avoid interference
- pytest -v -s -m 'not cpu_test' v1/core
- pytest -v -s v1/executor - pytest -v -s v1/executor
- pytest -v -s v1/kv_offload - pytest -v -s v1/kv_offload
- pytest -v -s v1/sample - pytest -v -s v1/sample
@ -318,7 +317,7 @@ steps:
no_gpu: true no_gpu: true
commands: commands:
# split the test to avoid interference # split the test to avoid interference
- pytest -v -s -m 'cpu_test' v1/core - pytest -v -s v1/core
- pytest -v -s v1/structured_output - pytest -v -s v1/structured_output
- pytest -v -s v1/test_serial_utils.py - pytest -v -s v1/test_serial_utils.py
- pytest -v -s -m 'cpu_test' v1/kv_connector/unit - pytest -v -s -m 'cpu_test' v1/kv_connector/unit
@ -384,12 +383,7 @@ steps:
--num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT \ --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT \
--ignore=lora/test_chatglm3_tp.py \ --ignore=lora/test_chatglm3_tp.py \
--ignore=lora/test_llama_tp.py \ --ignore=lora/test_llama_tp.py \
--ignore=lora/test_llm_with_multi_loras.py \ --ignore=lora/test_llm_with_multi_loras.py
--ignore=lora/test_olmoe_tp.py \
--ignore=lora/test_deepseekv2_tp.py \
--ignore=lora/test_gptoss.py \
--ignore=lora/test_qwen3moe_tp.py
parallelism: 4 parallelism: 4
- label: PyTorch Compilation Unit Tests # 15min - label: PyTorch Compilation Unit Tests # 15min
@ -405,10 +399,11 @@ steps:
- pytest -v -s compile/test_fusion_attn.py - pytest -v -s compile/test_fusion_attn.py
- pytest -v -s compile/test_functionalization.py - pytest -v -s compile/test_functionalization.py
- pytest -v -s compile/test_silu_mul_quant_fusion.py - pytest -v -s compile/test_silu_mul_quant_fusion.py
- pytest -v -s compile/test_sequence_parallelism.py
- pytest -v -s compile/test_async_tp.py
- pytest -v -s compile/test_fusion_all_reduce.py - pytest -v -s compile/test_fusion_all_reduce.py
- pytest -v -s compile/test_decorator.py - pytest -v -s compile/test_decorator.py
- pytest -v -s compile/test_noop_elimination.py - pytest -v -s compile/test_noop_elimination.py
- pytest -v -s compile/test_aot_compile.py
- label: PyTorch Fullgraph Smoke Test # 15min - label: PyTorch Fullgraph Smoke Test # 15min
timeout_in_minutes: 30 timeout_in_minutes: 30
@ -421,8 +416,8 @@ steps:
- pytest -v -s compile/test_basic_correctness.py - pytest -v -s compile/test_basic_correctness.py
- pytest -v -s compile/piecewise/ - pytest -v -s compile/piecewise/
- label: PyTorch Fullgraph Test # 22min - label: PyTorch Fullgraph Test # 20min
timeout_in_minutes: 35 timeout_in_minutes: 30
mirror_hardwares: [amdexperimental] mirror_hardwares: [amdexperimental]
torch_nightly: true torch_nightly: true
source_file_dependencies: source_file_dependencies:
@ -430,7 +425,6 @@ steps:
- tests/compile - tests/compile
commands: commands:
- pytest -v -s compile/test_full_graph.py - pytest -v -s compile/test_full_graph.py
- pytest -v -s compile/test_fusions_e2e.py
- label: Kernels Core Operation Test # 48min - label: Kernels Core Operation Test # 48min
timeout_in_minutes: 75 timeout_in_minutes: 75
@ -438,9 +432,8 @@ steps:
source_file_dependencies: source_file_dependencies:
- csrc/ - csrc/
- tests/kernels/core - tests/kernels/core
- tests/kernels/test_top_k_per_row.py
commands: commands:
- pytest -v -s kernels/core kernels/test_top_k_per_row.py - pytest -v -s kernels/core
- label: Kernels Attention Test %N # 23min - label: Kernels Attention Test %N # 23min
timeout_in_minutes: 35 timeout_in_minutes: 35
@ -533,9 +526,8 @@ steps:
# since torchao nightly is only compatible with torch nightly currently # since torchao nightly is only compatible with torch nightly currently
# https://github.com/pytorch/ao/issues/2919, we'll have to skip new torchao tests for now # https://github.com/pytorch/ao/issues/2919, we'll have to skip new torchao tests for now
# we can only upgrade after this is resolved # we can only upgrade after this is resolved
# TODO(jerryzh168): resolve the above comment - pip install --pre torchao==0.13.0.dev20250814 --index-url https://download.pytorch.org/whl/nightly/cu128
- uv pip install --system torchao==0.13.0 - VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization/
- VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization/ --ignore quantization/test_blackwell_moe.py
- label: LM Eval Small Models # 53min - label: LM Eval Small Models # 53min
timeout_in_minutes: 75 timeout_in_minutes: 75
@ -740,16 +732,6 @@ steps:
- pytest -v -s models/multimodal -m core_model --ignore models/multimodal/generation/test_whisper.py --ignore models/multimodal/processing - pytest -v -s models/multimodal -m core_model --ignore models/multimodal/generation/test_whisper.py --ignore models/multimodal/processing
- cd .. && VLLM_WORKER_MULTIPROC_METHOD=spawn pytest -v -s tests/models/multimodal/generation/test_whisper.py -m core_model # Otherwise, mp_method="spawn" doesn't work - cd .. && VLLM_WORKER_MULTIPROC_METHOD=spawn pytest -v -s tests/models/multimodal/generation/test_whisper.py -m core_model # Otherwise, mp_method="spawn" doesn't work
- label: Multi-Modal Accuracy Eval (Small Models) # 50min
timeout_in_minutes: 70
working_dir: "/vllm-workspace/.buildkite/lm-eval-harness"
source_file_dependencies:
- vllm/multimodal/
- vllm/inputs/
- vllm/v1/core/
commands:
- pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-mm-small.txt --tp-size=1
- label: Multi-Modal Models Test (Extended) 1 - label: Multi-Modal Models Test (Extended) 1
mirror_hardwares: [amdexperimental] mirror_hardwares: [amdexperimental]
optional: true optional: true
@ -813,8 +795,8 @@ steps:
# Whisper needs spawn method to avoid deadlock # Whisper needs spawn method to avoid deadlock
- VLLM_WORKER_MULTIPROC_METHOD=spawn python3 examples/offline_inference/audio_language.py --model-type whisper - VLLM_WORKER_MULTIPROC_METHOD=spawn python3 examples/offline_inference/audio_language.py --model-type whisper
- label: Blackwell Test # 21 min - label: Blackwell Test # 38 min
timeout_in_minutes: 30 timeout_in_minutes: 60
working_dir: "/vllm-workspace/" working_dir: "/vllm-workspace/"
gpu: b200 gpu: b200
# optional: true # optional: true
@ -827,6 +809,8 @@ steps:
- vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py - vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py
- vllm/model_executor/layers/quantization/utils/flashinfer_utils.py - vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
- vllm/v1/attention/backends/flashinfer.py - vllm/v1/attention/backends/flashinfer.py
- vllm/compilation/fusion.py
- vllm/compilation/fusion_attn.py
commands: commands:
- nvidia-smi - nvidia-smi
- python3 examples/offline_inference/basic/chat.py - python3 examples/offline_inference/basic/chat.py
@ -843,32 +827,13 @@ steps:
- pytest -v -s tests/kernels/quantization/test_nvfp4_scaled_mm.py - pytest -v -s tests/kernels/quantization/test_nvfp4_scaled_mm.py
- pytest -v -s tests/kernels/quantization/test_flashinfer_scaled_mm.py - pytest -v -s tests/kernels/quantization/test_flashinfer_scaled_mm.py
- pytest -v -s tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py - pytest -v -s tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py
- pytest -v -s tests/kernels/quantization/test_nvfp4_qutlass.py
- pytest -v -s tests/kernels/quantization/test_mxfp4_qutlass.py
- pytest -v -s tests/kernels/moe/test_nvfp4_moe.py - pytest -v -s tests/kernels/moe/test_nvfp4_moe.py
- pytest -v -s tests/kernels/moe/test_ocp_mx_moe.py - pytest -v -s tests/kernels/moe/test_mxfp4_moe.py
- pytest -v -s tests/kernels/moe/test_flashinfer.py # Fusion
- label: Blackwell Fusion Tests # 30 min
timeout_in_minutes: 40
working_dir: "/vllm-workspace/"
gpu: b200
source_file_dependencies:
- csrc/quantization/fp4/
- vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
- vllm/v1/attention/backends/flashinfer.py
- vllm/compilation/
# can affect pattern matching
- vllm/model_executor/layers/layernorm.py
- vllm/model_executor/layers/activation.py
- vllm/model_executor/layers/quantization/input_quant_fp8.py
commands:
- nvidia-smi
- pytest -v -s tests/compile/test_fusion_attn.py
- pytest -v -s tests/compile/test_silu_mul_quant_fusion.py
# this runner has 2 GPUs available even though num_gpus=2 is not set
- pytest -v -s tests/compile/test_fusion_all_reduce.py - pytest -v -s tests/compile/test_fusion_all_reduce.py
- pytest -v -s tests/compile/test_fusions_e2e.py - pytest -v -s tests/compile/test_fusion_attn.py::test_attention_quant_pattern
- pytest -v -s tests/kernels/moe/test_flashinfer.py
- pytest -v -s tests/compile/test_silu_mul_quant_fusion.py
- label: Blackwell GPT-OSS Eval - label: Blackwell GPT-OSS Eval
timeout_in_minutes: 60 timeout_in_minutes: 60
@ -902,7 +867,7 @@ steps:
- pytest -s -v tests/quantization/test_blackwell_moe.py - pytest -s -v tests/quantization/test_blackwell_moe.py
- label: Blackwell LM Eval Small Models - label: Blackwell LM Eval Small Models
timeout_in_minutes: 120 timeout_in_minutes: 75
gpu: b200 gpu: b200
optional: true # run on nightlies optional: true # run on nightlies
source_file_dependencies: source_file_dependencies:
@ -982,7 +947,6 @@ steps:
- pytest -v -s ./compile/test_basic_correctness.py - pytest -v -s ./compile/test_basic_correctness.py
- pytest -v -s ./compile/test_wrapper.py - pytest -v -s ./compile/test_wrapper.py
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed' - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed'
- VLLM_TEST_SAME_HOST=1 VLLM_TEST_WITH_DEFAULT_DEVICE_SET=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed'
- pytest -v -s distributed/test_sequence_parallel.py - pytest -v -s distributed/test_sequence_parallel.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s v1/shutdown - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s v1/shutdown
- pytest -v -s v1/worker/test_worker_memory_snapshot.py - pytest -v -s v1/worker/test_worker_memory_snapshot.py
@ -1026,11 +990,6 @@ steps:
- pytest -v -s plugins_tests/test_io_processor_plugins.py - pytest -v -s plugins_tests/test_io_processor_plugins.py
- pip uninstall prithvi_io_processor_plugin -y - pip uninstall prithvi_io_processor_plugin -y
# end io_processor plugins test # end io_processor plugins test
# begin stat_logger plugins test
- pip install -e ./plugins/vllm_add_dummy_stat_logger
- pytest -v -s plugins_tests/test_stats_logger_plugins.py
- pip uninstall dummy_stat_logger -y
# end stat_logger plugins test
# other tests continue here: # other tests continue here:
- pytest -v -s plugins_tests/test_scheduler_plugins.py - pytest -v -s plugins_tests/test_scheduler_plugins.py
- pip install -e ./plugins/vllm_add_dummy_model - pip install -e ./plugins/vllm_add_dummy_model
@ -1070,7 +1029,6 @@ steps:
- pytest -v -s -x lora/test_chatglm3_tp.py - pytest -v -s -x lora/test_chatglm3_tp.py
- pytest -v -s -x lora/test_llama_tp.py - pytest -v -s -x lora/test_llama_tp.py
- pytest -v -s -x lora/test_llm_with_multi_loras.py - pytest -v -s -x lora/test_llm_with_multi_loras.py
- pytest -v -s -x lora/test_olmoe_tp.py
- label: Weight Loading Multiple GPU Test # 33min - label: Weight Loading Multiple GPU Test # 33min
@ -1096,17 +1054,6 @@ steps:
- tests/weight_loading - tests/weight_loading
commands: commands:
- bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models-large.txt - bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models-large.txt
- label: NixlConnector PD accuracy tests (Distributed) # 30min
timeout_in_minutes: 30
working_dir: "/vllm-workspace/tests"
num_gpus: 4
source_file_dependencies:
- vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
- tests/v1/kv_connector/nixl_integration/
commands:
- uv pip install --system -r /vllm-workspace/requirements/kv_connectors.txt
- bash v1/kv_connector/nixl_integration/tp_config_sweep_accuracy_test.sh
##### multi gpus test ##### ##### multi gpus test #####
@ -1139,16 +1086,12 @@ steps:
- pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-large.txt --tp-size=4 - pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-large.txt --tp-size=4
##### H200 test ##### ##### H200 test #####
- label: Distributed Tests (H200) # optional - label: Distrubted Tests (H200) # optional
gpu: h200 gpu: h200
optional: true optional: true
working_dir: "/vllm-workspace/" working_dir: "/vllm-workspace/"
num_gpus: 2 num_gpus: 2
commands: commands:
- pytest -v -s tests/compile/test_async_tp.py
- pytest -v -s tests/compile/test_sequence_parallelism.py
- pytest -v -s tests/compile/test_fusion_all_reduce.py
- pytest -v -s tests/compile/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm
- pytest -v -s tests/distributed/test_context_parallel.py - pytest -v -s tests/distributed/test_context_parallel.py
- CUDA_VISIBLE_DEVICES=1,2 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048 - CUDA_VISIBLE_DEVICES=1,2 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048

View File

@ -1,10 +1,5 @@
[run] [run]
# Track the installed vllm package (this is what actually gets imported during tests) source = vllm
# Use wildcard pattern to match the installed location
source =
vllm
*/dist-packages/vllm
*/site-packages/vllm
omit = omit =
*/tests/* */tests/*
*/test_* */test_*
@ -17,16 +12,6 @@ omit =
*/benchmarks/* */benchmarks/*
*/docs/* */docs/*
[paths]
# Map all possible vllm locations to a canonical "vllm" path
# This ensures coverage.combine properly merges data from different test runs
source =
vllm
/vllm-workspace/src/vllm
/vllm-workspace/vllm
*/site-packages/vllm
*/dist-packages/vllm
[report] [report]
exclude_lines = exclude_lines =
pragma: no cover pragma: no cover

View File

@ -1,4 +0,0 @@
# Migrate from `yapf` & `isort` to `ruff`
d6953beb91da4e9c99be4c0a1304a2d24189535c
# Convert `Optional[x]` to `x | None` and `Union[x, y]` to `x | y`
8fcaaf6a165e661f63fc51be906bc05b0767332f

13
.github/CODEOWNERS vendored
View File

@ -5,7 +5,9 @@
/vllm/attention @LucasWilkinson /vllm/attention @LucasWilkinson
/vllm/attention/backends/abstract.py @WoosukKwon @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill /vllm/attention/backends/abstract.py @WoosukKwon @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill
/vllm/executor/executor_base.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill @22quinn /vllm/executor/executor_base.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill @22quinn
/vllm/worker/worker_base.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill @22quinn
/vllm/model_executor/layers/fused_moe @mgoin /vllm/model_executor/layers/fused_moe @mgoin
/vllm/model_executor/layers/sampler.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill @NickLucche
/vllm/model_executor/layers/quantization @mgoin @robertgshaw2-redhat @tlrmchlsmth @yewentao256 /vllm/model_executor/layers/quantization @mgoin @robertgshaw2-redhat @tlrmchlsmth @yewentao256
/vllm/model_executor/layers/mamba @tdoublep /vllm/model_executor/layers/mamba @tdoublep
/vllm/model_executor/model_loader @22quinn /vllm/model_executor/model_loader @22quinn
@ -24,6 +26,7 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson
/vllm/config/cache.py @simon-mo @WoosukKwon @youkaichao @robertgshaw2-redhat @mgoin @tlrmchlsmth @houseroad @hmellor @yewentao256 @ProExpertProg @heheda12345 /vllm/config/cache.py @simon-mo @WoosukKwon @youkaichao @robertgshaw2-redhat @mgoin @tlrmchlsmth @houseroad @hmellor @yewentao256 @ProExpertProg @heheda12345
# vLLM V1 # vLLM V1
/vllm/v1 @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @comaniac @alexm-redhat
/vllm/v1/attention @LucasWilkinson /vllm/v1/attention @LucasWilkinson
/vllm/v1/attention/backends/flashinfer.py @mgoin /vllm/v1/attention/backends/flashinfer.py @mgoin
/vllm/v1/attention/backends/triton_attn.py @tdoublep /vllm/v1/attention/backends/triton_attn.py @tdoublep
@ -57,7 +60,7 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson
/tests/v1/offloading @ApostaC /tests/v1/offloading @ApostaC
# Transformers backend # Transformers backend
/vllm/model_executor/models/transformers @hmellor /vllm/model_executor/models/transformers.py @hmellor
/tests/models/test_transformers.py @hmellor /tests/models/test_transformers.py @hmellor
# Docs # Docs
@ -118,11 +121,3 @@ mkdocs.yaml @hmellor
# KVConnector installation files # KVConnector installation files
/requirements/kv_connectors.txt @NickLucche /requirements/kv_connectors.txt @NickLucche
# Pooling models
/examples/*/pooling/ @noooop
/tests/models/*/pooling* @noooop
/tests/entrypoints/pooling @noooop
/vllm/config/pooler.py @noooop
/vllm/pooling_params.py @noooop
/vllm/model_executor/layers/pooler.py @noooop

2
.github/mergify.yml vendored
View File

@ -11,8 +11,6 @@ pull_request_rules:
label: label:
add: add:
- documentation - documentation
comment:
message: "Documentation preview: https://vllm--{{number}}.org.readthedocs.build/en/{{number}}/"
- name: label-ci-build - name: label-ci-build
description: Automatically apply ci/build label description: Automatically apply ci/build label

View File

@ -13,7 +13,6 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Label issues based on keywords - name: Label issues based on keywords
id: label-step
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
with: with:
script: | script: |
@ -43,6 +42,7 @@ jobs:
searchIn: "body" searchIn: "body"
}, },
], ],
// Substring search - matches anywhere in text (partial matches) // Substring search - matches anywhere in text (partial matches)
substrings: [ substrings: [
{ {
@ -89,12 +89,14 @@ jobs:
term: "hip_", term: "hip_",
searchIn: "both" searchIn: "both"
}, },
// ROCm tools and libraries // ROCm tools and libraries
{ {
term: "hipify", term: "hipify",
searchIn: "both" searchIn: "both"
}, },
], ],
// Regex patterns - for complex pattern matching // Regex patterns - for complex pattern matching
regexPatterns: [ regexPatterns: [
{ {
@ -105,17 +107,13 @@ jobs:
} }
], ],
}, },
// Add more label configurations here as needed
// example: {
// keywords: [...],
// substrings: [...],
// regexPatterns: [...]
// },
}; };
// Helper function to create regex based on search type // Helper function to create regex based on search type
function createSearchRegex(term, type) { function createSearchRegex(term, type) {
// Escape special regex characters in the term // Escape special regex characters in the term
const escapedTerm = term.replace(/[.*+?^${}()|[\]\\]/g, '\\$&'); const escapedTerm = term.replace(/[.*+?^${}()|[\]\\]/g, '\\$&');
switch (type) { switch (type) {
case 'keyword': case 'keyword':
// Word boundary search - matches whole words only // Word boundary search - matches whole words only
@ -127,13 +125,16 @@ jobs:
throw new Error(`Unknown search type: ${type}`); throw new Error(`Unknown search type: ${type}`);
} }
} }
// Helper function to find matching terms in text with line information // Helper function to find matching terms in text with line information
function findMatchingTermsWithLines(text, searchTerms = [], searchType = 'keyword', searchLocation = '') { function findMatchingTermsWithLines(text, searchTerms = [], searchType = 'keyword', searchLocation = '') {
const matches = []; const matches = [];
const lines = text.split('\n'); const lines = text.split('\n');
for (const termConfig of searchTerms) { for (const termConfig of searchTerms) {
let regex; let regex;
let term, searchIn, pattern, description, flags; let term, searchIn, pattern, description, flags;
// Handle different input formats (string or object) // Handle different input formats (string or object)
if (typeof termConfig === 'string') { if (typeof termConfig === 'string') {
term = termConfig; term = termConfig;
@ -145,17 +146,21 @@ jobs:
description = termConfig.description; description = termConfig.description;
flags = termConfig.flags; flags = termConfig.flags;
} }
// Skip if this term shouldn't be searched in the current location // Skip if this term shouldn't be searched in the current location
if (searchIn !== 'both' && searchIn !== searchLocation) { if (searchIn !== 'both' && searchIn !== searchLocation) {
continue; continue;
} }
// Create appropriate regex // Create appropriate regex
if (searchType === 'regex') { if (searchType === 'regex') {
regex = new RegExp(pattern, flags || "gi"); regex = new RegExp(pattern, flags || "gi");
} else { } else {
regex = createSearchRegex(term, searchType); regex = createSearchRegex(term, searchType);
} }
const termMatches = []; const termMatches = [];
// Check each line for matches // Check each line for matches
lines.forEach((line, lineIndex) => { lines.forEach((line, lineIndex) => {
const lineMatches = line.match(regex); const lineMatches = line.match(regex);
@ -170,14 +175,15 @@ jobs:
originalTerm: term || pattern, originalTerm: term || pattern,
description: description, description: description,
// Show context around the match in the line // Show context around the match in the line
context: line.length > 100 ? context: line.length > 100 ?
line.substring(Math.max(0, line.toLowerCase().indexOf(match.toLowerCase()) - 30), line.substring(Math.max(0, line.toLowerCase().indexOf(match.toLowerCase()) - 30),
line.toLowerCase().indexOf(match.toLowerCase()) + match.length + 30) + '...' line.toLowerCase().indexOf(match.toLowerCase()) + match.length + 30) + '...'
: line.trim() : line.trim()
}); });
}); });
} }
}); });
if (termMatches.length > 0) { if (termMatches.length > 0) {
matches.push({ matches.push({
term: term || (description || pattern), term: term || (description || pattern),
@ -190,48 +196,64 @@ jobs:
}); });
} }
} }
return matches; return matches;
} }
// Helper function to check if label should be added // Helper function to check if label should be added
async function processLabel(labelName, config) { async function processLabel(labelName, config) {
const body = context.payload.issue.body || ""; const body = context.payload.issue.body || "";
const title = context.payload.issue.title || ""; const title = context.payload.issue.title || "";
core.notice(`Processing label: ${labelName}`); core.notice(`Processing label: ${labelName}`);
core.notice(`Issue Title: "${title}"`); core.notice(`Issue Title: "${title}"`);
core.notice(`Issue Body length: ${body.length} characters`); core.notice(`Issue Body length: ${body.length} characters`);
let shouldAddLabel = false; let shouldAddLabel = false;
let allMatches = []; let allMatches = [];
let reason = ''; let reason = '';
const keywords = config.keywords || []; const keywords = config.keywords || [];
const substrings = config.substrings || []; const substrings = config.substrings || [];
const regexPatterns = config.regexPatterns || []; const regexPatterns = config.regexPatterns || [];
core.notice(`Searching with ${keywords.length} keywords, ${substrings.length} substrings, and ${regexPatterns.length} regex patterns`); core.notice(`Searching with ${keywords.length} keywords, ${substrings.length} substrings, and ${regexPatterns.length} regex patterns`);
// Search in title // Search in title
if (title.trim()) { if (title.trim()) {
core.notice(`Searching in title: "${title}"`); core.notice(`Searching in title: "${title}"`);
const titleKeywordMatches = findMatchingTermsWithLines(title, keywords, 'keyword', 'title'); const titleKeywordMatches = findMatchingTermsWithLines(title, keywords, 'keyword', 'title');
const titleSubstringMatches = findMatchingTermsWithLines(title, substrings, 'substring', 'title'); const titleSubstringMatches = findMatchingTermsWithLines(title, substrings, 'substring', 'title');
const titleRegexMatches = findMatchingTermsWithLines(title, regexPatterns, 'regex', 'title'); const titleRegexMatches = findMatchingTermsWithLines(title, regexPatterns, 'regex', 'title');
allMatches.push(...titleKeywordMatches, ...titleSubstringMatches, ...titleRegexMatches); allMatches.push(...titleKeywordMatches, ...titleSubstringMatches, ...titleRegexMatches);
} }
// Search in body // Search in body
if (body.trim()) { if (body.trim()) {
core.notice(`Searching in body (${body.length} characters)`); core.notice(`Searching in body (${body.length} characters)`);
const bodyKeywordMatches = findMatchingTermsWithLines(body, keywords, 'keyword', 'body'); const bodyKeywordMatches = findMatchingTermsWithLines(body, keywords, 'keyword', 'body');
const bodySubstringMatches = findMatchingTermsWithLines(body, substrings, 'substring', 'body'); const bodySubstringMatches = findMatchingTermsWithLines(body, substrings, 'substring', 'body');
const bodyRegexMatches = findMatchingTermsWithLines(body, regexPatterns, 'regex', 'body'); const bodyRegexMatches = findMatchingTermsWithLines(body, regexPatterns, 'regex', 'body');
allMatches.push(...bodyKeywordMatches, ...bodySubstringMatches, ...bodyRegexMatches); allMatches.push(...bodyKeywordMatches, ...bodySubstringMatches, ...bodyRegexMatches);
} }
if (allMatches.length > 0) { if (allMatches.length > 0) {
core.notice(`Found ${allMatches.length} matching term(s):`); core.notice(`Found ${allMatches.length} matching term(s):`);
for (const termMatch of allMatches) { for (const termMatch of allMatches) {
const locationText = termMatch.searchLocation === 'title' ? 'title' : 'body'; const locationText = termMatch.searchLocation === 'title' ? 'title' : 'body';
const searchInText = termMatch.searchIn === 'both' ? 'both' : termMatch.searchIn; const searchInText = termMatch.searchIn === 'both' ? 'both' : termMatch.searchIn;
if (termMatch.searchType === 'regex') { if (termMatch.searchType === 'regex') {
core.notice(` 📍 Regex: "${termMatch.term}" (pattern: ${termMatch.pattern}) found ${termMatch.count} time(s) in ${locationText} (configured to search in: ${searchInText}):`); core.notice(` 📍 Regex: "${termMatch.term}" (pattern: ${termMatch.pattern}) found ${termMatch.count} time(s) in ${locationText} (configured to search in: ${searchInText}):`);
} else { } else {
core.notice(` 📍 Term: "${termMatch.term}" (${termMatch.searchType} search) found ${termMatch.count} time(s) in ${locationText} (configured to search in: ${searchInText}):`); core.notice(` 📍 Term: "${termMatch.term}" (${termMatch.searchType} search) found ${termMatch.count} time(s) in ${locationText} (configured to search in: ${searchInText}):`);
} }
// Show details for each match // Show details for each match
termMatch.matches.forEach((match, index) => { termMatch.matches.forEach((match, index) => {
core.notice(` ${index + 1}. Line ${match.lineNumber} in ${match.searchLocation}: "${match.match}" [${match.searchType}]`); core.notice(` ${index + 1}. Line ${match.lineNumber} in ${match.searchLocation}: "${match.match}" [${match.searchType}]`);
@ -244,6 +266,7 @@ jobs:
} }
}); });
} }
shouldAddLabel = true; shouldAddLabel = true;
const totalMatches = allMatches.reduce((sum, t) => sum + t.count, 0); const totalMatches = allMatches.reduce((sum, t) => sum + t.count, 0);
const titleMatches = allMatches.filter(t => t.searchLocation === 'title').reduce((sum, t) => sum + t.count, 0); const titleMatches = allMatches.filter(t => t.searchLocation === 'title').reduce((sum, t) => sum + t.count, 0);
@ -251,10 +274,13 @@ jobs:
const keywordMatches = allMatches.filter(t => t.searchType === 'keyword').reduce((sum, t) => sum + t.count, 0); const keywordMatches = allMatches.filter(t => t.searchType === 'keyword').reduce((sum, t) => sum + t.count, 0);
const substringMatches = allMatches.filter(t => t.searchType === 'substring').reduce((sum, t) => sum + t.count, 0); const substringMatches = allMatches.filter(t => t.searchType === 'substring').reduce((sum, t) => sum + t.count, 0);
const regexMatches = allMatches.filter(t => t.searchType === 'regex').reduce((sum, t) => sum + t.count, 0); const regexMatches = allMatches.filter(t => t.searchType === 'regex').reduce((sum, t) => sum + t.count, 0);
reason = `Found ${totalMatches} total matches (${titleMatches} in title, ${bodyMatches} in body) - ${keywordMatches} keyword matches, ${substringMatches} substring matches, ${regexMatches} regex matches`; reason = `Found ${totalMatches} total matches (${titleMatches} in title, ${bodyMatches} in body) - ${keywordMatches} keyword matches, ${substringMatches} substring matches, ${regexMatches} regex matches`;
} }
core.notice(`Final decision: ${shouldAddLabel ? 'ADD LABEL' : 'DO NOT ADD LABEL'}`); core.notice(`Final decision: ${shouldAddLabel ? 'ADD LABEL' : 'DO NOT ADD LABEL'}`);
core.notice(`Reason: ${reason || 'No matching terms found'}`); core.notice(`Reason: ${reason || 'No matching terms found'}`);
if (shouldAddLabel) { if (shouldAddLabel) {
const existingLabels = context.payload.issue.labels.map(l => l.name); const existingLabels = context.payload.issue.labels.map(l => l.name);
if (!existingLabels.includes(labelName)) { if (!existingLabels.includes(labelName)) {
@ -270,92 +296,14 @@ jobs:
core.notice(`Label "${labelName}" already present.`); core.notice(`Label "${labelName}" already present.`);
return false; return false;
} }
core.notice(`No matching terms found for label "${labelName}".`); core.notice(`No matching terms found for label "${labelName}".`);
return false; return false;
} }
// Process all configured labels // Process all configured labels
const labelsAddedResults = await Promise.all( const processLabels = Object.entries(labelConfig)
Object.entries(labelConfig).map(([labelName, config]) => .map(([labelName, config]) => processLabel(labelName, config));
processLabel(labelName, config).then(added => ({ labelName, added })) const labelsAdded = await Promise.all(processLabels);
) const numLabelsAdded = labelsAdded.reduce((x, y) => x + y, 0);
); core.notice(`Processing complete. ${numLabelsAdded} label(s) added.`);
const numLabelsAdded = labelsAddedResults.filter(r => r.added).length;
core.notice(`Processing complete. ${numLabelsAdded} label(s) added.`);
// Return which labels were added for the next step
const addedLabels = labelsAddedResults.filter(r => r.added).map(r => r.labelName);
core.setOutput('labels_added', JSON.stringify(addedLabels));
return addedLabels;
- name: CC users for labeled issues
if: steps.label-step.outputs.labels_added != '[]'
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
with:
script: |
// Configuration: Map labels to GitHub users to CC
// You can add multiple users per label, and multiple label configurations
const ccConfig = {
rocm: {
users: ['hongxiayang', 'tjtanaa', 'vllmellm'], // Add more users as needed: ['user1', 'user2', 'user3']
message: 'CC {users} for ROCm-related issue' // {users} will be replaced with @mentions
},
// Add more label -> user mappings here
// Example:
// cuda: {
// users: ['user1', 'user2'],
// message: 'CC {users} for CUDA-related issue'
// },
// performance: {
// users: ['perfexpert'],
// message: 'CC {users} for performance issue'
// },
};
const labelsAdded = JSON.parse('${{ steps.label-step.outputs.labels_added }}');
core.notice(`Labels added: ${labelsAdded.join(', ')}`);
// Get existing comments to check for already mentioned users
const comments = await github.rest.issues.listComments({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: context.issue.number,
});
const issueBody = context.payload.issue.body || '';
const allExistingText = issueBody + '\n' + comments.data.map(c => c.body).join('\n');
// Process each label that was added
for (const label of labelsAdded) {
if (ccConfig[label]) {
const config = ccConfig[label];
const usersToMention = [];
// Check which users haven't been mentioned yet
for (const user of config.users) {
const mentionPattern = new RegExp(`@${user}\\b`, 'i');
if (!mentionPattern.test(allExistingText)) {
usersToMention.push(user);
} else {
core.notice(`@${user} already mentioned for label "${label}", skipping`);
}
}
// Post comment if there are users to mention
if (usersToMention.length > 0) {
const mentions = usersToMention.map(u => `@${u}`).join(' ');
const message = config.message.replace('{users}', mentions);
await github.rest.issues.createComment({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: context.issue.number,
body: message
});
core.notice(`CC comment added for label "${label}": ${mentions}`);
} else {
core.notice(`All users for label "${label}" already mentioned, skipping comment`);
}
}
}

3
.gitignore vendored
View File

@ -94,9 +94,6 @@ ipython_config.py
# generated files # generated files
**/generated/** **/generated/**
# uv
uv.lock
# pyenv # pyenv
# For a library or package, you might want to ignore these files since the code is # For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in: # intended to run in multiple environments; otherwise, check them in:

View File

@ -4,6 +4,7 @@ MD013: false
MD024: MD024:
siblings_only: true siblings_only: true
MD033: false MD033: false
MD042: false
MD045: false MD045: false
MD046: false MD046: false
MD051: false MD051: false

View File

@ -7,18 +7,17 @@ default_stages:
exclude: 'vllm/third_party/.*' exclude: 'vllm/third_party/.*'
repos: repos:
- repo: https://github.com/astral-sh/ruff-pre-commit - repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.14.0 rev: v0.13.3
hooks: hooks:
- id: ruff-check - id: ruff-check
args: [--output-format, github, --fix] args: [--output-format, github, --fix]
- id: ruff-format - id: ruff-format
- repo: https://github.com/crate-ci/typos - repo: https://github.com/crate-ci/typos
rev: v1.38.1 rev: v1.35.5
hooks: hooks:
- id: typos - id: typos
args: [--force-exclude]
- repo: https://github.com/pre-commit/mirrors-clang-format - repo: https://github.com/pre-commit/mirrors-clang-format
rev: v21.1.2 rev: v20.1.3
hooks: hooks:
- id: clang-format - id: clang-format
exclude: 'csrc/(moe/topk_softmax_kernels.cu|quantization/gguf/(ggml-common.h|dequantize.cuh|vecdotq.cuh|mmq.cuh|mmvq.cuh))|vllm/third_party/.*' exclude: 'csrc/(moe/topk_softmax_kernels.cu|quantization/gguf/(ggml-common.h|dequantize.cuh|vecdotq.cuh|mmq.cuh|mmvq.cuh))|vllm/third_party/.*'
@ -35,7 +34,7 @@ repos:
hooks: hooks:
- id: actionlint - id: actionlint
- repo: https://github.com/astral-sh/uv-pre-commit - repo: https://github.com/astral-sh/uv-pre-commit
rev: 0.9.1 rev: 0.6.17
hooks: hooks:
- id: pip-compile - id: pip-compile
args: [requirements/test.in, -o, requirements/test.txt, --index-strategy, unsafe-best-match, --torch-backend, cu128, --python-platform, x86_64-manylinux_2_28] args: [requirements/test.in, -o, requirements/test.txt, --index-strategy, unsafe-best-match, --torch-backend, cu128, --python-platform, x86_64-manylinux_2_28]
@ -56,6 +55,11 @@ repos:
types_or: [python, pyi] types_or: [python, pyi]
require_serial: true require_serial: true
additional_dependencies: [mypy==1.11.1, regex, types-cachetools, types-setuptools, types-PyYAML, types-requests, types-torch, pydantic] additional_dependencies: [mypy==1.11.1, regex, types-cachetools, types-setuptools, types-PyYAML, types-requests, types-torch, pydantic]
- id: mypy-3.9 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
name: Run mypy for Python 3.9
entry: python tools/pre_commit/mypy.py 1 "3.9"
<<: *mypy_common
stages: [manual] # Only run in CI
- id: mypy-3.10 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward - id: mypy-3.10 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
name: Run mypy for Python 3.10 name: Run mypy for Python 3.10
entry: python tools/pre_commit/mypy.py 1 "3.10" entry: python tools/pre_commit/mypy.py 1 "3.10"
@ -71,11 +75,6 @@ repos:
entry: python tools/pre_commit/mypy.py 1 "3.12" entry: python tools/pre_commit/mypy.py 1 "3.12"
<<: *mypy_common <<: *mypy_common
stages: [manual] # Only run in CI stages: [manual] # Only run in CI
- id: mypy-3.13 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
name: Run mypy for Python 3.13
entry: python tools/pre_commit/mypy.py 1 "3.13"
<<: *mypy_common
stages: [manual] # Only run in CI
- id: shellcheck - id: shellcheck
name: Lint shell scripts name: Lint shell scripts
entry: tools/shellcheck.sh entry: tools/shellcheck.sh

View File

@ -34,7 +34,7 @@ install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS)
# Supported python versions. These versions will be searched in order, the # Supported python versions. These versions will be searched in order, the
# first match will be selected. These should be kept in sync with setup.py. # first match will be selected. These should be kept in sync with setup.py.
# #
set(PYTHON_SUPPORTED_VERSIONS "3.10" "3.11" "3.12" "3.13") set(PYTHON_SUPPORTED_VERSIONS "3.9" "3.10" "3.11" "3.12" "3.13")
# Supported AMD GPU architectures. # Supported AMD GPU architectures.
set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1101;gfx1200;gfx1201;gfx1150;gfx1151") set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1101;gfx1200;gfx1201;gfx1150;gfx1151")
@ -269,8 +269,8 @@ set(VLLM_EXT_SRC
"csrc/sampler.cu" "csrc/sampler.cu"
"csrc/cuda_view.cu" "csrc/cuda_view.cu"
"csrc/quantization/gptq/q_gemm.cu" "csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/w8a8/int8/scaled_quant.cu" "csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
"csrc/quantization/w8a8/fp8/common.cu" "csrc/quantization/fp8/common.cu"
"csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu" "csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu"
"csrc/quantization/gguf/gguf_kernel.cu" "csrc/quantization/gguf/gguf_kernel.cu"
"csrc/quantization/activation_kernels.cu" "csrc/quantization/activation_kernels.cu"
@ -314,13 +314,12 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_EXT_SRC list(APPEND VLLM_EXT_SRC
"csrc/quantization/awq/gemm_kernels.cu" "csrc/quantization/awq/gemm_kernels.cu"
"csrc/permute_cols.cu" "csrc/permute_cols.cu"
"csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu" "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
"csrc/quantization/fp4/nvfp4_quant_entry.cu" "csrc/quantization/fp4/nvfp4_quant_entry.cu"
"csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu" "csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu"
"csrc/sparse/cutlass/sparse_scaled_mm_entry.cu" "csrc/sparse/cutlass/sparse_scaled_mm_entry.cu"
"csrc/cutlass_extensions/common.cpp" "csrc/cutlass_extensions/common.cpp"
"csrc/quantization/w8a8/fp8/per_token_group_quant.cu" "csrc/quantization/fp8/per_token_group_quant.cu")
"csrc/quantization/w8a8/int8/per_token_group_quant.cu")
set_gencode_flags_for_srcs( set_gencode_flags_for_srcs(
SRCS "${VLLM_EXT_SRC}" SRCS "${VLLM_EXT_SRC}"
@ -424,11 +423,11 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}") cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND SCALED_MM_ARCHS) if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND SCALED_MM_ARCHS)
set(SRCS set(SRCS
"csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm90.cu" "csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90.cu"
"csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_fp8.cu" "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu"
"csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_int8.cu" "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu"
"csrc/quantization/w8a8/cutlass/c3x/scaled_mm_azp_sm90_int8.cu" "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_azp_sm90_int8.cu"
"csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8.cu") "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu")
set_gencode_flags_for_srcs( set_gencode_flags_for_srcs(
SRCS "${SRCS}" SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_ARCHS}") CUDA_ARCHS "${SCALED_MM_ARCHS}")
@ -459,9 +458,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
endif() endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
set(SRCS set(SRCS
"csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm120.cu" "csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm120.cu"
"csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8.cu" "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm120_fp8.cu"
"csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8.cu" "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8.cu"
) )
set_gencode_flags_for_srcs( set_gencode_flags_for_srcs(
SRCS "${SRCS}" SRCS "${SRCS}"
@ -493,9 +492,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
endif() endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
set(SRCS set(SRCS
"csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm100.cu" "csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu"
"csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8.cu" "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu"
"csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm100_fp8.cu" "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu"
) )
set_gencode_flags_for_srcs( set_gencode_flags_for_srcs(
SRCS "${SRCS}" SRCS "${SRCS}"
@ -526,7 +525,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# subtract out the archs that are already built for 3x # subtract out the archs that are already built for 3x
list(REMOVE_ITEM SCALED_MM_2X_ARCHS ${SCALED_MM_3X_ARCHS}) list(REMOVE_ITEM SCALED_MM_2X_ARCHS ${SCALED_MM_3X_ARCHS})
if (SCALED_MM_2X_ARCHS) if (SCALED_MM_2X_ARCHS)
set(SRCS "csrc/quantization/w8a8/cutlass/scaled_mm_c2x.cu") set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu")
set_gencode_flags_for_srcs( set_gencode_flags_for_srcs(
SRCS "${SRCS}" SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_2X_ARCHS}") CUDA_ARCHS "${SCALED_MM_2X_ARCHS}")
@ -649,7 +648,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# if it's possible to compile MoE kernels that use its output. # if it's possible to compile MoE kernels that use its output.
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a" "${CUDA_ARCHS}") cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS) if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS)
set(SRCS "csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm90.cu") set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm90.cu")
set_gencode_flags_for_srcs( set_gencode_flags_for_srcs(
SRCS "${SRCS}" SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_ARCHS}") CUDA_ARCHS "${SCALED_MM_ARCHS}")
@ -673,7 +672,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a" "${CUDA_ARCHS}") cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a" "${CUDA_ARCHS}")
endif() endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
set(SRCS "csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm100.cu") set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm100.cu")
set_gencode_flags_for_srcs( set_gencode_flags_for_srcs(
SRCS "${SRCS}" SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_ARCHS}") CUDA_ARCHS "${SCALED_MM_ARCHS}")
@ -698,7 +697,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0a;10.1a;10.3a;12.0a;12.1a" "${CUDA_ARCHS}") cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0a;10.1a;10.3a;12.0a;12.1a" "${CUDA_ARCHS}")
endif() endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS) if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS)
set(SRCS "csrc/quantization/w8a8/cutlass/moe/moe_data.cu") set(SRCS "csrc/quantization/cutlass_w8a8/moe/moe_data.cu")
set_gencode_flags_for_srcs( set_gencode_flags_for_srcs(
SRCS "${SRCS}" SRCS "${SRCS}"
CUDA_ARCHS "${CUTLASS_MOE_DATA_ARCHS}") CUDA_ARCHS "${CUTLASS_MOE_DATA_ARCHS}")
@ -721,7 +720,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a;12.0a;12.1a" "${CUDA_ARCHS}") cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a;12.0a;12.1a" "${CUDA_ARCHS}")
endif() endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
set(SRCS "csrc/quantization/w8a8/cutlass/moe/blockwise_scaled_group_mm_sm100.cu") set(SRCS "csrc/quantization/cutlass_w8a8/moe/blockwise_scaled_group_mm_sm100.cu")
set_gencode_flags_for_srcs( set_gencode_flags_for_srcs(
SRCS "${SRCS}" SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_ARCHS}") CUDA_ARCHS "${SCALED_MM_ARCHS}")
@ -883,7 +882,6 @@ target_compile_definitions(_C PRIVATE CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1)
set(VLLM_MOE_EXT_SRC set(VLLM_MOE_EXT_SRC
"csrc/moe/torch_bindings.cpp" "csrc/moe/torch_bindings.cpp"
"csrc/moe/moe_align_sum_kernels.cu" "csrc/moe/moe_align_sum_kernels.cu"
"csrc/moe/moe_lora_align_sum_kernels.cu"
"csrc/moe/topk_softmax_kernels.cu") "csrc/moe/topk_softmax_kernels.cu")
if(VLLM_GPU_LANG STREQUAL "CUDA") if(VLLM_GPU_LANG STREQUAL "CUDA")
@ -1008,7 +1006,6 @@ endif()
# For CUDA we also build and ship some external projects. # For CUDA we also build and ship some external projects.
if (VLLM_GPU_LANG STREQUAL "CUDA") if (VLLM_GPU_LANG STREQUAL "CUDA")
include(cmake/external_projects/flashmla.cmake) include(cmake/external_projects/flashmla.cmake)
include(cmake/external_projects/qutlass.cmake)
# vllm-flash-attn should be last as it overwrites some CMake functions # vllm-flash-attn should be last as it overwrites some CMake functions
include(cmake/external_projects/vllm_flash_attn.cmake) include(cmake/external_projects/vllm_flash_attn.cmake)

View File

@ -149,7 +149,6 @@ Compute Resources:
- Trainy - Trainy
- UC Berkeley - UC Berkeley
- UC San Diego - UC San Diego
- Volcengine
Slack Sponsor: Anyscale Slack Sponsor: Anyscale

View File

@ -74,7 +74,7 @@ start_server() {
local vllm_log=$4 local vllm_log=$4
local profile_dir=$5 local profile_dir=$5
pkill -if "vllm serve" || true pkill -if vllm
# Define the common arguments as a bash array. # Define the common arguments as a bash array.
# Each argument and its value are separate elements. # Each argument and its value are separate elements.
@ -96,11 +96,11 @@ start_server() {
# This correctly passes each element as a separate argument. # This correctly passes each element as a separate argument.
if [[ -n "$profile_dir" ]]; then if [[ -n "$profile_dir" ]]; then
# Start server with profiling enabled # Start server with profiling enabled
VLLM_SERVER_DEV_MODE=1 VLLM_TORCH_PROFILER_DIR=$profile_dir \ VLLM_USE_V1=1 VLLM_SERVER_DEV_MODE=1 VLLM_TORCH_PROFILER_DIR=$profile_dir \
vllm serve "${common_args_array[@]}" > "$vllm_log" 2>&1 & vllm serve "${common_args_array[@]}" > "$vllm_log" 2>&1 &
else else
# Start server without profiling # Start server without profiling
VLLM_SERVER_DEV_MODE=1 \ VLLM_USE_V1=1 VLLM_SERVER_DEV_MODE=1 \
vllm serve "${common_args_array[@]}" > "$vllm_log" 2>&1 & vllm serve "${common_args_array[@]}" > "$vllm_log" 2>&1 &
fi fi
local server_pid=$! local server_pid=$!
@ -139,7 +139,7 @@ run_benchmark() {
echo "vllm_log: $vllm_log" echo "vllm_log: $vllm_log"
echo echo
rm -f $vllm_log rm -f $vllm_log
pkill -if "vllm serve" || true pkill -if vllm
echo "starting server..." echo "starting server..."
# Call start_server without a profile_dir to avoid profiling overhead # Call start_server without a profile_dir to avoid profiling overhead
@ -232,7 +232,7 @@ run_benchmark() {
echo "best_max_num_seqs: $best_max_num_seqs, best_num_batched_tokens: $best_num_batched_tokens, best_throughput: $best_throughput" echo "best_max_num_seqs: $best_max_num_seqs, best_num_batched_tokens: $best_num_batched_tokens, best_throughput: $best_throughput"
pkill -if "vllm serve" || true pkill -if vllm
sleep 10 sleep 10
echo "====================" echo "===================="
return 0 return 0
@ -308,6 +308,6 @@ if (( $(echo "$best_throughput > 0" | bc -l) )); then
else else
echo "No configuration met the latency requirements. Skipping final profiling run." echo "No configuration met the latency requirements. Skipping final profiling run."
fi fi
pkill -if "vllm serve" || true pkill -if vllm
echo "best_max_num_seqs: $best_max_num_seqs, best_num_batched_tokens: $best_num_batched_tokens, best_throughput: $best_throughput, profile saved in: $PROFILE_PATH" echo "best_max_num_seqs: $best_max_num_seqs, best_num_batched_tokens: $best_num_batched_tokens, best_throughput: $best_throughput, profile saved in: $PROFILE_PATH"
echo "best_max_num_seqs: $best_max_num_seqs, best_num_batched_tokens: $best_num_batched_tokens, best_throughput: $best_throughput, profile saved in: $PROFILE_PATH" >> "$RESULT" echo "best_max_num_seqs: $best_max_num_seqs, best_num_batched_tokens: $best_num_batched_tokens, best_throughput: $best_throughput, profile saved in: $PROFILE_PATH" >> "$RESULT"

View File

@ -8,6 +8,7 @@ import sys
import time import time
import traceback import traceback
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional, Union
import aiohttp import aiohttp
import huggingface_hub.constants import huggingface_hub.constants
@ -27,13 +28,13 @@ class RequestFuncInput:
prompt_len: int prompt_len: int
output_len: int output_len: int
model: str model: str
model_name: str | None = None model_name: Optional[str] = None
logprobs: int | None = None logprobs: Optional[int] = None
extra_body: dict | None = None extra_body: Optional[dict] = None
multi_modal_content: dict | list[dict] | None = None multi_modal_content: Optional[dict | list[dict]] = None
ignore_eos: bool = False ignore_eos: bool = False
language: str | None = None language: Optional[str] = None
request_id: str | None = None request_id: Optional[str] = None
@dataclass @dataclass
@ -51,7 +52,7 @@ class RequestFuncOutput:
async def async_request_tgi( async def async_request_tgi(
request_func_input: RequestFuncInput, request_func_input: RequestFuncInput,
pbar: tqdm | None = None, pbar: Optional[tqdm] = None,
) -> RequestFuncOutput: ) -> RequestFuncOutput:
api_url = request_func_input.api_url api_url = request_func_input.api_url
assert api_url.endswith("generate_stream") assert api_url.endswith("generate_stream")
@ -132,7 +133,7 @@ async def async_request_tgi(
async def async_request_trt_llm( async def async_request_trt_llm(
request_func_input: RequestFuncInput, request_func_input: RequestFuncInput,
pbar: tqdm | None = None, pbar: Optional[tqdm] = None,
) -> RequestFuncOutput: ) -> RequestFuncOutput:
api_url = request_func_input.api_url api_url = request_func_input.api_url
assert api_url.endswith("generate_stream") assert api_url.endswith("generate_stream")
@ -203,7 +204,7 @@ async def async_request_trt_llm(
async def async_request_deepspeed_mii( async def async_request_deepspeed_mii(
request_func_input: RequestFuncInput, request_func_input: RequestFuncInput,
pbar: tqdm | None = None, pbar: Optional[tqdm] = None,
) -> RequestFuncOutput: ) -> RequestFuncOutput:
api_url = request_func_input.api_url api_url = request_func_input.api_url
assert api_url.endswith(("completions", "profile")), ( assert api_url.endswith(("completions", "profile")), (
@ -266,7 +267,7 @@ async def async_request_deepspeed_mii(
async def async_request_openai_completions( async def async_request_openai_completions(
request_func_input: RequestFuncInput, request_func_input: RequestFuncInput,
pbar: tqdm | None = None, pbar: Optional[tqdm] = None,
) -> RequestFuncOutput: ) -> RequestFuncOutput:
api_url = request_func_input.api_url api_url = request_func_input.api_url
assert api_url.endswith(("completions", "profile")), ( assert api_url.endswith(("completions", "profile")), (
@ -366,7 +367,7 @@ async def async_request_openai_completions(
async def async_request_openai_chat_completions( async def async_request_openai_chat_completions(
request_func_input: RequestFuncInput, request_func_input: RequestFuncInput,
pbar: tqdm | None = None, pbar: Optional[tqdm] = None,
) -> RequestFuncOutput: ) -> RequestFuncOutput:
api_url = request_func_input.api_url api_url = request_func_input.api_url
assert api_url.endswith(("chat/completions", "profile")), ( assert api_url.endswith(("chat/completions", "profile")), (
@ -475,7 +476,7 @@ async def async_request_openai_chat_completions(
async def async_request_openai_audio( async def async_request_openai_audio(
request_func_input: RequestFuncInput, request_func_input: RequestFuncInput,
pbar: tqdm | None = None, pbar: Optional[tqdm] = None,
) -> RequestFuncOutput: ) -> RequestFuncOutput:
# Lazy import without PlaceholderModule to avoid vllm dep. # Lazy import without PlaceholderModule to avoid vllm dep.
import soundfile import soundfile
@ -609,7 +610,7 @@ def get_tokenizer(
tokenizer_mode: str = "auto", tokenizer_mode: str = "auto",
trust_remote_code: bool = False, trust_remote_code: bool = False,
**kwargs, **kwargs,
) -> PreTrainedTokenizer | PreTrainedTokenizerFast: ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
if pretrained_model_name_or_path is not None and not os.path.exists( if pretrained_model_name_or_path is not None and not os.path.exists(
pretrained_model_name_or_path pretrained_model_name_or_path
): ):

View File

@ -32,6 +32,7 @@ import dataclasses
import json import json
import random import random
import time import time
from typing import Optional
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
@ -79,7 +80,7 @@ def sample_requests_from_dataset(
num_requests: int, num_requests: int,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
input_length_range: tuple[int, int], input_length_range: tuple[int, int],
fixed_output_len: int | None, fixed_output_len: Optional[int],
) -> list[Request]: ) -> list[Request]:
if fixed_output_len is not None and fixed_output_len < 4: if fixed_output_len is not None and fixed_output_len < 4:
raise ValueError("output_len too small") raise ValueError("output_len too small")
@ -127,7 +128,7 @@ def sample_requests_from_random(
num_requests: int, num_requests: int,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
input_length_range: tuple[int, int], input_length_range: tuple[int, int],
fixed_output_len: int | None, fixed_output_len: Optional[int],
prefix_len: int, prefix_len: int,
) -> list[Request]: ) -> list[Request]:
requests = [] requests = []

View File

@ -7,6 +7,7 @@ import dataclasses
import json import json
import random import random
import time import time
from typing import Optional
from transformers import AutoTokenizer, PreTrainedTokenizerBase from transformers import AutoTokenizer, PreTrainedTokenizerBase
@ -23,7 +24,7 @@ def sample_requests(
dataset_path: str, dataset_path: str,
num_requests: int, num_requests: int,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
fixed_output_len: int | None, fixed_output_len: Optional[int],
) -> list[tuple[str, int, int, int]]: ) -> list[tuple[str, int, int, int]]:
if fixed_output_len is not None and fixed_output_len < 4: if fixed_output_len is not None and fixed_output_len < 4:
raise ValueError("output_len too small") raise ValueError("output_len too small")

View File

@ -31,8 +31,8 @@ import time
import uuid import uuid
import warnings import warnings
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from contextlib import nullcontext
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional
import datasets import datasets
import numpy as np import numpy as np
@ -316,7 +316,7 @@ def calculate_metrics(
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
selected_percentile_metrics: list[str], selected_percentile_metrics: list[str],
selected_percentiles: list[float], selected_percentiles: list[float],
goodput_config_dict: dict[str, float] | None = None, goodput_config_dict: Optional[dict[str, float]] = None,
) -> tuple[BenchmarkMetrics, list[int]]: ) -> tuple[BenchmarkMetrics, list[int]]:
actual_output_lens: list[int] = [] actual_output_lens: list[int] = []
total_input = 0 total_input = 0
@ -436,9 +436,9 @@ async def benchmark(
selected_percentile_metrics: list[str], selected_percentile_metrics: list[str],
selected_percentiles: list[str], selected_percentiles: list[str],
ignore_eos: bool, ignore_eos: bool,
max_concurrency: int | None, max_concurrency: Optional[int],
structured_output_ratio: float, structured_output_ratio: float,
goodput_config_dict: dict[str, float] | None = None, goodput_config_dict: Optional[dict[str, float]] = None,
): ):
if backend in ASYNC_REQUEST_FUNCS: if backend in ASYNC_REQUEST_FUNCS:
request_func = ASYNC_REQUEST_FUNCS[backend] request_func = ASYNC_REQUEST_FUNCS[backend]
@ -502,9 +502,15 @@ async def benchmark(
pbar = None if disable_tqdm else tqdm(total=len(input_requests)) pbar = None if disable_tqdm else tqdm(total=len(input_requests))
semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else nullcontext() # This can be used once the minimum Python version is 3.10 or higher,
# and it will simplify the code in limited_request_func.
# semaphore = (asyncio.Semaphore(max_concurrency)
# if max_concurrency else contextlib.nullcontext())
semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None
async def limited_request_func(request_func_input, pbar): async def limited_request_func(request_func_input, pbar):
if semaphore is None:
return await request_func(request_func_input=request_func_input, pbar=pbar)
async with semaphore: async with semaphore:
return await request_func(request_func_input=request_func_input, pbar=pbar) return await request_func(request_func_input=request_func_input, pbar=pbar)

View File

@ -6,7 +6,7 @@ import math
import os import os
import time import time
from types import TracebackType from types import TracebackType
from typing import Any from typing import Any, Optional, Union
def convert_to_pytorch_benchmark_format( def convert_to_pytorch_benchmark_format(
@ -92,7 +92,7 @@ class TimeCollector:
def __init__(self, scale: int) -> None: def __init__(self, scale: int) -> None:
self.cnt: int = 0 self.cnt: int = 0
self._sum: int = 0 self._sum: int = 0
self._max: int | None = None self._max: Optional[int] = None
self.scale = scale self.scale = scale
self.start_time: int = time.monotonic_ns() self.start_time: int = time.monotonic_ns()
@ -104,13 +104,13 @@ class TimeCollector:
else: else:
self._max = max(self._max, v) self._max = max(self._max, v)
def avg(self) -> float | str: def avg(self) -> Union[float, str]:
return self._sum * 1.0 / self.cnt / self.scale if self.cnt > 0 else "N/A" return self._sum * 1.0 / self.cnt / self.scale if self.cnt > 0 else "N/A"
def max(self) -> float | str: def max(self) -> Union[float, str]:
return self._max / self.scale if self._max else "N/A" return self._max / self.scale if self._max else "N/A"
def dump_avg_max(self) -> list[float | str]: def dump_avg_max(self) -> list[Union[float, str]]:
return [self.avg(), self.max()] return [self.avg(), self.max()]
def __enter__(self) -> None: def __enter__(self) -> None:
@ -118,8 +118,8 @@ class TimeCollector:
def __exit__( def __exit__(
self, self,
exc_type: type[BaseException] | None, exc_type: Optional[type[BaseException]],
exc_value: BaseException | None, exc_value: Optional[BaseException],
exc_traceback: TracebackType | None, exc_traceback: Optional[TracebackType],
) -> None: ) -> None:
self.collect(time.monotonic_ns() - self.start_time) self.collect(time.monotonic_ns() - self.start_time)

View File

@ -6,7 +6,8 @@ import copy
import itertools import itertools
import pickle as pkl import pickle as pkl
import time import time
from collections.abc import Callable, Iterable from collections.abc import Iterable
from typing import Callable
import torch import torch
import torch.utils.benchmark as TBenchmark import torch.utils.benchmark as TBenchmark

View File

@ -6,7 +6,8 @@ import copy
import itertools import itertools
import pickle as pkl import pickle as pkl
import time import time
from collections.abc import Callable, Iterable from collections.abc import Iterable
from typing import Callable, Optional
import torch import torch
import torch.utils.benchmark as TBenchmark import torch.utils.benchmark as TBenchmark
@ -52,7 +53,7 @@ def bench_int8(
n: int, n: int,
label: str, label: str,
sub_label: str, sub_label: str,
bench_kernels: list[str] | None = None, bench_kernels: Optional[list[str]] = None,
) -> Iterable[TMeasurement]: ) -> Iterable[TMeasurement]:
"""Benchmark INT8-based kernels.""" """Benchmark INT8-based kernels."""
assert dtype == torch.int8 assert dtype == torch.int8
@ -107,7 +108,7 @@ def bench_fp8(
n: int, n: int,
label: str, label: str,
sub_label: str, sub_label: str,
bench_kernels: list[str] | None = None, bench_kernels: Optional[list[str]] = None,
) -> Iterable[TMeasurement]: ) -> Iterable[TMeasurement]:
"""Benchmark FP8-based kernels.""" """Benchmark FP8-based kernels."""
assert dtype == torch.float8_e4m3fn assert dtype == torch.float8_e4m3fn
@ -182,7 +183,7 @@ def bench(
n: int, n: int,
label: str, label: str,
sub_label: str, sub_label: str,
bench_kernels: list[str] | None = None, bench_kernels: Optional[list[str]] = None,
) -> Iterable[TMeasurement]: ) -> Iterable[TMeasurement]:
if dtype == torch.int8: if dtype == torch.int8:
return bench_int8(dtype, m, k, n, label, sub_label, bench_kernels) return bench_int8(dtype, m, k, n, label, sub_label, bench_kernels)
@ -200,7 +201,7 @@ def print_timers(timers: Iterable[TMeasurement]):
def run( def run(
dtype: torch.dtype, dtype: torch.dtype,
MKNs: Iterable[tuple[int, int, int]], MKNs: Iterable[tuple[int, int, int]],
bench_kernels: list[str] | None = None, bench_kernels: Optional[list[str]] = None,
) -> Iterable[TMeasurement]: ) -> Iterable[TMeasurement]:
results = [] results = []
for m, k, n in MKNs: for m, k, n in MKNs:

View File

@ -3,9 +3,10 @@
import pickle as pkl import pickle as pkl
import time import time
from collections.abc import Callable, Iterable from collections.abc import Iterable
from dataclasses import dataclass from dataclasses import dataclass
from itertools import product from itertools import product
from typing import Callable, Optional
import torch import torch
import torch.utils.benchmark as TBenchmark import torch.utils.benchmark as TBenchmark
@ -50,7 +51,7 @@ def get_bench_params() -> list[bench_params_t]:
def unfused_int8_impl( def unfused_int8_impl(
rms_norm_layer: RMSNorm, rms_norm_layer: RMSNorm,
x: torch.Tensor, x: torch.Tensor,
residual: torch.Tensor | None, residual: Optional[torch.Tensor],
quant_dtype: torch.dtype, quant_dtype: torch.dtype,
): ):
# Norm # Norm
@ -67,7 +68,7 @@ def unfused_int8_impl(
def unfused_fp8_impl( def unfused_fp8_impl(
rms_norm_layer: RMSNorm, rms_norm_layer: RMSNorm,
x: torch.Tensor, x: torch.Tensor,
residual: torch.Tensor | None, residual: Optional[torch.Tensor],
quant_dtype: torch.dtype, quant_dtype: torch.dtype,
): ):
# Norm # Norm
@ -84,7 +85,7 @@ def unfused_fp8_impl(
def fused_impl( def fused_impl(
rms_norm_layer: RMSNorm, # this stores the weights rms_norm_layer: RMSNorm, # this stores the weights
x: torch.Tensor, x: torch.Tensor,
residual: torch.Tensor | None, residual: Optional[torch.Tensor],
quant_dtype: torch.dtype, quant_dtype: torch.dtype,
): ):
out, _ = ops.rms_norm_dynamic_per_token_quant( out, _ = ops.rms_norm_dynamic_per_token_quant(

View File

@ -1,191 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
#
# Copyright (C) 2025 Roberto L. Castro (Roberto.LopezCastro@ist.ac.at).
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import argparse
import copy
import itertools
import torch
from compressed_tensors.transform.utils.hadamard import deterministic_hadamard_matrix
from weight_shapes import WEIGHT_SHAPES
from vllm._custom_ops import fusedQuantizeMx, matmul_mxf4_bf16_tn
from vllm.model_executor.layers.quantization.qutlass_utils import to_blocked
from vllm.triton_utils import triton
PROVIDER_CFGS = {
"torch-bf16": dict(enabled=True),
"mxfp4": dict(no_a_quant=False, enabled=True),
"mxfp4-noquant": dict(no_a_quant=True, enabled=True),
}
_enabled = [k for k, v in PROVIDER_CFGS.items() if v["enabled"]]
def get_hadamard_matrix(group_size: int, dtype: torch.dtype, device: torch.device):
return (
deterministic_hadamard_matrix(group_size, dtype=dtype, device=device)
* group_size**-0.5
)
def _quant_weight_mxfp4(
b: torch.Tensor, forward_hadamard_matrix: torch.Tensor, device: str
):
weight_hf_e2m1, weight_hf_e8m0 = fusedQuantizeMx(
b, forward_hadamard_matrix, method="abs_max"
)
weight_hf_scale_block = to_blocked(weight_hf_e8m0, backend="triton")
return weight_hf_e2m1, weight_hf_scale_block
def build_mxfp4_runner(cfg, a, b, forward_hadamard_matrix, dtype, device):
weight_hf_e2m1, weight_hf_scale_block = _quant_weight_mxfp4(
b, forward_hadamard_matrix, device
)
alpha = torch.tensor([1.0], device="cuda")
if cfg["no_a_quant"]:
# Pre-quantize activation
input_hf_e2m1, input_hf_e8m0 = fusedQuantizeMx(
a, forward_hadamard_matrix, method="abs_max"
)
input_hf_scale_block = to_blocked(input_hf_e8m0, backend="triton")
def run():
return matmul_mxf4_bf16_tn(
input_hf_e2m1,
weight_hf_e2m1,
input_hf_scale_block,
weight_hf_scale_block,
alpha,
)
return run
# Quantize activation on-the-fly
def run():
input_hf_e2m1, input_hf_e8m0 = fusedQuantizeMx(
a, forward_hadamard_matrix, method="abs_max"
)
input_hf_scale_block = to_blocked(input_hf_e8m0, backend="triton")
return matmul_mxf4_bf16_tn(
input_hf_e2m1,
weight_hf_e2m1,
input_hf_scale_block,
weight_hf_scale_block,
alpha,
)
return run
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=[
1,
4,
8,
16,
32,
64,
128,
256,
512,
1024,
2048,
4096,
8192,
16384,
24576,
32768,
],
x_log=False,
line_arg="provider",
line_vals=_enabled,
line_names=_enabled,
ylabel="TFLOP/s (larger is better)",
plot_name="BF16 vs MXFP4 GEMMs",
args={},
)
)
def benchmark(batch_size, provider, N, K, had_size):
M = batch_size
device = "cuda"
dtype = torch.bfloat16
a = torch.randn((M, K), device=device, dtype=dtype)
b = torch.randn((N, K), device=device, dtype=dtype)
forward_hadamard_matrix = get_hadamard_matrix(had_size, dtype, device)
quantiles = [0.5, 0.2, 0.8]
if provider == "torch-bf16":
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: torch.nn.functional.linear(a, b), rep=200, quantiles=quantiles
)
else:
cfg = PROVIDER_CFGS[provider]
run_quant = build_mxfp4_runner(
cfg, a, b, forward_hadamard_matrix, dtype, device
)
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: run_quant(), rep=200, quantiles=quantiles
)
to_tflops = lambda t_ms: (2 * M * N * K) * 1e-12 / (t_ms * 1e-3)
return to_tflops(ms), to_tflops(max_ms), to_tflops(min_ms)
def prepare_shapes(args):
out = []
for model, tp_size in itertools.product(args.models, args.tp_sizes):
for KN, tp_dim in copy.deepcopy(WEIGHT_SHAPES[model]):
KN[tp_dim] //= tp_size
KN.append(model)
out.append(KN)
return out
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--models",
nargs="+",
type=str,
default=["meta-llama/Llama-3.3-70B-Instruct"],
choices=list(WEIGHT_SHAPES.keys()),
)
parser.add_argument("--tp-sizes", nargs="+", type=int, default=[1])
args = parser.parse_args()
for K, N, model in prepare_shapes(args):
for had_size in [32, 64, 128]:
print(f"{model}, N={N} K={K}, HAD={had_size}, BF16 vs MXFP4 GEMMs TFLOP/s:")
benchmark.run(
print_data=True,
show_plots=True,
save_path=f"bench_mxfp4_res_n{N}_k{K}",
N=N,
K=K,
had_size=had_size,
)
print("Benchmark finished!")

View File

@ -1,207 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
#
# Copyright (C) 2025 Roberto L. Castro (Roberto.LopezCastro@ist.ac.at).
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import argparse
import copy
import itertools
import torch
from compressed_tensors.transform.utils.hadamard import deterministic_hadamard_matrix
from weight_shapes import WEIGHT_SHAPES
from vllm import _custom_ops as ops # use existing nvfp4 gemm in vllm
from vllm._custom_ops import fusedQuantizeNv
from vllm.model_executor.layers.quantization.qutlass_utils import to_blocked
from vllm.triton_utils import triton
PROVIDER_CFGS = {
"torch-bf16": dict(enabled=True),
"nvfp4": dict(no_a_quant=False, enabled=True),
"nvfp4-noquant": dict(no_a_quant=True, enabled=True),
}
_enabled = [k for k, v in PROVIDER_CFGS.items() if v["enabled"]]
def get_hadamard_matrix(group_size: int, dtype: torch.dtype, device: torch.device):
return (
deterministic_hadamard_matrix(group_size, dtype=dtype, device=device)
* group_size**-0.5
)
def _quant_weight_nvfp4(
b: torch.Tensor,
forward_hadamard_matrix: torch.Tensor,
global_scale: torch.Tensor,
device: str,
M: int,
N: int,
K: int,
):
weight_hf_e2m1, weight_hf_e8m0 = fusedQuantizeNv(
b, forward_hadamard_matrix, global_scale
)
weight_hf_scale_block = to_blocked(weight_hf_e8m0, backend="triton").view(
-1, K // 16
)
return weight_hf_e2m1, weight_hf_scale_block
def build_nvfp4_runner(cfg, a, b, forward_hadamard_matrix, dtype, device, M, N, K):
alpha = torch.tensor([1.0], device="cuda")
global_scale = torch.tensor([1.0], device="cuda")
weight_hf_e2m1, weight_hf_scale_block = _quant_weight_nvfp4(
b, forward_hadamard_matrix, global_scale, device, M, N, K
)
if cfg["no_a_quant"]:
# Pre-quantize activation
input_hf_e2m1, input_hf_e8m0 = fusedQuantizeNv(
a, forward_hadamard_matrix, global_scale
)
input_hf_scale_block = to_blocked(input_hf_e8m0, backend="triton").view(
-1, K // 16
)
def run():
return ops.cutlass_scaled_fp4_mm(
input_hf_e2m1,
weight_hf_e2m1,
input_hf_scale_block,
weight_hf_scale_block,
alpha,
torch.bfloat16,
)
return run
# Quantize activation on-the-fly
def run():
input_hf_e2m1, input_hf_e8m0 = fusedQuantizeNv(
a, forward_hadamard_matrix, global_scale
)
input_hf_scale_block = to_blocked(input_hf_e8m0, backend="triton").view(
-1, K // 16
)
return ops.cutlass_scaled_fp4_mm(
input_hf_e2m1,
weight_hf_e2m1,
input_hf_scale_block,
weight_hf_scale_block,
alpha,
torch.bfloat16,
)
return run
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=[
1,
4,
8,
16,
32,
64,
128,
256,
512,
1024,
2048,
4096,
8192,
16384,
24576,
32768,
],
x_log=False,
line_arg="provider",
line_vals=_enabled,
line_names=_enabled,
ylabel="TFLOP/s (larger is better)",
plot_name="BF16 vs NVFP4 GEMMs",
args={},
)
)
def benchmark(batch_size, provider, N, K, had_size):
M = batch_size
device = "cuda"
dtype = torch.bfloat16
a = torch.randn((M, K), device=device, dtype=dtype)
b = torch.randn((N, K), device=device, dtype=dtype)
forward_hadamard_matrix = get_hadamard_matrix(had_size, dtype, device)
quantiles = [0.5, 0.2, 0.8]
if provider == "torch-bf16":
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: torch.nn.functional.linear(a, b), rep=200, quantiles=quantiles
)
else:
cfg = PROVIDER_CFGS[provider]
run_quant = build_nvfp4_runner(
cfg, a, b, forward_hadamard_matrix, dtype, device, M, N, K
)
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: run_quant(), rep=200, quantiles=quantiles
)
to_tflops = lambda t_ms: (2 * M * N * K) * 1e-12 / (t_ms * 1e-3)
return to_tflops(ms), to_tflops(max_ms), to_tflops(min_ms)
def prepare_shapes(args):
out = []
for model, tp_size in itertools.product(args.models, args.tp_sizes):
for KN, tp_dim in copy.deepcopy(WEIGHT_SHAPES[model]):
KN[tp_dim] //= tp_size
KN.append(model)
out.append(KN)
return out
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--models",
nargs="+",
type=str,
default=["meta-llama/Llama-3.3-70B-Instruct"],
choices=list(WEIGHT_SHAPES.keys()),
)
parser.add_argument("--tp-sizes", nargs="+", type=int, default=[1])
args = parser.parse_args()
for K, N, model in prepare_shapes(args):
for had_size in [16, 32, 64, 128]:
print(f"{model}, N={N} K={K}, HAD={had_size}, BF16 vs NVFP4 GEMMs TFLOP/s:")
benchmark.run(
print_data=True,
show_plots=True,
save_path=f"bench_nvfp4_res_n{N}_k{K}",
N=N,
K=K,
had_size=had_size,
)
print("Benchmark finished!")

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools import itertools
from collections.abc import Callable from typing import Callable
from unittest.mock import patch from unittest.mock import patch
import pandas as pd import pandas as pd
@ -10,8 +10,7 @@ import torch
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.triton_utils import triton from vllm.triton_utils import triton
from vllm.utils import FlexibleArgumentParser from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
def with_triton_mode(fn): def with_triton_mode(fn):

View File

@ -10,8 +10,7 @@ import vllm.model_executor.layers.activation # noqa F401
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import triton from vllm.triton_utils import triton
from vllm.utils import FlexibleArgumentParser from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
batch_size_range = [1, 16, 32, 64, 128] batch_size_range = [1, 16, 32, 64, 128]
seq_len_range = [1, 16, 64, 128, 256, 512, 1024, 2048, 4096] seq_len_range = [1, 16, 64, 128, 256, 512, 1024, 2048, 4096]

View File

@ -22,8 +22,8 @@ Example:
import json import json
import os import os
import time import time
from collections.abc import Callable
from contextlib import nullcontext from contextlib import nullcontext
from typing import Callable, Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -264,12 +264,12 @@ class CommunicatorBenchmark:
def benchmark_allreduce_single( def benchmark_allreduce_single(
self, self,
sequence_length: int, sequence_length: int,
allreduce_fn: Callable[[torch.Tensor], torch.Tensor | None], allreduce_fn: Callable[[torch.Tensor], Optional[torch.Tensor]],
should_use_fn: Callable[[torch.Tensor], bool], should_use_fn: Callable[[torch.Tensor], bool],
context, context,
num_warmup: int, num_warmup: int,
num_trials: int, num_trials: int,
) -> float | None: ) -> Optional[float]:
"""Benchmark method with CUDA graph optimization.""" """Benchmark method with CUDA graph optimization."""
try: try:
# Create test tensor (2D: sequence_length x hidden_size) # Create test tensor (2D: sequence_length x hidden_size)

View File

@ -7,8 +7,7 @@ import torch
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import FlexibleArgumentParser from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
@torch.inference_mode() @torch.inference_mode()

View File

@ -6,12 +6,11 @@ import copy
import json import json
import pickle import pickle
import time import time
from collections.abc import Callable
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum, auto from enum import Enum, auto
from itertools import product from itertools import product
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any, Callable, Optional
import torch import torch
import torch.utils.benchmark as TBenchmark import torch.utils.benchmark as TBenchmark
@ -159,7 +158,7 @@ def ref_group_gemm(
seq_lens_cpu: torch.Tensor, seq_lens_cpu: torch.Tensor,
prompt_lora_mapping_cpu: torch.Tensor, prompt_lora_mapping_cpu: torch.Tensor,
scaling: float, scaling: float,
add_inputs: bool | None, add_inputs: Optional[bool],
): ):
""" """
Torch group gemm reference implementation to test correctness of Torch group gemm reference implementation to test correctness of
@ -317,8 +316,8 @@ class BenchmarkContext:
lora_rank: int lora_rank: int
sort_by_lora_id: bool sort_by_lora_id: bool
dtype: torch.dtype dtype: torch.dtype
seq_length: int | None = None seq_length: Optional[int] = None
num_slices: int | None = None # num_slices for slice based ops num_slices: Optional[int] = None # num_slices for slice based ops
def with_seq_length(self, seq_length: int) -> "BenchmarkContext": def with_seq_length(self, seq_length: int) -> "BenchmarkContext":
ctx = copy.copy(self) ctx = copy.copy(self)
@ -562,7 +561,7 @@ class BenchmarkTensors:
} }
def bench_fn_kwargs( def bench_fn_kwargs(
self, op_type: OpType, add_inputs: bool | None = None self, op_type: OpType, add_inputs: Optional[bool] = None
) -> dict[str, Any]: ) -> dict[str, Any]:
if op_type.is_shrink_fn(): if op_type.is_shrink_fn():
assert add_inputs is None assert add_inputs is None
@ -576,7 +575,7 @@ class BenchmarkTensors:
raise ValueError(f"Unrecognized optype {self}") raise ValueError(f"Unrecognized optype {self}")
def test_correctness( def test_correctness(
self, op_type: OpType, expand_fn_add_inputs: bool | None self, op_type: OpType, expand_fn_add_inputs: Optional[bool]
) -> bool: ) -> bool:
""" """
Test correctness of op_type implementation against a grouped gemm Test correctness of op_type implementation against a grouped gemm
@ -612,8 +611,8 @@ def bench_optype(
ctx: BenchmarkContext, ctx: BenchmarkContext,
arg_pool_size: int, arg_pool_size: int,
op_type: OpType, op_type: OpType,
cuda_graph_nops: int | None = None, cuda_graph_nops: Optional[int] = None,
expand_fn_add_inputs: bool | None = None, expand_fn_add_inputs: Optional[bool] = None,
test_correctness: bool = False, test_correctness: bool = False,
) -> TMeasurement: ) -> TMeasurement:
assert arg_pool_size >= 1 assert arg_pool_size >= 1
@ -680,7 +679,7 @@ def bench_torch_mm(
ctx: BenchmarkContext, ctx: BenchmarkContext,
arg_pool_size: int, arg_pool_size: int,
op_type: OpType, op_type: OpType,
cuda_graph_nops: int | None = None, cuda_graph_nops: Optional[int] = None,
) -> TMeasurement: ) -> TMeasurement:
""" """
Benchmark basic torch.mm as a roofline. Benchmark basic torch.mm as a roofline.
@ -745,7 +744,7 @@ def use_cuda_graph_recommendation() -> str:
""" """
def print_timers(timers: list[TMeasurement], args: argparse.Namespace | None = None): def print_timers(timers: list[TMeasurement], args: Optional[argparse.Namespace] = None):
compare = TBenchmark.Compare(timers) compare = TBenchmark.Compare(timers)
compare.print() compare.print()

View File

@ -8,9 +8,10 @@ import math
import os import os
import pickle as pkl import pickle as pkl
import time import time
from collections.abc import Callable, Iterable from collections.abc import Iterable
from dataclasses import dataclass from dataclasses import dataclass
from itertools import product from itertools import product
from typing import Callable, Optional
import pandas as pd import pandas as pd
import torch import torch
@ -62,23 +63,23 @@ class BenchmarkTensors:
a: torch.Tensor a: torch.Tensor
w_q: torch.Tensor w_q: torch.Tensor
group_size: int | None group_size: Optional[int]
wtype: ScalarType wtype: ScalarType
w_g_s: torch.Tensor w_g_s: torch.Tensor
w_g_zp: torch.Tensor | None w_g_zp: Optional[torch.Tensor]
w_ch_s: torch.Tensor | None w_ch_s: Optional[torch.Tensor]
w_tok_s: torch.Tensor | None w_tok_s: Optional[torch.Tensor]
@dataclass @dataclass
class TypeConfig: class TypeConfig:
act_type: torch.dtype act_type: torch.dtype
weight_type: ScalarType weight_type: ScalarType
output_type: torch.dtype | None output_type: Optional[torch.dtype]
group_scale_type: torch.dtype | None group_scale_type: Optional[torch.dtype]
group_zero_type: torch.dtype | None group_zero_type: Optional[torch.dtype]
channel_scale_type: torch.dtype | None channel_scale_type: Optional[torch.dtype]
token_scale_type: torch.dtype | None token_scale_type: Optional[torch.dtype]
def rand_data(shape, dtype=torch.float16, scale=1): def rand_data(shape, dtype=torch.float16, scale=1):
@ -92,8 +93,8 @@ def quantize_and_pack(
atype: torch.dtype, atype: torch.dtype,
w: torch.Tensor, w: torch.Tensor,
wtype: ScalarType, wtype: ScalarType,
stype: torch.dtype | None, stype: Optional[torch.dtype],
group_size: int | None, group_size: Optional[int],
zero_points: bool = False, zero_points: bool = False,
): ):
assert wtype.is_integer(), "TODO: support floating point weights" assert wtype.is_integer(), "TODO: support floating point weights"
@ -112,7 +113,7 @@ def quantize_and_pack(
def create_bench_tensors( def create_bench_tensors(
shape: tuple[int, int, int], types: TypeConfig, group_size: int | None shape: tuple[int, int, int], types: TypeConfig, group_size: Optional[int]
) -> list[BenchmarkTensors]: ) -> list[BenchmarkTensors]:
m, n, k = shape m, n, k = shape
@ -330,8 +331,8 @@ def bench_fns(label: str, sub_label: str, description: str, fns: list[Callable])
return res return res
_SWEEP_SCHEDULES_RESULTS: pd.DataFrame | None = None _SWEEP_SCHEDULES_RESULTS: Optional[pd.DataFrame] = None
_SWEEP_SCHEDULES_RESULTS_CSV: str | None = None _SWEEP_SCHEDULES_RESULTS_CSV: Optional[str] = None
def bench( def bench(

View File

@ -579,12 +579,10 @@ def main(args: argparse.Namespace):
E = config.ffn_config.moe_num_experts E = config.ffn_config.moe_num_experts
topk = config.ffn_config.moe_top_k topk = config.ffn_config.moe_top_k
intermediate_size = config.ffn_config.ffn_hidden_size intermediate_size = config.ffn_config.ffn_hidden_size
hidden_size = config.hidden_size
elif config.architectures[0] == "JambaForCausalLM": elif config.architectures[0] == "JambaForCausalLM":
E = config.num_experts E = config.num_experts
topk = config.num_experts_per_tok topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size intermediate_size = config.intermediate_size
hidden_size = config.hidden_size
elif config.architectures[0] in ( elif config.architectures[0] in (
"DeepseekV2ForCausalLM", "DeepseekV2ForCausalLM",
"DeepseekV3ForCausalLM", "DeepseekV3ForCausalLM",
@ -594,7 +592,6 @@ def main(args: argparse.Namespace):
E = config.n_routed_experts E = config.n_routed_experts
topk = config.num_experts_per_tok topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size intermediate_size = config.moe_intermediate_size
hidden_size = config.hidden_size
elif config.architectures[0] in ( elif config.architectures[0] in (
"Qwen2MoeForCausalLM", "Qwen2MoeForCausalLM",
"Qwen3MoeForCausalLM", "Qwen3MoeForCausalLM",
@ -603,18 +600,10 @@ def main(args: argparse.Namespace):
E = config.num_experts E = config.num_experts
topk = config.num_experts_per_tok topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size intermediate_size = config.moe_intermediate_size
hidden_size = config.hidden_size
elif config.architectures[0] == "Qwen3VLMoeForConditionalGeneration":
text_config = config.get_text_config()
E = text_config.num_experts
topk = text_config.num_experts_per_tok
intermediate_size = text_config.moe_intermediate_size
hidden_size = text_config.hidden_size
elif config.architectures[0] in ("HunYuanMoEV1ForCausalLM"): elif config.architectures[0] in ("HunYuanMoEV1ForCausalLM"):
E = config.num_experts E = config.num_experts
topk = config.moe_topk[0] topk = config.moe_topk[0]
intermediate_size = config.moe_intermediate_size[0] intermediate_size = config.moe_intermediate_size[0]
hidden_size = config.hidden_size
else: else:
# Support for llama4 # Support for llama4
config = config.get_text_config() config = config.get_text_config()
@ -622,7 +611,6 @@ def main(args: argparse.Namespace):
E = config.num_local_experts E = config.num_local_experts
topk = config.num_experts_per_tok topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size intermediate_size = config.intermediate_size
hidden_size = config.hidden_size
enable_ep = bool(args.enable_expert_parallel) enable_ep = bool(args.enable_expert_parallel)
if enable_ep: if enable_ep:
ensure_divisibility(E, args.tp_size, "Number of experts") ensure_divisibility(E, args.tp_size, "Number of experts")
@ -631,7 +619,8 @@ def main(args: argparse.Namespace):
else: else:
ensure_divisibility(intermediate_size, args.tp_size, "intermediate_size") ensure_divisibility(intermediate_size, args.tp_size, "intermediate_size")
shard_intermediate_size = 2 * intermediate_size // args.tp_size shard_intermediate_size = 2 * intermediate_size // args.tp_size
dtype = torch.float16 if current_platform.is_rocm() else config.dtype hidden_size = config.hidden_size
dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype
use_fp8_w8a8 = args.dtype == "fp8_w8a8" use_fp8_w8a8 = args.dtype == "fp8_w8a8"
use_int8_w8a16 = args.dtype == "int8_w8a16" use_int8_w8a16 = args.dtype == "int8_w8a16"
block_quant_shape = get_weight_block_size_safety(config) block_quant_shape = get_weight_block_size_safety(config)

View File

@ -344,7 +344,7 @@ def main(args: argparse.Namespace):
topk = config.num_experts_per_tok topk = config.num_experts_per_tok
hidden_size = config.hidden_size hidden_size = config.hidden_size
dtype = torch.float16 if current_platform.is_rocm() else config.dtype dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype
use_fp8_w8a8 = args.dtype == "fp8_w8a8" use_fp8_w8a8 = args.dtype == "fp8_w8a8"
use_int8_w8a16 = args.dtype == "int8_w8a16" use_int8_w8a16 = args.dtype == "int8_w8a16"
use_customized_permute = args.use_customized_permute use_customized_permute = args.use_customized_permute

View File

@ -3,15 +3,16 @@
import random import random
import time import time
from typing import Optional
import torch import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import FlexibleArgumentParser from vllm.utils import (
from vllm.utils.torch_utils import (
STR_DTYPE_TO_TORCH_DTYPE, STR_DTYPE_TO_TORCH_DTYPE,
FlexibleArgumentParser,
create_kv_caches_with_random, create_kv_caches_with_random,
) )
@ -36,7 +37,7 @@ def main(
seed: int, seed: int,
do_profile: bool, do_profile: bool,
device: str = "cuda", device: str = "cuda",
kv_cache_dtype: str | None = None, kv_cache_dtype: Optional[str] = None,
) -> None: ) -> None:
current_platform.seed_everything(seed) current_platform.seed_everything(seed)

View File

@ -3,8 +3,8 @@
import argparse import argparse
import math import math
from collections.abc import Callable
from contextlib import contextmanager from contextlib import contextmanager
from typing import Callable
from unittest.mock import patch from unittest.mock import patch
import torch import torch

View File

@ -0,0 +1,155 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
import torch
from vllm import _custom_ops as vllm_ops
from vllm.triton_utils import triton
def polynorm_naive(
x: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
eps: float = 1e-6,
):
orig_shape = x.shape
x = x.view(-1, x.shape[-1])
def norm(x, eps: float):
return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps)
x = x.float()
return (
(
weight[0] * norm(x**3, eps)
+ weight[1] * norm(x**2, eps)
+ weight[2] * norm(x, eps)
+ bias
)
.to(weight.dtype)
.view(orig_shape)
)
def polynorm_vllm(
x: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
eps: float = 1e-6,
):
orig_shape = x.shape
x = x.view(-1, x.shape[-1])
out = torch.empty_like(x)
vllm_ops.poly_norm(out, x, weight, bias, eps)
output = out
output = output.view(orig_shape)
return output
def calculate_diff(batch_size, seq_len, hidden_dim):
dtype = torch.bfloat16
x = torch.randn(batch_size, seq_len, hidden_dim, dtype=dtype, device="cuda")
weight = torch.ones(3, dtype=dtype, device="cuda")
bias = torch.ones(1, dtype=dtype, device="cuda")
output_naive = polynorm_naive(x, weight, bias)
output_vllm = polynorm_vllm(x, weight, bias)
if torch.allclose(output_naive, output_vllm, atol=1e-2, rtol=1e-2):
print("✅ All implementations match")
else:
print("❌ Implementations differ")
batch_size_range = [2**i for i in range(0, 7, 2)]
seq_length_range = [2**i for i in range(6, 11, 1)]
dim_range = [2048, 4096]
configs = list(itertools.product(dim_range, batch_size_range, seq_length_range))
def get_benchmark():
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["dim", "batch_size", "seq_len"],
x_vals=[list(_) for _ in configs],
line_arg="provider",
line_vals=["naive", "vllm"],
line_names=["Naive", "vLLM"],
styles=[("blue", "-"), ("red", "-")],
ylabel="us",
plot_name="polynorm-perf",
args={},
)
)
def benchmark(dim, batch_size, seq_len, provider):
dtype = torch.bfloat16
hidden_dim = dim * 4
x = torch.randn(batch_size, seq_len, hidden_dim, dtype=dtype, device="cuda")
weight = torch.ones(3, dtype=dtype, device="cuda")
bias = torch.ones(1, dtype=dtype, device="cuda")
quantiles = [0.5, 0.2, 0.8]
if provider == "naive":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: polynorm_naive(x, weight, bias),
quantiles=quantiles,
)
else:
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: polynorm_vllm(x, weight, bias),
quantiles=quantiles,
)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
return benchmark
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--batch-size",
type=int,
default=4,
help="Batch size",
)
parser.add_argument(
"--seq-len",
type=int,
default=128,
help="Sequence length",
)
parser.add_argument(
"--hidden-dim",
type=int,
default=8192,
help="Intermediate size of MLP",
)
parser.add_argument(
"--save-path",
type=str,
default="./configs/polnorm/",
help="Path to save polnorm benchmark results",
)
args = parser.parse_args()
# Run correctness test
calculate_diff(
batch_size=args.batch_size,
seq_len=args.seq_len,
hidden_dim=args.hidden_dim,
)
benchmark = get_benchmark()
# Run performance benchmark
benchmark.run(print_data=True, save_path=args.save_path)

View File

@ -7,8 +7,7 @@ import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import FlexibleArgumentParser from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
@torch.inference_mode() @torch.inference_mode()

View File

@ -1,5 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import random import random
import time import time
@ -9,9 +11,9 @@ from tabulate import tabulate
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import FlexibleArgumentParser from vllm.utils import (
from vllm.utils.torch_utils import (
STR_DTYPE_TO_TORCH_DTYPE, STR_DTYPE_TO_TORCH_DTYPE,
FlexibleArgumentParser,
create_kv_caches_with_random, create_kv_caches_with_random,
) )

View File

@ -1,5 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import random import random
import time import time
@ -12,9 +14,9 @@ from vllm.attention.ops.triton_reshape_and_cache_flash import (
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import FlexibleArgumentParser from vllm.utils import (
from vllm.utils.torch_utils import (
STR_DTYPE_TO_TORCH_DTYPE, STR_DTYPE_TO_TORCH_DTYPE,
FlexibleArgumentParser,
create_kv_caches_with_random_flash, create_kv_caches_with_random_flash,
) )

View File

@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools import itertools
from typing import Optional, Union
import torch import torch
from flashinfer.norm import fused_add_rmsnorm, rmsnorm from flashinfer.norm import fused_add_rmsnorm, rmsnorm
@ -20,8 +21,8 @@ class HuggingFaceRMSNorm(nn.Module):
def forward( def forward(
self, self,
x: torch.Tensor, x: torch.Tensor,
residual: torch.Tensor | None = None, residual: Optional[torch.Tensor] = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
orig_dtype = x.dtype orig_dtype = x.dtype
x = x.to(torch.float32) x = x.to(torch.float32)
if residual is not None: if residual is not None:
@ -40,7 +41,7 @@ class HuggingFaceRMSNorm(nn.Module):
def rmsnorm_naive( def rmsnorm_naive(
x: torch.Tensor, x: torch.Tensor,
weight: torch.Tensor, weight: torch.Tensor,
residual: torch.Tensor | None = None, residual: Optional[torch.Tensor] = None,
eps: float = 1e-6, eps: float = 1e-6,
): ):
naive_norm = HuggingFaceRMSNorm(x.shape[-1], eps=eps) naive_norm = HuggingFaceRMSNorm(x.shape[-1], eps=eps)
@ -64,7 +65,7 @@ def rmsnorm_naive(
def rmsnorm_flashinfer( def rmsnorm_flashinfer(
x: torch.Tensor, x: torch.Tensor,
weight: torch.Tensor, weight: torch.Tensor,
residual: torch.Tensor | None = None, residual: Optional[torch.Tensor] = None,
eps: float = 1e-6, eps: float = 1e-6,
): ):
orig_shape = x.shape orig_shape = x.shape
@ -88,7 +89,7 @@ def rmsnorm_flashinfer(
def rmsnorm_vllm( def rmsnorm_vllm(
x: torch.Tensor, x: torch.Tensor,
weight: torch.Tensor, weight: torch.Tensor,
residual: torch.Tensor | None = None, residual: Optional[torch.Tensor] = None,
eps: float = 1e-6, eps: float = 1e-6,
): ):
orig_shape = x.shape orig_shape = x.shape

View File

@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from itertools import accumulate from itertools import accumulate
from typing import Optional
import nvtx import nvtx
import torch import torch
@ -17,7 +18,7 @@ def benchmark_rope_kernels_multi_lora(
seq_len: int, seq_len: int,
num_heads: int, num_heads: int,
head_size: int, head_size: int,
rotary_dim: int | None, rotary_dim: Optional[int],
dtype: torch.dtype, dtype: torch.dtype,
seed: int, seed: int,
device: str, device: str,

View File

@ -1,19 +1,5 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Comprehensive 3-way SiLU Benchmark Suite
This benchmark compares three SiLU implementations:
1. SiLU V2 (CUDA) - Optimized CUDA kernel implementation
2. Triton Kernel - Triton-based implementation
The suite generates detailed performance comparisons including:
- Memory bandwidth utilization
- Speedup ratios (baseline vs optimized implementations)
- Performance across different expert configurations and token distributions
"""
from collections.abc import Callable from collections.abc import Callable
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
@ -21,7 +7,7 @@ import numpy as np
import torch import torch
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
persistent_masked_m_silu_mul_quant, silu_mul_fp8_quant_deep_gemm_cuda,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
@ -108,7 +94,6 @@ def silu_mul_fp8_quant_deep_gemm_triton(
num_parallel_tokens, num_parallel_tokens,
group_size: int = 128, group_size: int = 128,
eps: float = 1e-10, eps: float = 1e-10,
expert_offsets: torch.Tensor = None,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
"""Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales """Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales
@ -189,7 +174,7 @@ def silu_mul_fp8_quant_deep_gemm_triton(
# Parse generation strategies # Parse generation strategies
strategies = ["random_imbalanced", "uniform", "max_t"] strategies = ["uniform", "max_t", "first_t"]
def benchmark( def benchmark(
@ -210,27 +195,15 @@ def benchmark(
current_platform.seed_everything(42 + seed_offset) current_platform.seed_everything(42 + seed_offset)
y = torch.rand((E, T, 2 * H), dtype=torch.bfloat16, device="cuda").contiguous() y = torch.rand((E, T, 2 * H), dtype=torch.bfloat16, device="cuda").contiguous()
if gen_strategy == "random_imbalanced": if gen_strategy == "uniform":
r = torch.rand(size=(E,), device="cuda")
def generate_expert_loads(n_e, total_tokens, ratio, device="cuda"):
mean = total_tokens // n_e
min_max = mean // ratio
e = torch.ones(size=(E,), dtype=torch.int64, device=device) * mean
e[0] = min_max
r = torch.rand(size=(E - 1,))
r /= r.sum()
r *= total_tokens - min_max
r = r.round().long()
e[1:] = r.to(device=device)
return e
tokens_per_expert = generate_expert_loads(E, total_tokens, 0.7, "cuda")
elif gen_strategy == "uniform":
r = torch.rand(size=(E,))
r /= r.sum() r /= r.sum()
r *= total_tokens r *= total_tokens
r = r.round().long() tokens_per_expert = r.int()
tokens_per_expert = r tokens_per_expert = torch.minimum(
tokens_per_expert,
torch.ones((E,), device=r.device, dtype=torch.int) * T,
)
elif gen_strategy == "max_t": elif gen_strategy == "max_t":
tokens_per_expert = torch.empty(size=(E,), dtype=torch.int32, device="cuda") tokens_per_expert = torch.empty(size=(E,), dtype=torch.int32, device="cuda")
tokens_per_expert.fill_(total_tokens / E) tokens_per_expert.fill_(total_tokens / E)
@ -308,34 +281,40 @@ def benchmark(
def create_comparison_plot( def create_comparison_plot(
ratios, silu_v2_times, triton_times, config_labels, strategy_name, id ratio, cuda_times, baseline_times, config_labels, strategy_name, id
): ):
fig, ax = plt.subplots(1, 1, figsize=(18, 6)) """Create a comparison plot for a specific generation strategy"""
fig, ax = plt.subplots(1, 1, figsize=(16, 6))
# Configure x-axis positions # Configure x-axis positions
x = np.arange(len(config_labels)) x = np.arange(len(config_labels))
width = 0.25 width = 0.35
# Execution Time plot (lower is better) # Execution Time plot (lower is better)
ax.bar(x, silu_v2_times, width, label="SiLU V2 (CUDA)", alpha=0.8, color="blue")
ax.bar( ax.bar(
x + width, triton_times, width, label="Triton Kernel", alpha=0.8, color="green" x - width / 2, cuda_times, width, label="CUDA Kernel", alpha=0.8, color="blue"
)
ax.bar(
x + width / 2,
baseline_times,
width,
label="Baseline",
alpha=0.8,
color="orange",
) )
# Add speedup labels over each bar trio # Add speedup labels over each bar pair
for i in range(len(x)): for i in range(len(x)):
triton_v2_speedup = ratios[i][1] # triton/v2 speedup = ratio[i]
max_height = max(silu_v2_times[i], triton_times[i]) max_height = max(cuda_times[i], baseline_times[i])
# Triton/V2 speedup
ax.text( ax.text(
x[i] + width / 2, x[i],
max_height + max_height * 0.02, max_height + max_height * 0.02,
f"{triton_v2_speedup:.2f}x", f"{speedup:.2f}x",
ha="center", ha="center",
va="bottom", va="bottom",
fontweight="bold", fontweight="bold",
fontsize=8, fontsize=9,
) )
ax.set_xlabel("Configuration") ax.set_xlabel("Configuration")
@ -353,75 +332,56 @@ def create_comparison_plot(
def create_combined_plot(all_results): def create_combined_plot(all_results):
"""Create a combined plot with all strategies in one PNG"""
num_strategies = len(all_results) num_strategies = len(all_results)
fig, axes = plt.subplots(num_strategies, 1, figsize=(22, 7 * num_strategies)) fig, axes = plt.subplots(num_strategies, 1, figsize=(20, 6 * num_strategies))
if num_strategies == 1: if num_strategies == 1:
axes = [axes] axes = [axes]
for idx, ( for idx, (
strategy_name, strategy_name,
all_ratios, ratio,
all_silu_v2_results, cuda_times,
all_triton_results, baseline_times,
config_labels, config_labels,
config_x_axis,
) in enumerate(all_results): ) in enumerate(all_results):
ax = axes[idx] ax = axes[idx]
# Flatten the nested results to get bandwidth percentages for plotting
silu_v2_bandwidths = []
triton_bandwidths = []
flat_ratios = []
for config_results in all_silu_v2_results:
for result in config_results:
silu_v2_bandwidths.append(result[3]) # bandwidth percentage
for config_results in all_triton_results:
for result in config_results:
triton_bandwidths.append(result[3]) # bandwidth percentage
for config_ratios in all_ratios:
for ratio in config_ratios:
flat_ratios.append(ratio)
# Configure x-axis positions # Configure x-axis positions
x = np.arange(len(config_labels)) x = np.arange(len(config_labels))
width = 0.25 width = 0.35
# Bandwidth utilization plot (higher is better) # Execution Time plot (lower is better)
ax.bar( ax.bar(
x, x - width / 2,
silu_v2_bandwidths, cuda_times,
width, width,
label="SiLU V2 (CUDA)", label="CUDA Kernel",
alpha=0.8, alpha=0.8,
color="blue", color="blue",
) )
ax.bar( ax.bar(
x + width, x + width / 2,
triton_bandwidths, baseline_times,
width, width,
label="Triton Kernel", label="Baseline",
alpha=0.8, alpha=0.8,
color="green", color="orange",
) )
# Add speedup labels over each bar trio # Add speedup labels over each bar pair
for i in range(len(x)): for i in range(len(x)):
triton_v2_speedup = flat_ratios[i] # triton/v2 speedup = ratio[i]
max_height = max(silu_v2_bandwidths[i], triton_bandwidths[i]) max_height = max(cuda_times[i], baseline_times[i])
# Triton/V2 speedup
ax.text( ax.text(
x[i] + width / 2, x[i],
max_height + max_height * 0.02, max_height + max_height * 0.02,
f"{triton_v2_speedup:.2f}x", f"{speedup:.2f}x",
ha="center", ha="center",
va="bottom", va="bottom",
fontweight="bold", fontweight="bold",
fontsize=8, fontsize=9,
) )
ax.set_xlabel("Configuration") ax.set_xlabel("Configuration")
@ -435,7 +395,7 @@ def create_combined_plot(all_results):
ax.grid(True, alpha=0.3) ax.grid(True, alpha=0.3)
plt.tight_layout() plt.tight_layout()
filename = "silu_benchmark_combined_3way.png" filename = "../../silu_bench/silu_benchmark_combined.png"
plt.savefig(filename, dpi=300, bbox_inches="tight") plt.savefig(filename, dpi=300, bbox_inches="tight")
plt.show() plt.show()
@ -445,9 +405,7 @@ def create_combined_plot(all_results):
outer_dim = 7168 outer_dim = 7168
configs = [ configs = [
# DeepSeekV3 Configs # DeepSeekV3 Configs
# (1, 56, 7168),
(8, 1024, 7168), (8, 1024, 7168),
# (32, 56, 7168),
# DeepSeekV3 Configs # DeepSeekV3 Configs
(32, 1024, 7168), (32, 1024, 7168),
# DeepSeekV3 Configs # DeepSeekV3 Configs
@ -459,7 +417,6 @@ num_warmups = 20
strategy_descriptions = { strategy_descriptions = {
"uniform": "Uniform Random", "uniform": "Uniform Random",
"random_imbalanced": "Imbalanced Random",
"max_t": "Even Assignment", "max_t": "Even Assignment",
"first_t": "experts[0] = T, experts[1:] = 0", "first_t": "experts[0] = T, experts[1:] = 0",
} }
@ -476,31 +433,28 @@ for id, strategy in enumerate(strategies):
print(f"Testing strategy: {strategy_descriptions[strategy]}") print(f"Testing strategy: {strategy_descriptions[strategy]}")
print(f"{'=' * 60}") print(f"{'=' * 60}")
# Collect benchmark data for all three algorithms # Collect benchmark data for both algorithms
config_labels = [] config_labels = []
config_x_axis = [] config_x_axis = []
all_silu_v2_results = [] all_cuda_results = []
all_triton_results = [] all_baseline_results = []
all_ratios = [] all_ratios = []
for E, T, H in configs: for E, T, H in configs:
total_tokens_config = [] total_tokens_config = [8 * E, 16 * E, 32 * E, 64 * E, 128 * E, 256 * E]
for i in [8, 16, 32, 64, 128, 256, 512]:
if i <= T:
total_tokens_config.append(i * E)
config_x_axis.append(total_tokens_config) config_x_axis.append(total_tokens_config)
silu_v2_results = [] cuda_results = []
triton_results = [] baseline_results = []
ratios = [] ratios = []
for total_tokens in total_tokens_config: for total_tokens in total_tokens_config:
config_label = f"E={E},T={T},H={H},TT={total_tokens}" config_label = f"E={E},T={T},H={H},TT={total_tokens}"
config_labels.append(config_label) config_labels.append(config_label)
# SiLU V2 (CUDA kernel) results # CUDA kernel results
time_ms_silu_v2, gflops, gbps, perc = benchmark( time_ms_cuda, gflops, gbps, perc = benchmark(
persistent_masked_m_silu_mul_quant, silu_mul_fp8_quant_deep_gemm_cuda,
E, E,
T, T,
H, H,
@ -509,9 +463,9 @@ for id, strategy in enumerate(strategies):
num_warmups=num_warmups, num_warmups=num_warmups,
gen_strategy=strategy, gen_strategy=strategy,
) )
silu_v2_results.append((time_ms_silu_v2, gflops, gbps, perc)) cuda_results.append((time_ms_cuda, gflops, gbps, perc))
# Triton kernel results # Baseline results
time_ms_triton, gflops, gbps, perc = benchmark( time_ms_triton, gflops, gbps, perc = benchmark(
silu_mul_fp8_quant_deep_gemm_triton, silu_mul_fp8_quant_deep_gemm_triton,
E, E,
@ -522,20 +476,12 @@ for id, strategy in enumerate(strategies):
num_warmups=num_warmups, num_warmups=num_warmups,
gen_strategy=strategy, gen_strategy=strategy,
) )
triton_results.append((time_ms_triton, gflops, gbps, perc)) baseline_results.append((time_ms_triton, gflops, gbps, perc))
ratios.append(time_ms_triton / time_ms_cuda)
# Calculate speedup ratios (triton baseline / implementation) print(f"Completed: {config_label}")
triton_v2_ratio = time_ms_triton / time_ms_silu_v2 all_cuda_results.append(cuda_results)
ratios.append(triton_v2_ratio) all_baseline_results.append(baseline_results)
print(
f"Completed: {config_label}:"
f" V2: {time_ms_silu_v2:.3f}ms,"
f" Triton: {time_ms_triton:.3f}ms"
)
all_silu_v2_results.append(silu_v2_results)
all_triton_results.append(triton_results)
all_ratios.append(ratios) all_ratios.append(ratios)
# Store results for combined plotting # Store results for combined plotting
@ -543,8 +489,8 @@ for id, strategy in enumerate(strategies):
( (
strategy_descriptions[strategy], strategy_descriptions[strategy],
all_ratios, all_ratios,
all_silu_v2_results, all_cuda_results,
all_triton_results, all_baseline_results,
config_labels, config_labels,
config_x_axis, config_x_axis,
) )
@ -552,18 +498,15 @@ for id, strategy in enumerate(strategies):
# Print summary table for this strategy # Print summary table for this strategy
print(f"\nSummary Table - {strategy_descriptions[strategy]}:") print(f"\nSummary Table - {strategy_descriptions[strategy]}:")
print(f" {'V2 Time(ms)':<12} {'Triton Time(ms)':<14} {'Triton/V2':<10}") print(f"{'Config':<20} {'CUDA Time(ms)':<12} {'Base Time(ms)':<12} {'Speedup':<8}")
print("-" * 90) print("-" * 60)
for i, (E, T, H) in enumerate(configs): for i, (E, T, H) in enumerate(configs):
# Get the first result for each config (simplifying for summary) speedup = baseline_results[i][0] / cuda_results[i][0]
v2_time = silu_v2_results[i][0]
triton_time = triton_results[i][0]
triton_v2_speedup = triton_time / v2_time
config_label = f"E={E:3d},T={T:4d},H={H:4d}" config_label = f"E={E:3d},T={T:4d},H={H:4d}"
print( print(
f"{config_label:<20} {v2_time:8.5f} {triton_time:10.5f} " f"{config_label:<20} {cuda_results[i][0]:8.5f} "
f"{triton_v2_speedup:8.2f}x" f"{baseline_results[i][0]:8.5f} {speedup:6.2f}x"
) )
@ -571,14 +514,15 @@ def create_total_tokens_plot(all_results):
num_strategies = len(all_results) num_strategies = len(all_results)
num_configs = len(configs) num_configs = len(configs)
# Create side-by-side subplots: 2 columns for speedup and bandwidth percentage
fig, axs = plt.subplots( fig, axs = plt.subplots(
num_strategies, num_configs * 2, figsize=(32, 8 * num_strategies) num_strategies, num_configs * 2, figsize=(28, 6 * num_strategies)
) )
# Add main title to the entire figure # Add main title to the entire figure
fig.suptitle( fig.suptitle(
"Performance Analysis: Speedup vs Bandwidth Utilization (SiLU V2, and Triton)", "Performance Analysis: Speedup vs Bandwidth Utilization (Triton & CUDA)",
fontsize=18, fontsize=16,
fontweight="bold", fontweight="bold",
y=0.98, y=0.98,
) )
@ -595,8 +539,8 @@ def create_total_tokens_plot(all_results):
( (
strategy_name, strategy_name,
all_ratios, all_ratios,
all_silu_v2_results, all_cuda_results,
all_triton_results, all_baseline_results,
config_labels, config_labels,
config_x_axis, config_x_axis,
) = result ) = result
@ -611,54 +555,42 @@ def create_total_tokens_plot(all_results):
ratios = all_ratios[config_idx] ratios = all_ratios[config_idx]
total_tokens_values = config_x_axis[config_idx] total_tokens_values = config_x_axis[config_idx]
# Extract speedup ratios # Extract CUDA and Triton bandwidth percentages
triton_v2_ratios = [ratio for ratio in ratios] cuda_bandwidth_percentages = [
result[3] for result in all_cuda_results[config_idx]
# Extract bandwidth percentages for all implementations
v2_bandwidth_percentages = [
result[3] for result in all_silu_v2_results[config_idx]
] ]
triton_bandwidth_percentages = [ triton_bandwidth_percentages = [
result[3] for result in all_triton_results[config_idx] result[3] for result in all_baseline_results[config_idx]
] ]
# Plot speedup ratios vs total tokens (left plot) # Plot speedup ratios vs total tokens (left plot)
ax_speedup.plot( ax_speedup.plot(
total_tokens_values, total_tokens_values, ratios, "bo-", linewidth=3, markersize=8
triton_v2_ratios,
"go-",
linewidth=3,
markersize=8,
label="Triton/V2 Speedup",
) )
ax_speedup.set_title( ax_speedup.set_title(
f"{strategy_name}\nSpeedup vs Baseline (Triton)\nE={E}, T={T}, H={H}", f"{strategy_name}\nSpeedup (CUDA/Triton)\nE={E}, T={T}, H={H}",
fontsize=12, fontsize=12,
fontweight="bold", fontweight="bold",
) )
ax_speedup.set_xlabel("Total Tokens", fontweight="bold", fontsize=11) ax_speedup.set_xlabel("Total Tokens", fontweight="bold", fontsize=11)
ax_speedup.set_ylabel("Speedup Ratio", fontweight="bold", fontsize=11) ax_speedup.set_ylabel("Speedup Ratio", fontweight="bold", fontsize=11)
ax_speedup.legend(prop={"weight": "bold"})
ax_speedup.grid(True, alpha=0.3) ax_speedup.grid(True, alpha=0.3)
# Plot bandwidth utilization (right plot)
ax_bandwidth.plot( ax_bandwidth.plot(
total_tokens_values, total_tokens_values,
v2_bandwidth_percentages, cuda_bandwidth_percentages,
"o-", "ro-",
linewidth=3, linewidth=3,
markersize=8, markersize=8,
label="SiLU V2", label="CUDA",
color="blue",
) )
ax_bandwidth.plot( ax_bandwidth.plot(
total_tokens_values, total_tokens_values,
triton_bandwidth_percentages, triton_bandwidth_percentages,
"o-", "go-",
linewidth=3, linewidth=3,
markersize=8, markersize=8,
label="Triton", label="Triton",
color="green",
) )
ax_bandwidth.set_title( ax_bandwidth.set_title(
f"{strategy_name}\nBandwidth Utilization (Hopper)\nE={E}, T={T}, H={H}", f"{strategy_name}\nBandwidth Utilization (Hopper)\nE={E}, T={T}, H={H}",
@ -686,12 +618,38 @@ def create_total_tokens_plot(all_results):
for label in ax.get_xticklabels() + ax.get_yticklabels(): for label in ax.get_xticklabels() + ax.get_yticklabels():
label.set_fontweight("bold") label.set_fontweight("bold")
# Add value labels on Triton/V2 speedup points # Add value labels on speedup points
for x, y in zip(total_tokens_values, triton_v2_ratios): for x, y in zip(total_tokens_values, ratios):
ax_speedup.annotate( ax_speedup.annotate(
f"{y:.2f}x", f"{y:.2f}x",
(x, y), (x, y),
textcoords="offset points", textcoords="offset points",
xytext=(0, 12),
ha="center",
fontsize=10,
fontweight="bold",
bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.7),
)
# Add value labels on CUDA bandwidth points
for x, y in zip(total_tokens_values, cuda_bandwidth_percentages):
ax_bandwidth.annotate(
f"{y:.1f}%",
(x, y),
textcoords="offset points",
xytext=(0, 12),
ha="center",
fontsize=9,
fontweight="bold",
bbox=dict(boxstyle="round,pad=0.2", facecolor="red", alpha=0.3),
)
# Add value labels on Triton bandwidth points
for x, y in zip(total_tokens_values, triton_bandwidth_percentages):
ax_bandwidth.annotate(
f"{y:.1f}%",
(x, y),
textcoords="offset points",
xytext=(0, -15), xytext=(0, -15),
ha="center", ha="center",
fontsize=9, fontsize=9,
@ -701,20 +659,17 @@ def create_total_tokens_plot(all_results):
plt.tight_layout() plt.tight_layout()
plt.subplots_adjust(top=0.93) # Make room for main title plt.subplots_adjust(top=0.93) # Make room for main title
filename = "silu_benchmark_total_tokens_3way.png" filename = "silu_benchmark_total_tokens.png"
plt.savefig(filename, dpi=300, bbox_inches="tight") plt.savefig(filename, dpi=300, bbox_inches="tight")
plt.show() plt.show()
return filename return filename
# Create comprehensive 3-way comparison plots # Create combined plot with all strategies
combined_plot_filename = create_combined_plot(all_results) combined_plot_filename = create_total_tokens_plot(all_results)
total_tokens_plot_filename = create_total_tokens_plot(all_results)
print(f"\n{'=' * 80}") print(f"\n{'=' * 60}")
print("3-Way Benchmark Suite Complete!") print("Benchmark Complete!")
print(f"Generated combined comparison plot: {combined_plot_filename}") print(f"Generated combined plot: {combined_plot_filename}")
print(f"Generated total tokens analysis plot: {total_tokens_plot_filename}") print(f"{'=' * 60}")
print("Compared: SiLU V2 (CUDA), and Triton implementations")
print(f"{'=' * 80}")

View File

@ -4,6 +4,7 @@
import csv import csv
import os import os
from datetime import datetime from datetime import datetime
from typing import Optional
import flashinfer import flashinfer
import torch import torch
@ -27,7 +28,9 @@ def to_float8(x, dtype=torch.float8_e4m3fn):
@torch.no_grad() @torch.no_grad()
def benchmark_decode( def benchmark_decode(
dtype: torch.dtype, dtype: torch.dtype,
quant_dtypes: tuple[torch.dtype | None, torch.dtype | None, torch.dtype | None], quant_dtypes: tuple[
Optional[torch.dtype], Optional[torch.dtype], Optional[torch.dtype]
],
batch_size: int, batch_size: int,
max_seq_len: int, max_seq_len: int,
num_heads: tuple[int, int] = (64, 8), num_heads: tuple[int, int] = (64, 8),

View File

@ -4,6 +4,7 @@
import csv import csv
import os import os
from datetime import datetime from datetime import datetime
from typing import Optional
import flashinfer import flashinfer
import torch import torch
@ -27,7 +28,9 @@ def to_float8(x, dtype=torch.float8_e4m3fn):
@torch.no_grad() @torch.no_grad()
def benchmark_prefill( def benchmark_prefill(
dtype: torch.dtype, dtype: torch.dtype,
quant_dtypes: tuple[torch.dtype | None, torch.dtype | None, torch.dtype | None], quant_dtypes: tuple[
Optional[torch.dtype], Optional[torch.dtype], Optional[torch.dtype]
],
batch_size: int, batch_size: int,
max_seq_len: int, max_seq_len: int,
num_heads: tuple[int, int] = (64, 8), num_heads: tuple[int, int] = (64, 8),

View File

@ -14,7 +14,7 @@ import torch
from tqdm import tqdm from tqdm import tqdm
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
_w8a8_triton_block_scaled_mm, _w8a8_block_fp8_matmul,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import triton from vllm.triton_utils import triton
@ -83,7 +83,7 @@ def w8a8_block_matmul(
) )
if A.dtype == torch.float8_e4m3fn: if A.dtype == torch.float8_e4m3fn:
kernel = _w8a8_triton_block_scaled_mm kernel = _w8a8_block_fp8_matmul
else: else:
raise RuntimeError("Currently, only support tune w8a8 block fp8 kernel.") raise RuntimeError("Currently, only support tune w8a8 block fp8 kernel.")

View File

@ -2,8 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses import dataclasses
from collections.abc import Callable, Iterable from collections.abc import Iterable
from typing import Any from typing import Any, Callable, Optional
import torch import torch
import torch.utils.benchmark as TBenchmark import torch.utils.benchmark as TBenchmark
@ -55,7 +55,7 @@ class Bench:
def __init__( def __init__(
self, self,
cuda_graph_params: CudaGraphBenchParams | None, cuda_graph_params: Optional[CudaGraphBenchParams],
label: str, label: str,
sub_label: str, sub_label: str,
description: str, description: str,

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from statistics import mean from statistics import mean
from typing import Any, NamedTuple from typing import Any, NamedTuple, Optional, Union
import numpy as np # type: ignore import numpy as np # type: ignore
import pandas as pd # type: ignore import pandas as pd # type: ignore
@ -35,8 +35,8 @@ class Distribution(ABC):
class UniformDistribution(Distribution): class UniformDistribution(Distribution):
def __init__( def __init__(
self, self,
min_val: int | float, min_val: Union[int, float],
max_val: int | float, max_val: Union[int, float],
is_integer: bool = True, is_integer: bool = True,
) -> None: ) -> None:
self.min_val = min_val self.min_val = min_val
@ -56,7 +56,7 @@ class UniformDistribution(Distribution):
class ConstantDistribution(Distribution): class ConstantDistribution(Distribution):
def __init__(self, value: int | float) -> None: def __init__(self, value: Union[int, float]) -> None:
self.value = value self.value = value
self.max_val = value self.max_val = value
@ -68,7 +68,7 @@ class ConstantDistribution(Distribution):
class ZipfDistribution(Distribution): class ZipfDistribution(Distribution):
def __init__(self, alpha: float, max_val: int | None = None) -> None: def __init__(self, alpha: float, max_val: Optional[int] = None) -> None:
self.alpha = alpha self.alpha = alpha
self.max_val = max_val self.max_val = max_val
@ -83,7 +83,7 @@ class ZipfDistribution(Distribution):
class PoissonDistribution(Distribution): class PoissonDistribution(Distribution):
def __init__(self, alpha: float, max_val: int | None = None) -> None: def __init__(self, alpha: float, max_val: Optional[int] = None) -> None:
self.alpha = alpha self.alpha = alpha
self.max_val = max_val self.max_val = max_val
@ -100,11 +100,11 @@ class PoissonDistribution(Distribution):
class LognormalDistribution(Distribution): class LognormalDistribution(Distribution):
def __init__( def __init__(
self, self,
mean: float | None = None, mean: Optional[float] = None,
sigma: float | None = None, sigma: Optional[float] = None,
average: int | None = None, average: Optional[int] = None,
median_ratio: float | None = None, median_ratio: Optional[float] = None,
max_val: int | None = None, max_val: Optional[int] = None,
) -> None: ) -> None:
self.average = average self.average = average
self.median_ratio = median_ratio self.median_ratio = median_ratio

View File

@ -13,7 +13,7 @@ from datetime import datetime
from enum import Enum from enum import Enum
from http import HTTPStatus from http import HTTPStatus
from statistics import mean from statistics import mean
from typing import NamedTuple from typing import NamedTuple, Optional, Union
import aiohttp # type: ignore import aiohttp # type: ignore
import numpy as np # type: ignore import numpy as np # type: ignore
@ -46,9 +46,9 @@ class ConversationSampling(str, Enum):
class ClientArgs(NamedTuple): class ClientArgs(NamedTuple):
seed: int seed: int
max_num_requests: int | None max_num_requests: Optional[int]
skip_first_turn: bool skip_first_turn: bool
max_turns: int | None max_turns: Optional[int]
max_active_conversations: int max_active_conversations: int
verbose: bool verbose: bool
print_content: bool print_content: bool
@ -109,9 +109,9 @@ class RequestStats(NamedTuple):
class MetricStats: class MetricStats:
def __init__(self) -> None: def __init__(self) -> None:
self.min: float | None = None self.min: Optional[float] = None
self.max: float | None = None self.max: Optional[float] = None
self.avg: float | None = None self.avg: Optional[float] = None
self.sum = 0.0 self.sum = 0.0
self.count = 0 self.count = 0
@ -143,7 +143,7 @@ class MovingAverage:
self.index = 0 self.index = 0
self.sum = 0.0 self.sum = 0.0
self.count = 0 self.count = 0
self.avg: float | None = None self.avg: Optional[float] = None
def update(self, new_value: float) -> None: def update(self, new_value: float) -> None:
if self.count < self.window_size: if self.count < self.window_size:
@ -169,7 +169,7 @@ class MovingAverage:
class DebugStats: class DebugStats:
def __init__(self, logger: logging.Logger, window_size: int) -> None: def __init__(self, logger: logging.Logger, window_size: int) -> None:
self.logger = logger self.logger = logger
self.metrics: dict[str, MovingAverage | MetricStats] = { self.metrics: dict[str, Union[MovingAverage, MetricStats]] = {
"moving_avg_ttft_ms": MovingAverage(window_size), "moving_avg_ttft_ms": MovingAverage(window_size),
"moving_avg_tpot_ms": MovingAverage(window_size), "moving_avg_tpot_ms": MovingAverage(window_size),
"ttft_ms": MetricStats(), "ttft_ms": MetricStats(),
@ -198,6 +198,14 @@ class DebugStats:
self.logger.info("-" * 50) self.logger.info("-" * 50)
# Must support Python 3.8, we can't use str.removeprefix(prefix)
# introduced in Python 3.9
def remove_prefix(text: str, prefix: str) -> str:
if text.startswith(prefix):
return text[len(prefix) :]
return text
def nanosec_to_millisec(value: float) -> float: def nanosec_to_millisec(value: float) -> float:
return value / 1000000.0 return value / 1000000.0
@ -212,8 +220,8 @@ async def send_request(
chat_url: str, chat_url: str,
model: str, model: str,
stream: bool = True, stream: bool = True,
min_tokens: int | None = None, min_tokens: Optional[int] = None,
max_tokens: int | None = None, max_tokens: Optional[int] = None,
) -> ServerResponse: ) -> ServerResponse:
payload = { payload = {
"model": model, "model": model,
@ -242,9 +250,9 @@ async def send_request(
timeout = aiohttp.ClientTimeout(total=timeout_sec) timeout = aiohttp.ClientTimeout(total=timeout_sec)
valid_response = True valid_response = True
ttft: float | None = None ttft: Optional[float] = None
chunk_delay: list[int] = [] chunk_delay: list[int] = []
latency: float | None = None latency: Optional[float] = None
first_chunk = "" first_chunk = ""
generated_text = "" generated_text = ""
@ -261,7 +269,7 @@ async def send_request(
if not chunk_bytes: if not chunk_bytes:
continue continue
chunk = chunk_bytes.decode("utf-8").removeprefix("data: ") chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
if chunk == "[DONE]": if chunk == "[DONE]":
# End of stream # End of stream
latency = time.perf_counter_ns() - start_time latency = time.perf_counter_ns() - start_time
@ -356,7 +364,7 @@ async def send_turn(
req_args: RequestArgs, req_args: RequestArgs,
verbose: bool, verbose: bool,
verify_output: bool, verify_output: bool,
) -> RequestStats | None: ) -> Optional[RequestStats]:
assert messages_to_use > 0 assert messages_to_use > 0
assert messages_to_use <= len(conversation_messages) assert messages_to_use <= len(conversation_messages)
@ -636,7 +644,7 @@ async def client_main(
if args.verbose: if args.verbose:
curr_time_sec: float = time.perf_counter() curr_time_sec: float = time.perf_counter()
time_since_last_turn: str | float = "N/A" time_since_last_turn: Union[str, float] = "N/A"
if conv_id in time_of_last_turn: if conv_id in time_of_last_turn:
time_since_last_turn = round( time_since_last_turn = round(
curr_time_sec - time_of_last_turn[conv_id], 3 curr_time_sec - time_of_last_turn[conv_id], 3
@ -761,7 +769,7 @@ def get_client_config(
"Number of conversations must be equal or larger than the number of clients" "Number of conversations must be equal or larger than the number of clients"
) )
max_req_per_client: int | None = None max_req_per_client: Optional[int] = None
if args.max_num_requests is not None: if args.max_num_requests is not None:
# Max number of requests per client # Max number of requests per client
req_per_client = args.max_num_requests // args.num_clients req_per_client = args.max_num_requests // args.num_clients
@ -928,13 +936,13 @@ async def main_mp(
f"{num_clients_finished} out of {bench_args.num_clients} clients finished, collected {len(client_metrics)} measurements, runtime {runtime_sec:.3f} sec{Color.RESET}" # noqa: E501 f"{num_clients_finished} out of {bench_args.num_clients} clients finished, collected {len(client_metrics)} measurements, runtime {runtime_sec:.3f} sec{Color.RESET}" # noqa: E501
) )
rps: str | float = round(len(client_metrics) / runtime_sec, 3) rps: Union[str, float] = round(len(client_metrics) / runtime_sec, 3)
if len(client_metrics) < (5 * bench_args.num_clients): if len(client_metrics) < (5 * bench_args.num_clients):
# Do not estimate the RPS if the number of samples is very low # Do not estimate the RPS if the number of samples is very low
# (threshold can be tuned if needed) # (threshold can be tuned if needed)
rps = "N/A" rps = "N/A"
runtime_left_sec: str | float = round( runtime_left_sec: Union[str, float] = round(
(runtime_sec / finished_convs) * (total_convs - finished_convs), 3 (runtime_sec / finished_convs) * (total_convs - finished_convs), 3
) )
if percent < 0.05: if percent < 0.05:
@ -1024,7 +1032,7 @@ def process_statistics(
warmup_percentages: list[float], warmup_percentages: list[float],
test_params: dict, test_params: dict,
verbose: bool, verbose: bool,
gen_conv_args: GenConvArgs | None = None, gen_conv_args: Optional[GenConvArgs] = None,
excel_output: bool = False, excel_output: bool = False,
) -> None: ) -> None:
if len(client_metrics) == 0: if len(client_metrics) == 0:
@ -1251,7 +1259,7 @@ async def main() -> None:
default=None, default=None,
help="The model name used in the API. " help="The model name used in the API. "
"If not specified, the model name will be the " "If not specified, the model name will be the "
"same as the `--model` argument. ", "same as the ``--model`` argument. ",
) )
parser.add_argument( parser.add_argument(

View File

@ -13,7 +13,7 @@ import argparse
import json import json
import random import random
from statistics import mean from statistics import mean
from typing import Any from typing import Any, Optional
import pandas as pd # type: ignore import pandas as pd # type: ignore
import tqdm # type: ignore import tqdm # type: ignore
@ -25,7 +25,7 @@ def has_non_english_chars(text: str) -> bool:
def content_is_valid( def content_is_valid(
content: str, min_content_len: int | None, max_content_len: int | None content: str, min_content_len: Optional[int], max_content_len: Optional[int]
) -> bool: ) -> bool:
if min_content_len and len(content) < min_content_len: if min_content_len and len(content) < min_content_len:
return False return False
@ -37,7 +37,7 @@ def content_is_valid(
def print_stats( def print_stats(
conversations: "list[dict[Any, Any]]", tokenizer: AutoTokenizer | None = None conversations: "list[dict[Any, Any]]", tokenizer: Optional[AutoTokenizer] = None
) -> None: ) -> None:
# Collect statistics # Collect statistics
stats = [] stats = []
@ -109,12 +109,12 @@ def convert_sharegpt_to_openai(
seed: int, seed: int,
input_file: str, input_file: str,
output_file: str, output_file: str,
max_items: int | None, max_items: Optional[int],
min_content_len: int | None = None, min_content_len: Optional[int] = None,
max_content_len: int | None = None, max_content_len: Optional[int] = None,
min_turns: int | None = None, min_turns: Optional[int] = None,
max_turns: int | None = None, max_turns: Optional[int] = None,
model: str | None = None, model: Optional[str] = None,
) -> None: ) -> None:
if min_turns and max_turns: if min_turns and max_turns:
assert min_turns <= max_turns assert min_turns <= max_turns

View File

@ -188,66 +188,34 @@ else()
message(FATAL_ERROR "vLLM CPU backend requires AVX512, AVX2, Power9+ ISA, S390X ISA, ARMv8 or RISC-V support.") message(FATAL_ERROR "vLLM CPU backend requires AVX512, AVX2, Power9+ ISA, S390X ISA, ARMv8 or RISC-V support.")
endif() endif()
#
# Build oneDNN for W8A8 GEMM kernels (only for x86-AVX512 /ARM platforms)
# Flag to enable ACL kernels for AARCH64 platforms
if (VLLM_BUILD_ACL STREQUAL "ON")
set(USE_ACL ON)
else()
set(USE_ACL OFF)
endif()
# Build oneDNN for GEMM kernels (only for x86-AVX512 /ARM platforms)
if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON_FOUND) OR POWER9_FOUND OR POWER10_FOUND OR POWER11_FOUND) if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON_FOUND) OR POWER9_FOUND OR POWER10_FOUND OR POWER11_FOUND)
# Fetch and build Arm Compute Library (ACL) as oneDNN's backend for AArch64 FetchContent_Declare(
# TODO [fadara01]: remove this once ACL can be fetched and built automatically as a dependency of oneDNN oneDNN
if(ASIMD_FOUND) GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git
if(DEFINED ENV{ACL_ROOT_DIR} AND IS_DIRECTORY "$ENV{ACL_ROOT_DIR}") GIT_TAG v3.9
message(STATUS "Using ACL from specified source directory: $ENV{ACL_ROOT_DIR}") GIT_PROGRESS TRUE
else() GIT_SHALLOW TRUE
message(STATUS "Downloading Arm Compute Library (ACL) from GitHub") )
FetchContent_Populate(arm_compute
SUBBUILD_DIR "${FETCHCONTENT_BASE_DIR}/arm_compute-subbuild"
SOURCE_DIR "${FETCHCONTENT_BASE_DIR}/arm_compute-src"
GIT_REPOSITORY https://github.com/ARM-software/ComputeLibrary.git
GIT_TAG v52.2.0
GIT_SHALLOW TRUE
GIT_PROGRESS TRUE
)
set(ENV{ACL_ROOT_DIR} "${arm_compute_SOURCE_DIR}")
endif()
# Build ACL with scons if(USE_ACL)
include(ProcessorCount) find_library(ARM_COMPUTE_LIBRARY NAMES arm_compute PATHS $ENV{ACL_ROOT_DIR}/build/)
ProcessorCount(_NPROC) if(NOT ARM_COMPUTE_LIBRARY)
execute_process( message(FATAL_ERROR "Could not find ARM Compute Library: please set ACL_ROOT_DIR")
COMMAND scons -j${_NPROC}
Werror=0 debug=0 neon=1 examples=0 embed_kernels=0 os=linux
arch=armv8.2-a build=native benchmark_examples=0 fixed_format_kernels=1
multi_isa=1 openmp=1 cppthreads=0
WORKING_DIRECTORY "$ENV{ACL_ROOT_DIR}"
RESULT_VARIABLE _acl_rc
)
if(NOT _acl_rc EQUAL 0)
message(FATAL_ERROR "ACL SCons build failed (exit ${_acl_rc}).")
endif() endif()
set(ONEDNN_AARCH64_USE_ACL "ON") set(ONEDNN_AARCH64_USE_ACL "ON")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-rpath,$ENV{ACL_ROOT_DIR}/build/") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-rpath,$ENV{ACL_ROOT_DIR}/build/")
add_compile_definitions(VLLM_USE_ACL) add_compile_definitions(VLLM_USE_ACL)
endif() endif()
set(FETCHCONTENT_SOURCE_DIR_ONEDNN "$ENV{FETCHCONTENT_SOURCE_DIR_ONEDNN}" CACHE PATH "Path to a local oneDNN source directory.")
if(FETCHCONTENT_SOURCE_DIR_ONEDNN)
message(STATUS "Using oneDNN from specified source directory: ${FETCHCONTENT_SOURCE_DIR_ONEDNN}")
FetchContent_Declare(
oneDNN
SOURCE_DIR ${FETCHCONTENT_SOURCE_DIR_ONEDNN}
)
else()
message(STATUS "Downloading oneDNN from GitHub")
FetchContent_Declare(
oneDNN
GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git
GIT_TAG v3.9
GIT_PROGRESS TRUE
GIT_SHALLOW TRUE
)
endif()
set(ONEDNN_LIBRARY_TYPE "STATIC") set(ONEDNN_LIBRARY_TYPE "STATIC")
set(ONEDNN_BUILD_DOC "OFF") set(ONEDNN_BUILD_DOC "OFF")
set(ONEDNN_BUILD_EXAMPLES "OFF") set(ONEDNN_BUILD_EXAMPLES "OFF")
@ -259,7 +227,7 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON
set(ONEDNN_ENABLE_ITT_TASKS "OFF") set(ONEDNN_ENABLE_ITT_TASKS "OFF")
set(ONEDNN_ENABLE_MAX_CPU_ISA "OFF") set(ONEDNN_ENABLE_MAX_CPU_ISA "OFF")
set(ONEDNN_ENABLE_CPU_ISA_HINTS "OFF") set(ONEDNN_ENABLE_CPU_ISA_HINTS "OFF")
set(ONEDNN_VERBOSE "OFF") set(ONEDNN_VERBOSE "ON")
set(CMAKE_POLICY_DEFAULT_CMP0077 NEW) set(CMAKE_POLICY_DEFAULT_CMP0077 NEW)
FetchContent_MakeAvailable(oneDNN) FetchContent_MakeAvailable(oneDNN)
@ -341,4 +309,4 @@ define_gpu_extension_target(
WITH_SOABI WITH_SOABI
) )
message(STATUS "Enabling C extension.") message(STATUS "Enabling C extension.")

View File

@ -1,97 +0,0 @@
include(FetchContent)
set(CUTLASS_INCLUDE_DIR "${CUTLASS_INCLUDE_DIR}" CACHE PATH "Path to CUTLASS include/ directory")
if(DEFINED ENV{QUTLASS_SRC_DIR})
set(QUTLASS_SRC_DIR $ENV{QUTLASS_SRC_DIR})
endif()
if(QUTLASS_SRC_DIR)
FetchContent_Declare(
qutlass
SOURCE_DIR ${QUTLASS_SRC_DIR}
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
)
else()
FetchContent_Declare(
qutlass
GIT_REPOSITORY https://github.com/IST-DASLab/qutlass.git
GIT_TAG 830d2c4537c7396e14a02a46fbddd18b5d107c65
GIT_PROGRESS TRUE
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
)
endif()
FetchContent_Populate(qutlass)
if(NOT qutlass_SOURCE_DIR)
message(FATAL_ERROR "[QUTLASS] source directory could not be resolved.")
endif()
message(STATUS "[QUTLASS] QuTLASS is available at ${qutlass_SOURCE_DIR}")
cuda_archs_loose_intersection(QUTLASS_ARCHS "12.0a;10.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND QUTLASS_ARCHS)
if(QUTLASS_ARCHS MATCHES "10\\.0a")
set(QUTLASS_TARGET_CC 100)
elseif(QUTLASS_ARCHS MATCHES "12\\.0a")
set(QUTLASS_TARGET_CC 120)
else()
message(FATAL_ERROR "[QUTLASS] internal error parsing CUDA_ARCHS='${QUTLASS_ARCHS}'.")
endif()
set(QUTLASS_SOURCES
${qutlass_SOURCE_DIR}/qutlass/csrc/bindings.cpp
${qutlass_SOURCE_DIR}/qutlass/csrc/gemm.cu
${qutlass_SOURCE_DIR}/qutlass/csrc/gemm_ada.cu
${qutlass_SOURCE_DIR}/qutlass/csrc/fused_quantize_mx.cu
${qutlass_SOURCE_DIR}/qutlass/csrc/fused_quantize_nv.cu
${qutlass_SOURCE_DIR}/qutlass/csrc/fused_quantize_mx_sm100.cu
${qutlass_SOURCE_DIR}/qutlass/csrc/fused_quantize_nv_sm100.cu
)
set(QUTLASS_INCLUDES
${qutlass_SOURCE_DIR}
${qutlass_SOURCE_DIR}/qutlass
${qutlass_SOURCE_DIR}/qutlass/csrc/include
${qutlass_SOURCE_DIR}/qutlass/csrc/include/cutlass_extensions
)
if(CUTLASS_INCLUDE_DIR AND EXISTS "${CUTLASS_INCLUDE_DIR}/cutlass/cutlass.h")
list(APPEND QUTLASS_INCLUDES "${CUTLASS_INCLUDE_DIR}")
elseif(EXISTS "${qutlass_SOURCE_DIR}/qutlass/third_party/cutlass/include/cutlass/cutlass.h")
list(APPEND QUTLASS_INCLUDES "${qutlass_SOURCE_DIR}/qutlass/third_party/cutlass/include")
message(STATUS "[QUTLASS] Using QuTLASS vendored CUTLASS headers (no vLLM CUTLASS detected).")
else()
message(FATAL_ERROR "[QUTLASS] CUTLASS headers not found. "
"Set -DCUTLASS_INCLUDE_DIR=/path/to/cutlass/include")
endif()
set_gencode_flags_for_srcs(
SRCS "${QUTLASS_SOURCES}"
CUDA_ARCHS "${QUTLASS_ARCHS}"
)
target_sources(_C PRIVATE ${QUTLASS_SOURCES})
target_include_directories(_C PRIVATE ${QUTLASS_INCLUDES})
target_compile_definitions(_C PRIVATE
QUTLASS_DISABLE_PYBIND=1
TARGET_CUDA_ARCH=${QUTLASS_TARGET_CC}
)
set_property(SOURCE ${QUTLASS_SOURCES} APPEND PROPERTY COMPILE_OPTIONS
$<$<COMPILE_LANGUAGE:CUDA>:--expt-relaxed-constexpr --use_fast_math -O3>
)
else()
if("${CMAKE_CUDA_COMPILER_VERSION}" VERSION_LESS "12.8")
message(STATUS
"[QUTLASS] Skipping build: CUDA 12.8 or newer is required (found ${CMAKE_CUDA_COMPILER_VERSION}).")
else()
message(STATUS
"[QUTLASS] Skipping build: no supported arch (12.0a / 10.0a) found in "
"CUDA_ARCHS='${CUDA_ARCHS}'.")
endif()
endif()

View File

@ -38,7 +38,7 @@ else()
FetchContent_Declare( FetchContent_Declare(
vllm-flash-attn vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_TAG a893712401d70362fbb299cd9c4b3476e8e9ed54 GIT_TAG 4695e6bed5366c41e28c06cd86170166e4f43d00
GIT_PROGRESS TRUE GIT_PROGRESS TRUE
# Don't share the vllm-flash-attn build between build types # Don't share the vllm-flash-attn build between build types
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn

View File

@ -1,12 +0,0 @@
codecov:
require_ci_to_pass: false
fixes:
# Map source code paths to repository root paths
# Wildcards match any Python version (python3.*)
- "/vllm-workspace/src/vllm/::vllm/"
- "/vllm-workspace/vllm/::vllm/"
- "/usr/local/lib/python3.*/dist-packages/vllm/::vllm/"
- "/usr/local/lib/python3.*/site-packages/vllm/::vllm/"
- "/usr/lib/python3.*/dist-packages/vllm/::vllm/"
- "/usr/lib/python3.*/site-packages/vllm/::vllm/"

View File

@ -28,10 +28,10 @@
#ifdef USE_ROCM #ifdef USE_ROCM
#include <hip/hip_bf16.h> #include <hip/hip_bf16.h>
#include "../quantization/w8a8/fp8/amd/quant_utils.cuh" #include "../quantization/fp8/amd/quant_utils.cuh"
typedef __hip_bfloat16 __nv_bfloat16; typedef __hip_bfloat16 __nv_bfloat16;
#else #else
#include "../quantization/w8a8/fp8/nvidia/quant_utils.cuh" #include "../quantization/fp8/nvidia/quant_utils.cuh"
#endif #endif
#define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MAX(a, b) ((a) > (b) ? (a) : (b))

View File

@ -125,37 +125,32 @@ public:
} }
static void set_split_kv (KernelArguments& args) { static void set_split_kv (KernelArguments& args) {
// printf("set_split_kv start");
if (args.split_kv >= 1) return; if (args.split_kv >= 1) return;
auto [H, K, D, B] = args.problem_shape; auto [H, K, D, B] = args.problem_shape;
// std::cout << H << " " << K << " " << D << " " << B << "\n";
int sm_count = args.hw_info.sm_count; int sm_count = args.hw_info.sm_count;
float seq_length_k = static_cast<float>(K) / 1024.0f; // printf(" sm_count = %d\n", sm_count);
int max_splits = 1; int max_splits = ceil_div(K, 128);
max_splits = min(16, max_splits);
if (B <= 4 && seq_length_k >= 16) { // TODO: This avoids a hang when the batch size larger than 1 and
max_splits = 16; // there is more than 1 kv_splits.
// Discuss with NVIDIA how this can be fixed.
if (B > 1) {
max_splits = min(1, max_splits);
} }
else if (B <= 8 && seq_length_k >= 4) {
max_splits = 8; // printf(" max_splits = %d\n", max_splits);
}
else if ((B <= 16 && seq_length_k >= 8) ||
(B == 48 && seq_length_k >= 32)) {
max_splits = 4;
}
else if ((B <= 32 && seq_length_k >= 16) ||
(B == 96 && seq_length_k >= 16)) {
max_splits = 2;
}
else {
max_splits = 1;
}
// Wave-aware scheduling: ensure integer number of waves in K dimension
int sms_per_batch = max(1, sm_count / B); int sms_per_batch = max(1, sm_count / B);
// printf(" sms_per_batch = %d\n", sms_per_batch);
int split_heur = min(max_splits, sms_per_batch); int split_heur = min(max_splits, sms_per_batch);
int waves = ceil_div(B * split_heur, sm_count); int waves = ceil_div(B * split_heur, sm_count);
int k_waves = ceil_div(max_splits, split_heur); int k_waves = ceil_div(max_splits, split_heur);
int split_wave_aware = ceil_div(max_splits, k_waves); int split_wave_aware = ceil_div(max_splits, k_waves);
args.split_kv = split_wave_aware; args.split_kv = split_wave_aware;
// printf(" args.split_kv = %d\n", args.split_kv);
} }
/// Determines whether the GEMM can execute the given problem. /// Determines whether the GEMM can execute the given problem.

View File

@ -64,11 +64,3 @@ void indexer_k_quant_and_cache(
torch::Tensor& slot_mapping, // [num_tokens] torch::Tensor& slot_mapping, // [num_tokens]
int64_t quant_block_size, // quantization block size int64_t quant_block_size, // quantization block size
const std::string& scale_fmt); const std::string& scale_fmt);
// Extract function to gather quantized K cache
void cp_gather_indexer_k_quant_cache(
const torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride]
torch::Tensor& dst_k, // [num_tokens, head_dim]
torch::Tensor& dst_scale, // [num_tokens, head_dim / quant_block_size * 4]
const torch::Tensor& block_table, // [batch_size, num_blocks]
const torch::Tensor& cu_seq_lens); // [batch_size + 1]

View File

@ -9,9 +9,9 @@
#include "quantization/vectorization_utils.cuh" #include "quantization/vectorization_utils.cuh"
#ifdef USE_ROCM #ifdef USE_ROCM
#include "quantization/w8a8/fp8/amd/quant_utils.cuh" #include "quantization/fp8/amd/quant_utils.cuh"
#else #else
#include "quantization/w8a8/fp8/nvidia/quant_utils.cuh" #include "quantization/fp8/nvidia/quant_utils.cuh"
#endif #endif
#include <algorithm> #include <algorithm>
@ -572,70 +572,6 @@ __global__ void indexer_k_quant_and_cache_kernel(
} }
} }
template <int BLOCK_Y_SIZE>
__global__ void cp_gather_indexer_k_quant_cache_kernel(
const char* __restrict__ kv_cache, // [num_blocks, block_size,
// cache_stride]
char* __restrict__ dst_k, // [num_tokens, head_dim]
char* __restrict__ dst_scale, // [num_tokens, head_dim / quant_block_size *
// 4]
const int* __restrict__ block_table, // [batch_size, num_blocks]
const int* __restrict__ cu_seq_lens, // [batch_size + 1]
const int batch_size, // batch size
const int64_t token_stride, // stride for each token in dst_k
const int64_t head_dim, // dimension of each head
const int64_t block_stride, // stride for each block in kv_cache
const int64_t cache_token_stride, // stride for each token in kv_cache
const int64_t cache_block_size, // num_tokens for each block in kv_cache
const int num_blocks, // number of blocks
const int num_tokens, // number of tokens
const int quant_block_size // quantization block size
) {
constexpr int VEC_SIZE = sizeof(float4) / sizeof(char);
const int token_idx = blockIdx.x * blockDim.y + threadIdx.y;
const int head_idx = (blockIdx.y * blockDim.x + threadIdx.x) * VEC_SIZE;
// Find batch index within a block
__shared__ int batch_idx[BLOCK_Y_SIZE];
for (int iter = 0; iter < cuda_utils::ceil_div(batch_size, int(blockDim.x));
iter++) {
int tid = iter * blockDim.x + threadIdx.x;
if (tid < batch_size) {
const int seq_start = cu_seq_lens[tid];
const int seq_end = cu_seq_lens[tid + 1];
if (token_idx >= seq_start && token_idx < seq_end) {
batch_idx[threadIdx.y] = tid;
}
}
}
#ifndef USE_ROCM
__syncwarp();
#endif
if (head_idx >= head_dim || token_idx >= num_tokens) {
return;
}
const int inbatch_seq_idx = token_idx - cu_seq_lens[batch_idx[threadIdx.y]];
const int block_idx = block_table[batch_idx[threadIdx.y] * num_blocks +
inbatch_seq_idx / cache_block_size];
const int64_t src_block_offset = block_idx * block_stride;
const int64_t cache_inblock_offset =
(inbatch_seq_idx % cache_block_size) * head_dim + head_idx;
const int64_t src_inblock_offset = src_block_offset + cache_inblock_offset;
const int64_t dst_inblock_offset = token_idx * token_stride + head_idx;
reinterpret_cast<float4*>(dst_k)[dst_inblock_offset / VEC_SIZE] =
reinterpret_cast<const float4*>(kv_cache)[src_inblock_offset / VEC_SIZE];
;
if (threadIdx.x == 0) {
const int64_t src_scale_offset =
src_block_offset + cache_block_size * head_dim +
cache_inblock_offset * 4 / quant_block_size;
reinterpret_cast<float*>(dst_scale)[dst_inblock_offset / quant_block_size] =
reinterpret_cast<const float*>(kv_cache)[src_scale_offset / 4];
}
}
} // namespace vllm } // namespace vllm
// KV_T is the data type of key and value tensors. // KV_T is the data type of key and value tensors.
@ -1237,59 +1173,3 @@ void indexer_k_quant_and_cache(
DISPATCH_BY_KV_CACHE_DTYPE(k.dtype(), "fp8_e4m3", DISPATCH_BY_KV_CACHE_DTYPE(k.dtype(), "fp8_e4m3",
CALL_INDEXER_K_QUANT_AND_CACHE); CALL_INDEXER_K_QUANT_AND_CACHE);
} }
// Macro to dispatch the kernel based on the data amount.
#define CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(BLOCK_Y_SIZE) \
vllm::cp_gather_indexer_k_quant_cache_kernel<BLOCK_Y_SIZE> \
<<<dim3((num_tokens + BLOCK_Y_SIZE - 1) / BLOCK_Y_SIZE, \
(head_dim + 8 * vec_size - 1) / (8 * vec_size)), \
dim3(8, BLOCK_Y_SIZE), 0, stream>>>( \
reinterpret_cast<char*>(kv_cache.data_ptr()), \
reinterpret_cast<char*>(dst_k.data_ptr()), \
reinterpret_cast<char*>(dst_scale.data_ptr()), \
block_table.data_ptr<int32_t>(), cu_seq_lens.data_ptr<int32_t>(), \
batch_size, dst_k.stride(0), dst_k.size(1), kv_cache.stride(0), \
kv_cache.stride(1), kv_cache.size(1), block_table.size(1), \
num_tokens, quant_block_size);
void cp_gather_indexer_k_quant_cache(
const torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride]
torch::Tensor& dst_k, // [num_tokens, head_dim]
torch::Tensor& dst_scale, // [num_tokens, head_dim / quant_block_size * 4]
const torch::Tensor& block_table, // [batch_size, num_blocks]
const torch::Tensor& cu_seq_lens // [batch_size + 1]
) {
int batch_size = block_table.size(0);
int num_tokens = dst_k.size(0);
int head_dim = dst_k.size(1);
int quant_block_size = head_dim * 4 / dst_scale.size(1);
TORCH_CHECK(kv_cache.device() == dst_k.device(),
"kv_cache and dst_k must be on the same device");
TORCH_CHECK(kv_cache.device() == dst_scale.device(),
"kv_cache and dst_scale must be on the same device");
TORCH_CHECK(kv_cache.device() == block_table.device(),
"kv_cache and block_table must be on the same device");
TORCH_CHECK(kv_cache.device() == cu_seq_lens.device(),
"kv_cache and cu_seq_lens must be on the same device");
TORCH_CHECK(head_dim % quant_block_size == 0,
"head_dim must be divisible by quant_block_size");
constexpr int vec_size = 16;
const at::cuda::OptionalCUDAGuard device_guard(device_of(kv_cache));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (num_tokens < 32) {
CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(1);
} else if (num_tokens < 64) {
CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(2);
} else if (num_tokens < 128) {
CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(4);
} else if (num_tokens < 256) {
CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(8);
} else if (num_tokens < 512) {
CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(16);
} else {
CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(32);
}
}

View File

@ -5,15 +5,12 @@
namespace vllm { namespace vllm {
// vllm_is_batch_invariant(); returns true // vllm_kernel_override_batch_invariant(); returns true
// if env VLLM_BATCH_INVARIANT=1 // if env VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT=1
inline bool vllm_is_batch_invariant() { inline bool vllm_kernel_override_batch_invariant() {
static bool cached = []() { std::string env_key = "VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT";
std::string env_key = "VLLM_BATCH_INVARIANT"; const char* val = std::getenv(env_key.c_str());
const char* val = std::getenv(env_key.c_str()); return (val && std::atoi(val) != 0) ? 1 : 0;
return (val && std::atoi(val) != 0) ? 1 : 0;
}();
return cached;
} }
} // namespace vllm } // namespace vllm

View File

@ -12,7 +12,6 @@ using CubMaxOp = cub::Max;
#endif // CUB_VERSION #endif // CUB_VERSION
#else #else
#include <hipcub/hipcub.hpp> #include <hipcub/hipcub.hpp>
namespace cub = hipcub; using CubAddOp = cub::Sum;
using CubAddOp = hipcub::Sum; using CubMaxOp = cub::Max;
using CubMaxOp = hipcub::Max;
#endif // USE_ROCM #endif // USE_ROCM

View File

@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import enum import enum
from typing import Union
from cutlass_library import * from cutlass_library import *
@ -21,7 +22,7 @@ class MixedInputKernelScheduleType(enum.Enum):
TmaWarpSpecializedCooperative = enum_auto() TmaWarpSpecializedCooperative = enum_auto()
VLLMDataTypeNames: dict[VLLMDataType | DataType, str] = { VLLMDataTypeNames: dict[Union[VLLMDataType, DataType], str] = {
**DataTypeNames, # type: ignore **DataTypeNames, # type: ignore
**{ **{
VLLMDataType.u4b8: "u4b8", VLLMDataType.u4b8: "u4b8",
@ -29,7 +30,7 @@ VLLMDataTypeNames: dict[VLLMDataType | DataType, str] = {
}, },
} }
VLLMDataTypeTag: dict[VLLMDataType | DataType, str] = { VLLMDataTypeTag: dict[Union[VLLMDataType, DataType], str] = {
**DataTypeTag, # type: ignore **DataTypeTag, # type: ignore
**{ **{
VLLMDataType.u4b8: "cutlass::vllm_uint4b8_t", VLLMDataType.u4b8: "cutlass::vllm_uint4b8_t",
@ -37,7 +38,7 @@ VLLMDataTypeTag: dict[VLLMDataType | DataType, str] = {
}, },
} }
VLLMDataTypeSize: dict[VLLMDataType | DataType, int] = { VLLMDataTypeSize: dict[Union[VLLMDataType, DataType], int] = {
**DataTypeSize, # type: ignore **DataTypeSize, # type: ignore
**{ **{
VLLMDataType.u4b8: 4, VLLMDataType.u4b8: 4,
@ -45,7 +46,7 @@ VLLMDataTypeSize: dict[VLLMDataType | DataType, int] = {
}, },
} }
VLLMDataTypeVLLMScalarTypeTag: dict[VLLMDataType | DataType, str] = { VLLMDataTypeVLLMScalarTypeTag: dict[Union[VLLMDataType, DataType], str] = {
VLLMDataType.u4b8: "vllm::kU4B8", VLLMDataType.u4b8: "vllm::kU4B8",
VLLMDataType.u8b128: "vllm::kU8B128", VLLMDataType.u8b128: "vllm::kU8B128",
DataType.u4: "vllm::kU4", DataType.u4: "vllm::kU4",
@ -56,7 +57,7 @@ VLLMDataTypeVLLMScalarTypeTag: dict[VLLMDataType | DataType, str] = {
DataType.bf16: "vllm::kBfloat16", DataType.bf16: "vllm::kBfloat16",
} }
VLLMDataTypeTorchDataTypeTag: dict[VLLMDataType | DataType, str] = { VLLMDataTypeTorchDataTypeTag: dict[Union[VLLMDataType, DataType], str] = {
DataType.u8: "at::ScalarType::Byte", DataType.u8: "at::ScalarType::Byte",
DataType.s8: "at::ScalarType::Char", DataType.s8: "at::ScalarType::Char",
DataType.e4m3: "at::ScalarType::Float8_e4m3fn", DataType.e4m3: "at::ScalarType::Float8_e4m3fn",
@ -66,7 +67,9 @@ VLLMDataTypeTorchDataTypeTag: dict[VLLMDataType | DataType, str] = {
DataType.f32: "at::ScalarType::Float", DataType.f32: "at::ScalarType::Float",
} }
VLLMKernelScheduleTag: dict[MixedInputKernelScheduleType | KernelScheduleType, str] = { VLLMKernelScheduleTag: dict[
Union[MixedInputKernelScheduleType, KernelScheduleType], str
] = {
**KernelScheduleTag, # type: ignore **KernelScheduleTag, # type: ignore
**{ **{
MixedInputKernelScheduleType.TmaWarpSpecialized: "cutlass::gemm::KernelTmaWarpSpecialized", # noqa: E501 MixedInputKernelScheduleType.TmaWarpSpecialized: "cutlass::gemm::KernelTmaWarpSpecialized", # noqa: E501

View File

@ -2,7 +2,6 @@
#include "dispatch_utils.h" #include "dispatch_utils.h"
#include "cub_helpers.h" #include "cub_helpers.h"
#include "core/batch_invariant.hpp" #include "core/batch_invariant.hpp"
#include "quantization/vectorization_utils.cuh"
#include <torch/cuda.h> #include <torch/cuda.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
@ -19,22 +18,11 @@ __global__ void rms_norm_kernel(
const float epsilon, const int num_tokens, const int hidden_size) { const float epsilon, const int num_tokens, const int hidden_size) {
__shared__ float s_variance; __shared__ float s_variance;
float variance = 0.0f; float variance = 0.0f;
const scalar_t* input_row = input + blockIdx.x * input_stride;
constexpr int VEC_SIZE = 8; for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
auto vec_op = [&variance](const vec_n_t<scalar_t, VEC_SIZE>& vec) { const float x = (float)input[blockIdx.x * input_stride + idx];
#pragma unroll
for (int i = 0; i < VEC_SIZE; ++i) {
float x = static_cast<float>(vec.val[i]);
variance += x * x;
}
};
auto scalar_op = [&variance](const scalar_t& val) {
float x = static_cast<float>(val);
variance += x * x; variance += x * x;
}; }
vllm::vectorize_read_with_alignment<VEC_SIZE>(
input_row, hidden_size, threadIdx.x, blockDim.x, vec_op, scalar_op);
using BlockReduce = cub::BlockReduce<float, 1024>; using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore; __shared__ typename BlockReduce::TempStorage reduceStore;
@ -148,6 +136,211 @@ fused_add_rms_norm_kernel(
} }
} }
/* Function specialization in the case of FP16/BF16 tensors.
Additional optimizations we can make in this case are
packed and vectorized operations, which help with the
memory latency bottleneck.
_f16VecPN struct extends _f16Vec to add operations specifically required for
polynomial normalization (poly norm).
The original _f16Vec does not include the sum-of-powers computation or
in-place polynomial normalization logic. */
template <typename scalar_t, int width>
struct alignas(16) _f16VecPN : _f16Vec<scalar_t, width> {
using Base = _f16Vec<scalar_t, width>;
using Converter = typename Base::Converter;
using T1 = typename Base::T1;
using T2 = typename Base::T2;
using Base::data;
__device__ auto sum_pows() const {
float s2 = 0.0f, s4 = 0.0f, s6 = 0.0f;
#pragma unroll
for (int i = 0; i < width; i += 2) {
float2 z = Converter::convert(T2{data[i], data[i + 1]});
float x2 = z.x * z.x;
float x4 = x2 * x2;
float x6 = x4 * x2;
float y2 = z.y * z.y;
float y4 = y2 * y2;
float y6 = y4 * y2;
s2 += x2 + y2;
s4 += x4 + y4;
s6 += x6 + y6;
}
return std::make_tuple(s2, s4, s6);
}
__device__ void poly_norm_inplace(const float w2_inv_std,
const float w1_inv_std2,
const float w0_inv_std3, const float bias) {
#pragma unroll
for (int i = 0; i < width; i += 2) {
float2 z = Converter::convert(T2{data[i], data[i + 1]});
float x2 = z.x * z.x;
float x3 = x2 * z.x;
z.x = w2_inv_std * z.x + w1_inv_std2 * x2 + w0_inv_std3 * x3 + bias;
float y2 = z.y * z.y;
float y3 = y2 * z.y;
z.y = w2_inv_std * z.y + w1_inv_std2 * y2 + w0_inv_std3 * y3 + bias;
auto out = Converter::convert(z);
data[i] = out.x;
data[i + 1] = out.y;
}
}
};
template <typename scalar_t, int width>
__global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
poly_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [3]
const scalar_t* __restrict__ bias, // [1]
const float epsilon, const int hidden_size) {
// Sanity checks on our vector struct and type-punned pointer arithmetic
static_assert(std::is_pod_v<_f16VecPN<scalar_t, width>>);
static_assert(sizeof(_f16VecPN<scalar_t, width>) == sizeof(scalar_t) * width);
/* These and the argument pointers are all declared `restrict` as they are
not aliased in practice. Argument pointers should not be dereferenced
in this kernel as that would be undefined behavior */
auto* __restrict__ input_v =
reinterpret_cast<const _f16VecPN<scalar_t, width>*>(input);
const int vec_hidden_size = hidden_size / width;
float variance = 0.0f;
float variance2 = 0.0f;
float variance3 = 0.0f;
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
int id = blockIdx.x * vec_hidden_size + idx;
_f16VecPN<scalar_t, width> temp = input_v[id];
auto [x2, x4, x6] = temp.sum_pows();
variance += x2;
variance2 += x4;
variance3 += x6;
}
float3 thread_variances = make_float3(variance, variance2, variance3);
struct SumOp {
__device__ float3 operator()(const float3& a, const float3& b) const {
return make_float3(a.x + b.x, a.y + b.y, a.z + b.z);
}
};
using BlockReduce = cub::BlockReduce<float3, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore;
float3 block_variances =
BlockReduce(reduceStore).Reduce(thread_variances, SumOp{}, blockDim.x);
variance = block_variances.x;
variance2 = block_variances.y;
variance3 = block_variances.z;
__shared__ float s_w2_inv_std;
__shared__ float s_w1_inv_std2;
__shared__ float s_w0_inv_std3;
__shared__ float s_bias;
if (threadIdx.x == 0) {
float w0 = (float)weight[0];
float w1 = (float)weight[1];
float w2 = (float)weight[2];
s_bias = (float)bias[0];
s_w2_inv_std = w2 * rsqrtf(variance / hidden_size + epsilon);
s_w1_inv_std2 = w1 * rsqrtf(variance2 / hidden_size + epsilon);
s_w0_inv_std3 = w0 * rsqrtf(variance3 / hidden_size + epsilon);
}
__syncthreads();
auto* __restrict__ out_v = reinterpret_cast<_f16VecPN<scalar_t, width>*>(out);
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
int id = blockIdx.x * vec_hidden_size + idx;
_f16VecPN<scalar_t, width> temp = input_v[id];
temp.poly_norm_inplace(s_w2_inv_std, s_w1_inv_std2, s_w0_inv_std3, s_bias);
out_v[id] = temp;
}
}
/* Generic poly_norm_kernel
The width field is not used here but necessary for other specializations.
*/
template <typename scalar_t, int width>
__global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
poly_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [3]
const scalar_t* __restrict__ bias, // [1]
const float epsilon, const int hidden_size) {
float variance = 0.0f;
float variance2 = 0.0f;
float variance3 = 0.0f;
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float)input[blockIdx.x * hidden_size + idx];
float x2 = x * x;
float x4 = x2 * x2;
float x6 = x4 * x2;
variance += x2;
variance2 += x4;
variance3 += x6;
}
float3 thread_variances = make_float3(variance, variance2, variance3);
struct SumOp {
__device__ float3 operator()(const float3& a, const float3& b) const {
return make_float3(a.x + b.x, a.y + b.y, a.z + b.z);
}
};
using BlockReduce = cub::BlockReduce<float3, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore;
float3 block_variances =
BlockReduce(reduceStore).Reduce(thread_variances, SumOp{}, blockDim.x);
variance = block_variances.x;
variance2 = block_variances.y;
variance3 = block_variances.z;
__shared__ float s_w2_inv_std;
__shared__ float s_w1_inv_std2;
__shared__ float s_w0_inv_std3;
__shared__ float s_bias;
if (threadIdx.x == 0) {
float w0 = (float)weight[0];
float w1 = (float)weight[1];
float w2 = (float)weight[2];
s_bias = (float)bias[0];
s_w2_inv_std = w2 * rsqrtf(variance / hidden_size + epsilon);
s_w1_inv_std2 = w1 * rsqrtf(variance2 / hidden_size + epsilon);
s_w0_inv_std3 = w0 * rsqrtf(variance3 / hidden_size + epsilon);
}
__syncthreads();
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float)input[blockIdx.x * hidden_size + idx];
float x2 = x * x;
float x3 = x2 * x;
out[blockIdx.x * hidden_size + idx] =
(scalar_t)(x * s_w2_inv_std + x2 * s_w1_inv_std2 + x3 * s_w0_inv_std3 +
s_bias);
}
}
} // namespace vllm } // namespace vllm
void rms_norm(torch::Tensor& out, // [..., hidden_size] void rms_norm(torch::Tensor& out, // [..., hidden_size]
@ -159,26 +352,18 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
TORCH_CHECK(weight.is_contiguous()); TORCH_CHECK(weight.is_contiguous());
int hidden_size = input.size(-1); int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
// We cannot just use `input.stride(-2)` if the tensor is not row-major. int64_t input_stride = input.stride(-2);
// Instead, we use a 2d view to get the second-innermost stride.
// That way the dimensions (except the last one) can be arbitrarily permuted.
torch::Tensor input_view = input.view({-1, hidden_size});
int num_tokens = input_view.numel() / hidden_size;
int64_t input_stride = input_view.stride(-2);
dim3 grid(num_tokens); dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024)); dim3 block(std::min(hidden_size, 1024));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input_view)); const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES( VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] {
input_view.scalar_type(), "rms_norm_kernel", [&] { vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>( out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), input_stride,
out.data_ptr<scalar_t>(), input_view.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size);
input_stride, weight.data_ptr<scalar_t>(), epsilon, num_tokens, });
hidden_size);
});
} }
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \ #define LAUNCH_FUSED_ADD_RMS_NORM(width) \
@ -195,8 +380,6 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
torch::Tensor& residual, // [..., hidden_size] torch::Tensor& residual, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size] torch::Tensor& weight, // [hidden_size]
double epsilon) { double epsilon) {
TORCH_CHECK(weight.scalar_type() == input.scalar_type());
TORCH_CHECK(input.scalar_type() == residual.scalar_type());
TORCH_CHECK(residual.is_contiguous()); TORCH_CHECK(residual.is_contiguous());
TORCH_CHECK(weight.is_contiguous()); TORCH_CHECK(weight.is_contiguous());
int hidden_size = input.size(-1); int hidden_size = input.size(-1);
@ -231,7 +414,7 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
wt_ptr % req_alignment_bytes == 0; wt_ptr % req_alignment_bytes == 0;
bool offsets_are_multiple_of_vector_width = bool offsets_are_multiple_of_vector_width =
hidden_size % vector_width == 0 && input_stride % vector_width == 0; hidden_size % vector_width == 0 && input_stride % vector_width == 0;
bool batch_invariant_launch = vllm::vllm_is_batch_invariant(); bool batch_invariant_launch = vllm::vllm_kernel_override_batch_invariant();
if (ptrs_are_aligned && offsets_are_multiple_of_vector_width && if (ptrs_are_aligned && offsets_are_multiple_of_vector_width &&
!batch_invariant_launch) { !batch_invariant_launch) {
LAUNCH_FUSED_ADD_RMS_NORM(8); LAUNCH_FUSED_ADD_RMS_NORM(8);
@ -239,3 +422,50 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
LAUNCH_FUSED_ADD_RMS_NORM(0); LAUNCH_FUSED_ADD_RMS_NORM(0);
} }
} }
#define LAUNCH_FUSED_POLY_NORM(width) \
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "poly_norm_kernel", [&] { \
vllm::poly_norm_kernel<scalar_t, width><<<grid, block, 0, stream>>>( \
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), \
weight.data_ptr<scalar_t>(), bias.data_ptr<scalar_t>(), epsilon, \
hidden_size); \
});
void poly_norm(torch::Tensor& out, // [..., hidden_size]
torch::Tensor& input, // [..., hidden_size]
torch::Tensor& weight, // [3]
torch::Tensor& bias, // [1]
double epsilon) {
TORCH_CHECK(out.is_contiguous());
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(out.data_ptr() != input.data_ptr());
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
dim3 grid(num_tokens);
/* This kernel is memory-latency bound in many scenarios.
When num_tokens is large, a smaller block size allows
for increased block occupancy on CUs and better latency
hiding on global mem ops. */
const int max_block_size = (num_tokens < 256) ? 1024 : 256;
dim3 block(std::min(hidden_size, max_block_size));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
/*If the tensor types are FP16/BF16, try to use the optimized kernel
with packed + vectorized ops.
Max optimization is achieved with a width-8 vector of FP16/BF16s
since we can load at most 128 bits at once in a global memory op.
However, this requires each tensor's data to be aligned to 16
bytes.
*/
auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
auto out_ptr = reinterpret_cast<std::uintptr_t>(out.data_ptr());
bool ptrs_are_aligned = inp_ptr % 16 == 0 && out_ptr % 16 == 0;
bool batch_invariant_launch = vllm::vllm_kernel_override_batch_invariant();
if (ptrs_are_aligned && hidden_size % 8 == 0 && !batch_invariant_launch) {
LAUNCH_FUSED_POLY_NORM(8);
} else {
LAUNCH_FUSED_POLY_NORM(0);
}
}

View File

@ -6,11 +6,10 @@
*/ */
#include "type_convert.cuh" #include "type_convert.cuh"
#include "quantization/w8a8/fp8/common.cuh" #include "quantization/fp8/common.cuh"
#include "dispatch_utils.h" #include "dispatch_utils.h"
#include "cub_helpers.h" #include "cub_helpers.h"
#include "core/batch_invariant.hpp" #include "core/batch_invariant.hpp"
#include "quantization/vectorization_utils.cuh"
#include <torch/cuda.h> #include <torch/cuda.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
@ -29,22 +28,10 @@ __global__ void rms_norm_static_fp8_quant_kernel(
__shared__ float s_variance; __shared__ float s_variance;
float variance = 0.0f; float variance = 0.0f;
const scalar_t* input_row = input + blockIdx.x * input_stride; for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
const float x = (float)input[blockIdx.x * input_stride + idx];
constexpr int VEC_SIZE = 8;
auto vec_op = [&variance](const vec_n_t<scalar_t, VEC_SIZE>& vec) {
#pragma unroll
for (int i = 0; i < VEC_SIZE; ++i) {
float x = static_cast<float>(vec.val[i]);
variance += x * x;
}
};
auto scalar_op = [&variance](const scalar_t& val) {
float x = static_cast<float>(val);
variance += x * x; variance += x * x;
}; }
vllm::vectorize_read_with_alignment<VEC_SIZE>(
input_row, hidden_size, threadIdx.x, blockDim.x, vec_op, scalar_op);
using BlockReduce = cub::BlockReduce<float, 1024>; using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore; __shared__ typename BlockReduce::TempStorage reduceStore;
@ -229,8 +216,6 @@ void fused_add_rms_norm_static_fp8_quant(
double epsilon) { double epsilon) {
TORCH_CHECK(out.is_contiguous()); TORCH_CHECK(out.is_contiguous());
TORCH_CHECK(residual.is_contiguous()); TORCH_CHECK(residual.is_contiguous());
TORCH_CHECK(residual.scalar_type() == input.scalar_type());
TORCH_CHECK(weight.scalar_type() == input.scalar_type());
int hidden_size = input.size(-1); int hidden_size = input.size(-1);
int input_stride = input.stride(-2); int input_stride = input.stride(-2);
int num_tokens = input.numel() / hidden_size; int num_tokens = input.numel() / hidden_size;
@ -256,7 +241,7 @@ void fused_add_rms_norm_static_fp8_quant(
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr()); auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
bool ptrs_are_aligned = bool ptrs_are_aligned =
inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0; inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0;
bool batch_invariant_launch = vllm::vllm_is_batch_invariant(); bool batch_invariant_launch = vllm::vllm_kernel_override_batch_invariant();
if (ptrs_are_aligned && hidden_size % 8 == 0 && input_stride % 8 == 0 && if (ptrs_are_aligned && hidden_size % 8 == 0 && input_stride % 8 == 0 &&
!batch_invariant_launch) { !batch_invariant_launch) {
LAUNCH_FUSED_ADD_RMS_NORM(8); LAUNCH_FUSED_ADD_RMS_NORM(8);

View File

@ -8,77 +8,12 @@
#include "../cuda_compat.h" #include "../cuda_compat.h"
#include "../dispatch_utils.h" #include "../dispatch_utils.h"
#include "core/math.hpp"
#define CEILDIV(x, y) (((x) + (y) - 1) / (y)) #define CEILDIV(x, y) (((x) + (y) - 1) / (y))
namespace vllm { namespace vllm {
namespace moe { namespace moe {
namespace batched_moe_align_block_size {
// Note num_threads needs to be 1024 for BlockScan Reduction in the kernel.
static constexpr int32_t num_threads = 1024;
static constexpr int32_t num_blocks = 1;
__global__ void batched_moe_align_block_size_kernel(
int32_t const num_batches, int32_t const max_tokens_per_batch,
int32_t const block_size, int32_t const* __restrict__ batch_num_tokens,
int32_t* __restrict__ sorted_ids, int32_t* __restrict__ block_ids,
int32_t* __restrict__ num_tokens_post_pad) {
// TODO(varun): This is a naive implementation. Could be optimized.
size_t const batch_id = threadIdx.x;
size_t const stride = blockDim.x * gridDim.x;
int32_t const num_blocks_per_batch =
CEILDIV(max_tokens_per_batch, block_size);
int32_t const sorted_ids_size =
num_blocks_per_batch * num_batches * block_size;
int32_t const block_ids_size = sorted_ids_size / block_size;
int32_t const SENTINEL =
num_batches * max_tokens_per_batch; // To denote invalid entries.
// Intialize sorted_ids
for (size_t i = threadIdx.x; i < sorted_ids_size; i += stride) {
sorted_ids[i] = SENTINEL;
}
// Intialize expert_ids with -1
for (size_t i = threadIdx.x; i < block_ids_size; i += stride) {
block_ids[i] = -1;
}
int32_t b_num_tokens = 0;
if (batch_id < num_batches) {
b_num_tokens = batch_num_tokens[batch_id];
}
int32_t const ceil_b_num_tokens =
CEILDIV(b_num_tokens, block_size) * block_size;
// Compute prefix sum over token counts per expert
using BlockScan = cub::BlockScan<int32_t, 1024>;
__shared__ typename BlockScan::TempStorage temp_storage;
int cumsum_val;
BlockScan(temp_storage).ExclusiveSum(ceil_b_num_tokens, cumsum_val);
__syncthreads();
bool const is_last_batch = batch_id == (num_batches - 1);
if (is_last_batch) {
*num_tokens_post_pad = cumsum_val + ceil_b_num_tokens;
}
if (batch_id < num_batches) {
int32_t const batch_offset = batch_id * max_tokens_per_batch;
for (size_t i = 0; i < b_num_tokens; ++i) {
sorted_ids[cumsum_val + i] = batch_offset + i;
}
int32_t const block_start = cumsum_val / block_size;
int32_t const num_blocks = ceil_b_num_tokens / block_size;
for (size_t i = 0; i < num_blocks; ++i) {
block_ids[block_start + i] = batch_id;
}
}
}
} // namespace batched_moe_align_block_size
template <typename scalar_t> template <typename scalar_t>
__global__ void moe_align_block_size_kernel( __global__ void moe_align_block_size_kernel(
const scalar_t* __restrict__ topk_ids, const scalar_t* __restrict__ topk_ids,
@ -345,33 +280,6 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
}); });
} }
void batched_moe_align_block_size(int64_t max_tokens_per_batch,
int64_t block_size,
torch::Tensor const& batch_num_tokens,
torch::Tensor sorted_ids,
torch::Tensor batch_ids,
torch::Tensor num_tokens_post_pad) {
namespace batched_kernel = vllm::moe::batched_moe_align_block_size;
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
int32_t const B = batch_num_tokens.size(0);
int32_t const num_blocks_per_batch =
round_to_next_multiple_of(max_tokens_per_batch, block_size) / block_size;
int32_t const num_blocks = num_blocks_per_batch * B;
int64_t const sorted_ids_size = num_blocks * block_size;
TORCH_CHECK(sorted_ids.size(0) == sorted_ids_size);
TORCH_CHECK(batch_ids.size(0) == sorted_ids_size / block_size);
TORCH_CHECK(num_tokens_post_pad.size(0) == 1);
TORCH_CHECK(B <= batched_kernel::num_threads);
batched_kernel::batched_moe_align_block_size_kernel<<<
batched_kernel::num_blocks, batched_kernel::num_threads, 0, stream>>>(
B, max_tokens_per_batch, block_size, batch_num_tokens.data_ptr<int32_t>(),
sorted_ids.data_ptr<int32_t>(), batch_ids.data_ptr<int32_t>(),
num_tokens_post_pad.data_ptr<int32_t>());
}
void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size] void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size]
torch::Tensor& output) // [num_tokens, hidden_size] torch::Tensor& output) // [num_tokens, hidden_size]
{ {

View File

@ -1,173 +0,0 @@
#include <stdio.h>
#include <stdlib.h>
#include <time.h>
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/ATen.h>
#include <ATen/cuda/Atomic.cuh>
#include "../cuda_compat.h"
#include "../dispatch_utils.h"
#include "core/math.hpp"
namespace {
__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row,
int32_t col) {
return row * total_col + col;
}
} // namespace
// TODO: Refactor common parts with moe_align_sum_kernels
template <typename scalar_t, typename token_cnts_t>
__global__ void moe_lora_align_sum_kernel(
scalar_t* __restrict__ topk_ids, int32_t* token_lora_mapping,
int64_t block_size, int num_experts, int max_loras, size_t numel,
int max_num_tokens_padded, int max_num_m_blocks,
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids,
int topk_num, int32_t* total_tokens_post_pad) {
const size_t tokens_per_thread = div_ceil(numel, blockDim.x);
const size_t start_idx = threadIdx.x * tokens_per_thread;
int lora_id = blockIdx.x;
extern __shared__ int32_t shared_mem[];
int32_t* cumsum = shared_mem;
token_cnts_t* tokens_cnts = (token_cnts_t*)(shared_mem + num_experts + 1);
// Initialize sorted_token_ids with numel
for (size_t it = threadIdx.x; it < max_num_tokens_padded; it += blockDim.x) {
sorted_token_ids[lora_id * max_num_tokens_padded + it] = numel;
}
// Initialize expert_ids with -1
for (size_t it = threadIdx.x; it < max_num_m_blocks; it += blockDim.x) {
expert_ids[lora_id * max_num_m_blocks + it] = -1;
}
// Initialize total_tokens_post_pad with 0
if (threadIdx.x == 0) {
total_tokens_post_pad[lora_id] = 0;
}
for (int i = 0; i < num_experts; ++i) {
tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;
}
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
int mask = token_lora_mapping[i / topk_num] == lora_id;
int idx = index(num_experts, threadIdx.x + 1, topk_ids[i]);
tokens_cnts[idx] += mask;
}
__syncthreads();
// For each expert we accumulate the token counts from the different threads.
if (threadIdx.x < num_experts) {
tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0;
for (int i = 1; i <= blockDim.x; ++i) {
tokens_cnts[index(num_experts, i, threadIdx.x)] +=
tokens_cnts[index(num_experts, i - 1, threadIdx.x)];
}
}
__syncthreads();
// We accumulate the token counts of all experts in thread 0.
if (threadIdx.x == 0) {
cumsum[0] = 0;
for (int i = 1; i <= num_experts; ++i) {
cumsum[i] = cumsum[i - 1] +
div_ceil(tokens_cnts[index(num_experts, blockDim.x, i - 1)],
block_size) *
block_size;
}
total_tokens_post_pad[lora_id] = static_cast<int32_t>(cumsum[num_experts]);
}
__syncthreads();
/**
* For each expert, each thread processes the tokens of the corresponding
* blocks and stores the corresponding expert_id for each block.
*/
if (threadIdx.x < num_experts) {
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
i += block_size) {
expert_ids[index(max_num_m_blocks, lora_id, i / block_size)] =
threadIdx.x;
}
}
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
int32_t expert_id = topk_ids[i];
/** The cumsum[expert_id] stores the starting index of the tokens that the
* expert with expert_id needs to process, and
* tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens
* processed by the expert with expert_id within the current thread's token
* shard.
*/
int32_t rank_post_pad =
tokens_cnts[index(num_experts, threadIdx.x, expert_id)] +
cumsum[expert_id];
int mask = (int)token_lora_mapping[i / topk_num] == lora_id;
atomicAdd(
&sorted_token_ids[index(max_num_tokens_padded, lora_id, rank_post_pad)],
(i - numel) * mask);
tokens_cnts[index(num_experts, threadIdx.x, expert_id)] += mask;
}
}
void moe_lora_align_block_size(torch::Tensor topk_ids,
torch::Tensor token_lora_mapping,
int64_t num_experts, int64_t block_size,
int64_t max_loras,
torch::Tensor sorted_token_ids,
torch::Tensor expert_ids,
torch::Tensor num_tokens_post_pad) {
const int topk_num = topk_ids.size(1);
int max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1);
TORCH_CHECK(block_size > 0, "block_size should be greater than 0. ");
max_num_tokens_padded = round_to_next_multiple_of(
max_num_tokens_padded, static_cast<int>(block_size));
int max_num_m_blocks = div_ceil(max_num_tokens_padded, block_size);
int device_max_shared_mem;
auto dev = topk_ids.get_device();
cudaDeviceGetAttribute(&device_max_shared_mem,
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int32_t num_thread = max((int32_t)num_experts, 128); // WARP_SIZE,
TORCH_CHECK(num_thread <= 1024,
"num_thread must be less than 1024, "
"and fallback is not implemented yet.");
const int32_t shared_mem = (num_thread + 1) * num_experts * sizeof(int32_t) +
(num_experts + 1) * sizeof(int32_t);
if (shared_mem > device_max_shared_mem) {
TORCH_CHECK(false,
"Shared memory usage exceeds device limit, and global memory "
"fallback is not implemented yet.");
}
VLLM_DISPATCH_INTEGRAL_TYPES(
topk_ids.scalar_type(), "moe_lora_align_sum_kernel", [&] {
dim3 blockDim(num_thread);
auto kernel = moe_lora_align_sum_kernel<scalar_t, int32_t>;
AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
(void*)kernel, shared_mem));
kernel<<<max_loras, blockDim, shared_mem, stream>>>(
topk_ids.data_ptr<scalar_t>(),
token_lora_mapping.data_ptr<int32_t>(), block_size, num_experts,
max_loras, topk_ids.numel(), max_num_tokens_padded,
max_num_m_blocks, sorted_token_ids.data_ptr<int32_t>(),
expert_ids.data_ptr<int32_t>(), topk_num,
num_tokens_post_pad.data_ptr<int32_t>());
});
}

View File

@ -4,7 +4,7 @@
void topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices, void topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices,
torch::Tensor& token_expert_indices, torch::Tensor& token_expert_indices,
torch::Tensor& gating_output, bool renormalize); torch::Tensor& gating_output);
void moe_sum(torch::Tensor& input, torch::Tensor& output); void moe_sum(torch::Tensor& input, torch::Tensor& output);
@ -12,21 +12,6 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
int64_t block_size, torch::Tensor sorted_token_ids, int64_t block_size, torch::Tensor sorted_token_ids,
torch::Tensor experts_ids, torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad); torch::Tensor num_tokens_post_pad);
void batched_moe_align_block_size(int64_t max_tokens_per_batch,
int64_t block_size,
torch::Tensor const& expert_num_tokens,
torch::Tensor sorted_ids,
torch::Tensor expert_ids,
torch::Tensor num_tokens_post_pad);
void moe_lora_align_block_size(torch::Tensor topk_ids,
torch::Tensor token_lora_mapping,
int64_t num_experts, int64_t block_size,
int64_t max_loras,
torch::Tensor sorted_token_ids,
torch::Tensor expert_ids,
torch::Tensor num_tokens_post_pad);
#ifndef USE_ROCM #ifndef USE_ROCM
torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output, torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
torch::Tensor b_qweight, torch::Tensor b_scales, torch::Tensor b_qweight, torch::Tensor b_scales,

View File

@ -16,22 +16,12 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include <type_traits>
#include <torch/all.h> #include <torch/all.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include "../cuda_compat.h" #include "../cuda_compat.h"
#include "../cub_helpers.h" #include "../cub_helpers.h"
#include "../core/batch_invariant.hpp"
#ifndef USE_ROCM
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#else
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>
typedef __hip_bfloat16 __nv_bfloat16;
typedef __hip_bfloat162 __nv_bfloat162;
#endif
#define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b))
@ -47,27 +37,16 @@ template <
/// Alignment requirement in bytes /// Alignment requirement in bytes
int Alignment = sizeof(T) * N int Alignment = sizeof(T) * N
> >
struct alignas(Alignment) AlignedArray { class alignas(Alignment) AlignedArray {
T data[N]; float data[N];
}; };
template <typename T>
__device__ __forceinline__ float toFloat(T value) {
if constexpr (std::is_same_v<T, float>) {
return value;
} else if constexpr (std::is_same_v<T, __nv_bfloat16>) {
return __bfloat162float(value);
} else if constexpr (std::is_same_v<T, __half>) {
return __half2float(value);
}
}
// ====================== Softmax things =============================== // ====================== Softmax things ===============================
// We have our own implementation of softmax here so we can support transposing the output // We have our own implementation of softmax here so we can support transposing the output
// in the softmax kernel when we extend this module to support expert-choice routing. // in the softmax kernel when we extend this module to support expert-choice routing.
template <int TPB, typename InputType> template <int TPB>
__launch_bounds__(TPB) __global__ __launch_bounds__(TPB) __global__
void moeSoftmax(const InputType* input, const bool* finished, float* output, const int num_cols) void moeSoftmax(const float* input, const bool* finished, float* output, const int num_cols)
{ {
using BlockReduce = cub::BlockReduce<float, TPB>; using BlockReduce = cub::BlockReduce<float, TPB>;
__shared__ typename BlockReduce::TempStorage tmpStorage; __shared__ typename BlockReduce::TempStorage tmpStorage;
@ -88,8 +67,7 @@ __launch_bounds__(TPB) __global__
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) for (int ii = threadIdx.x; ii < num_cols; ii += TPB)
{ {
const int idx = thread_row_offset + ii; const int idx = thread_row_offset + ii;
const float val = toFloat(input[idx]); threadData = max(static_cast<float>(input[idx]), threadData);
threadData = max(val, threadData);
} }
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, CubMaxOp()); const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, CubMaxOp());
@ -104,8 +82,7 @@ __launch_bounds__(TPB) __global__
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) for (int ii = threadIdx.x; ii < num_cols; ii += TPB)
{ {
const int idx = thread_row_offset + ii; const int idx = thread_row_offset + ii;
const float val = toFloat(input[idx]); threadData += exp((static_cast<float>(input[idx]) - float_max));
threadData += expf(val - float_max);
} }
const auto Z = BlockReduce(tmpStorage).Reduce(threadData, CubAddOp()); const auto Z = BlockReduce(tmpStorage).Reduce(threadData, CubAddOp());
@ -119,9 +96,8 @@ __launch_bounds__(TPB) __global__
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) for (int ii = threadIdx.x; ii < num_cols; ii += TPB)
{ {
const int idx = thread_row_offset + ii; const int idx = thread_row_offset + ii;
const float val = toFloat(input[idx]); const float val = exp((static_cast<float>(input[idx]) - float_max)) * normalizing_factor;
const float softmax_val = expf(val - float_max) * normalizing_factor; output[idx] = val;
output[idx] = softmax_val;
} }
} }
@ -135,8 +111,7 @@ __launch_bounds__(TPB) __global__ void moeTopK(
const int num_experts, const int num_experts,
const int k, const int k,
const int start_expert, const int start_expert,
const int end_expert, const int end_expert)
const bool renormalize)
{ {
using cub_kvp = cub::KeyValuePair<int, float>; using cub_kvp = cub::KeyValuePair<int, float>;
@ -151,7 +126,6 @@ __launch_bounds__(TPB) __global__ void moeTopK(
const bool row_is_active = finished ? !finished[block_row] : true; const bool row_is_active = finished ? !finished[block_row] : true;
const int thread_read_offset = blockIdx.x * num_experts; const int thread_read_offset = blockIdx.x * num_experts;
float selected_sum = 0.f;
for (int k_idx = 0; k_idx < k; ++k_idx) for (int k_idx = 0; k_idx < k; ++k_idx)
{ {
thread_kvp.key = 0; thread_kvp.key = 0;
@ -190,23 +164,9 @@ __launch_bounds__(TPB) __global__ void moeTopK(
indices[idx] = should_process_row ? (expert - start_expert) : num_experts; indices[idx] = should_process_row ? (expert - start_expert) : num_experts;
assert(indices[idx] >= 0); assert(indices[idx] >= 0);
source_rows[idx] = k_idx * num_rows + block_row; source_rows[idx] = k_idx * num_rows + block_row;
if (renormalize) {
selected_sum += result_kvp.value;
}
} }
__syncthreads(); __syncthreads();
} }
// Renormalize the k weights for this row to sum to 1, if requested.
if (renormalize) {
if (threadIdx.x == 0) {
const float denom = selected_sum > 0.f ? selected_sum : 1.f;
for (int k_idx = 0; k_idx < k; ++k_idx) {
const int idx = k * block_row + k_idx;
output[idx] = output[idx] / denom;
}
}
}
} }
// ====================== TopK softmax things =============================== // ====================== TopK softmax things ===============================
@ -225,30 +185,21 @@ __launch_bounds__(TPB) __global__ void moeTopK(
2) This implementation assumes k is small, but will work for any k. 2) This implementation assumes k is small, but will work for any k.
*/ */
template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG, int WARP_SIZE_PARAM, typename IndType, typename InputType = float> template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG, int WARP_SIZE_PARAM, typename IndType>
__launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__
void topkGatingSoftmax(const InputType* input, const bool* finished, float* output, const int num_rows, IndType* indices, void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, IndType* indices,
int* source_rows, const int k, const int start_expert, const int end_expert, const bool renormalize) int* source_rows, const int k, const int start_expert, const int end_expert)
{ {
static_assert(std::is_same_v<InputType, float> || std::is_same_v<InputType, __nv_bfloat16> ||
std::is_same_v<InputType, __half>,
"InputType must be float, __nv_bfloat16, or __half");
// We begin by enforcing compile time assertions and setting up compile time constants. // We begin by enforcing compile time assertions and setting up compile time constants.
static_assert(BYTES_PER_LDG == (BYTES_PER_LDG & -BYTES_PER_LDG), "BYTES_PER_LDG must be power of 2"); static_assert(BYTES_PER_LDG == (BYTES_PER_LDG & -BYTES_PER_LDG), "BYTES_PER_LDG must be power of 2");
static_assert(BYTES_PER_LDG <= 16, "BYTES_PER_LDG must be leq 16"); static_assert(BYTES_PER_LDG <= 16, "BYTES_PER_LDG must be leq 16");
// Number of bytes each thread pulls in per load // Number of bytes each thread pulls in per load
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(InputType); static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float);
static constexpr int ELTS_PER_ROW = NUM_EXPERTS; static constexpr int ELTS_PER_ROW = NUM_EXPERTS;
static constexpr int THREADS_PER_ROW = ELTS_PER_ROW / VPT; static constexpr int THREADS_PER_ROW = ELTS_PER_ROW / VPT;
static constexpr int LDG_PER_THREAD = VPT / ELTS_PER_LDG; static constexpr int LDG_PER_THREAD = VPT / ELTS_PER_LDG;
if constexpr (std::is_same_v<InputType, __nv_bfloat16> || std::is_same_v<InputType, __half>) {
static_assert(ELTS_PER_LDG == 1 || ELTS_PER_LDG % 2 == 0,
"ELTS_PER_LDG must be 1 or even for 16-bit conversion");
}
// Restrictions based on previous section. // Restrictions based on previous section.
static_assert(VPT % ELTS_PER_LDG == 0, "The elements per thread must be a multiple of the elements per ldg"); static_assert(VPT % ELTS_PER_LDG == 0, "The elements per thread must be a multiple of the elements per ldg");
static_assert(WARP_SIZE_PARAM % THREADS_PER_ROW == 0, "The threads per row must cleanly divide the threads per warp"); static_assert(WARP_SIZE_PARAM % THREADS_PER_ROW == 0, "The threads per row must cleanly divide the threads per warp");
@ -286,71 +237,27 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__
// We finally start setting up the read pointers for each thread. First, each thread jumps to the start of the // We finally start setting up the read pointers for each thread. First, each thread jumps to the start of the
// row it will read. // row it will read.
const InputType* thread_row_ptr = input + thread_row * ELTS_PER_ROW; const float* thread_row_ptr = input + thread_row * ELTS_PER_ROW;
// Now, we compute the group each thread belong to in order to determine the first column to start loads. // Now, we compute the group each thread belong to in order to determine the first column to start loads.
const int thread_group_idx = threadIdx.x % THREADS_PER_ROW; const int thread_group_idx = threadIdx.x % THREADS_PER_ROW;
const int first_elt_read_by_thread = thread_group_idx * ELTS_PER_LDG; const int first_elt_read_by_thread = thread_group_idx * ELTS_PER_LDG;
const InputType* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread; const float* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread;
// Determine the pointer type to use to read in the data depending on the BYTES_PER_LDG template param. In theory,
// this can support all powers of 2 up to 16.
// NOTE(woosuk): The original implementation uses CUTLASS aligned array here.
// We defined our own aligned array and use it here to avoid the dependency on CUTLASS.
using AccessType = AlignedArray<float, ELTS_PER_LDG>;
// Finally, we pull in the data from global mem // Finally, we pull in the data from global mem
float row_chunk[VPT]; float row_chunk[VPT];
AccessType* row_chunk_vec_ptr = reinterpret_cast<AccessType*>(&row_chunk);
// NOTE(zhuhaoran): dispatch different input types loading, BF16/FP16 convert to float const AccessType* vec_thread_read_ptr = reinterpret_cast<const AccessType*>(thread_read_ptr);
if constexpr (std::is_same_v<InputType, float>) {
using VecType = AlignedArray<float, ELTS_PER_LDG>;
VecType* row_chunk_vec_ptr = reinterpret_cast<VecType*>(&row_chunk);
const VecType* vec_thread_read_ptr = reinterpret_cast<const VecType*>(thread_read_ptr);
#pragma unroll #pragma unroll
for (int ii = 0; ii < LDG_PER_THREAD; ++ii) { for (int ii = 0; ii < LDG_PER_THREAD; ++ii)
row_chunk_vec_ptr[ii] = vec_thread_read_ptr[ii * THREADS_PER_ROW]; {
} row_chunk_vec_ptr[ii] = vec_thread_read_ptr[ii * THREADS_PER_ROW];
} else if constexpr (std::is_same_v<InputType, __nv_bfloat16>) {
if constexpr (ELTS_PER_LDG >= 2) {
using VecType = AlignedArray<__nv_bfloat16, ELTS_PER_LDG>;
float2* row_chunk_f2 = reinterpret_cast<float2*>(row_chunk);
const VecType* vec_thread_read_ptr = reinterpret_cast<const VecType*>(thread_read_ptr);
#pragma unroll
for (int ii = 0; ii < LDG_PER_THREAD; ++ii) {
VecType vec = vec_thread_read_ptr[ii * THREADS_PER_ROW];
int base_idx_f2 = ii * ELTS_PER_LDG / 2;
#pragma unroll
for (int jj = 0; jj < ELTS_PER_LDG / 2; ++jj) {
row_chunk_f2[base_idx_f2 + jj] = __bfloat1622float2(
*reinterpret_cast<const __nv_bfloat162*>(vec.data + jj * 2)
);
}
}
} else { // ELTS_PER_LDG == 1
#pragma unroll
for (int ii = 0; ii < LDG_PER_THREAD; ++ii) {
const __nv_bfloat16* scalar_ptr = thread_read_ptr + ii * THREADS_PER_ROW;
row_chunk[ii] = __bfloat162float(*scalar_ptr);
}
}
} else if constexpr (std::is_same_v<InputType, __half>) {
if constexpr (ELTS_PER_LDG >= 2) {
using VecType = AlignedArray<__half, ELTS_PER_LDG>;
float2* row_chunk_f2 = reinterpret_cast<float2*>(row_chunk);
const VecType* vec_thread_read_ptr = reinterpret_cast<const VecType*>(thread_read_ptr);
#pragma unroll
for (int ii = 0; ii < LDG_PER_THREAD; ++ii) {
VecType vec = vec_thread_read_ptr[ii * THREADS_PER_ROW];
int base_idx_f2 = ii * ELTS_PER_LDG / 2;
#pragma unroll
for (int jj = 0; jj < ELTS_PER_LDG / 2; ++jj) {
row_chunk_f2[base_idx_f2 + jj] = __half22float2(
*reinterpret_cast<const __half2*>(vec.data + jj * 2)
);
}
}
} else { // ELTS_PER_LDG == 1
#pragma unroll
for (int ii = 0; ii < LDG_PER_THREAD; ++ii) {
const __half* scalar_ptr = thread_read_ptr + ii * THREADS_PER_ROW;
row_chunk[ii] = __half2float(*scalar_ptr);
}
}
} }
// First, we perform a max reduce within the thread. We can do the max in fp16 safely (I think) and just // First, we perform a max reduce within the thread. We can do the max in fp16 safely (I think) and just
@ -404,7 +311,6 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__
int start_col = first_elt_read_by_thread; int start_col = first_elt_read_by_thread;
static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW; static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW;
float selected_sum = 0.f;
for (int k_idx = 0; k_idx < k; ++k_idx) for (int k_idx = 0; k_idx < k; ++k_idx)
{ {
// First, each thread does the local argmax // First, each thread does the local argmax
@ -458,9 +364,6 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__
output[idx] = max_val; output[idx] = max_val;
indices[idx] = should_process_row ? (expert - start_expert) : NUM_EXPERTS; indices[idx] = should_process_row ? (expert - start_expert) : NUM_EXPERTS;
source_rows[idx] = k_idx * num_rows + thread_row; source_rows[idx] = k_idx * num_rows + thread_row;
if (renormalize) {
selected_sum += max_val;
}
} }
// Finally, we clear the value in the thread with the current max if there is another iteration to run. // Finally, we clear the value in the thread with the current max if there is another iteration to run.
@ -478,28 +381,15 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__
} }
} }
} }
// Renormalize the k weights for this row to sum to 1, if requested.
if (renormalize) {
if (thread_group_idx == 0)
{
const float denom = selected_sum > 0.f ? selected_sum : 1.f;
for (int k_idx = 0; k_idx < k; ++k_idx)
{
const int idx = k * thread_row + k_idx;
output[idx] = output[idx] / denom;
}
}
}
} }
namespace detail namespace detail
{ {
// Constructs some constants needed to partition the work across threads at compile time. // Constructs some constants needed to partition the work across threads at compile time.
template <int EXPERTS, int BYTES_PER_LDG, int WARP_SIZE_PARAM, typename InputType> template <int EXPERTS, int BYTES_PER_LDG, int WARP_SIZE_PARAM>
struct TopkConstants struct TopkConstants
{ {
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(InputType); static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float);
static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE_PARAM) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE_PARAM) == 0, ""); static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE_PARAM) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE_PARAM) == 0, "");
static constexpr int VECs_PER_THREAD = MAX(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE_PARAM)); static constexpr int VECs_PER_THREAD = MAX(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE_PARAM));
static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG; static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG;
@ -508,21 +398,21 @@ struct TopkConstants
}; };
} // namespace detail } // namespace detail
template <int EXPERTS, int WARPS_PER_TB, int WARP_SIZE_PARAM, int MAX_BYTES_PER_LDG, typename IndType, typename InputType> template <int EXPERTS, int WARPS_PER_TB, int WARP_SIZE_PARAM, int MAX_BYTES_PER_LDG, typename IndType>
void topkGatingSoftmaxLauncherHelper(const InputType* input, const bool* finished, float* output, IndType* indices, void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, IndType* indices,
int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, const bool renormalize, int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, cudaStream_t stream)
cudaStream_t stream)
{ {
static constexpr int BYTES_PER_LDG = MIN(MAX_BYTES_PER_LDG, sizeof(InputType) * EXPERTS); static constexpr int BYTES_PER_LDG = MIN(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS);
using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG, WARP_SIZE_PARAM, InputType>; using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG, WARP_SIZE_PARAM>;
static constexpr int VPT = Constants::VPT; static constexpr int VPT = Constants::VPT;
static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP; static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP;
const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP; const bool batch_invariant_launch = vllm::vllm_kernel_override_batch_invariant();
const int num_warps = batch_invariant_launch ? 32 : (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP;
const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB; const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB;
dim3 block_dim(WARP_SIZE_PARAM, WARPS_PER_TB); dim3 block_dim(WARP_SIZE_PARAM, WARPS_PER_TB);
topkGatingSoftmax<VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG, WARP_SIZE_PARAM, IndType, InputType><<<num_blocks, block_dim, 0, stream>>>( topkGatingSoftmax<VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG, WARP_SIZE_PARAM><<<num_blocks, block_dim, 0, stream>>>(
input, finished, output, num_rows, indices, source_row, k, start_expert, end_expert, renormalize); input, finished, output, num_rows, indices, source_row, k, start_expert, end_expert);
} }
#ifndef USE_ROCM #ifndef USE_ROCM
@ -530,26 +420,26 @@ void topkGatingSoftmaxLauncherHelper(const InputType* input, const bool* finishe
static_assert(WARP_SIZE == 32, \ static_assert(WARP_SIZE == 32, \
"Unsupported warp size. Only 32 is supported for CUDA"); \ "Unsupported warp size. Only 32 is supported for CUDA"); \
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, WARP_SIZE, MAX_BYTES>( \ topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, WARP_SIZE, MAX_BYTES>( \
gating_output, nullptr, topk_weights, topk_indices, token_expert_indices, \ gating_output, nullptr, topk_weights, topk_indices, \
num_tokens, topk, 0, num_experts, renormalize, stream); token_expert_indices, num_tokens, topk, 0, num_experts, stream);
#else #else
#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB, MAX_BYTES) \ #define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB, MAX_BYTES) \
if (WARP_SIZE == 64) { \ if (WARP_SIZE == 64) { \
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, 64, MAX_BYTES>( \ topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, 64, MAX_BYTES>( \
gating_output, nullptr, topk_weights, topk_indices, token_expert_indices, \ gating_output, nullptr, topk_weights, topk_indices, \
num_tokens, topk, 0, num_experts, renormalize, stream); \ token_expert_indices, num_tokens, topk, 0, num_experts, stream); \
} else if (WARP_SIZE == 32) { \ } else if (WARP_SIZE == 32) { \
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, 32, MAX_BYTES>( \ topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, 32, MAX_BYTES>( \
gating_output, nullptr, topk_weights, topk_indices, token_expert_indices, \ gating_output, nullptr, topk_weights, topk_indices, \
num_tokens, topk, 0, num_experts, renormalize, stream); \ token_expert_indices, num_tokens, topk, 0, num_experts, stream); \
} else { \ } else { \
assert(false && "Unsupported warp size. Only 32 and 64 are supported for ROCm"); \ assert(false && "Unsupported warp size. Only 32 and 64 are supported for ROCm"); \
} }
#endif #endif
template <typename IndType, typename InputType> template <typename IndType>
void topkGatingSoftmaxKernelLauncher( void topkGatingSoftmaxKernelLauncher(
const InputType* gating_output, const float* gating_output,
float* topk_weights, float* topk_weights,
IndType* topk_indices, IndType* topk_indices,
int* token_expert_indices, int* token_expert_indices,
@ -557,15 +447,11 @@ void topkGatingSoftmaxKernelLauncher(
const int num_tokens, const int num_tokens,
const int num_experts, const int num_experts,
const int topk, const int topk,
const bool renormalize,
cudaStream_t stream) { cudaStream_t stream) {
static constexpr int WARPS_PER_TB = 4; static constexpr int WARPS_PER_TB = 4;
static constexpr int BYTES_PER_LDG_POWER_OF_2 = 16; static constexpr int BYTES_PER_LDG_POWER_OF_2 = 16;
#ifndef USE_ROCM #ifndef USE_ROCM
// for bfloat16 dtype, we need 4 bytes loading to make sure num_experts static constexpr int BYTES_PER_LDG_MULTIPLE_64 = 8;
// elements can be loaded by a warp
static constexpr int BYTES_PER_LDG_MULTIPLE_64 =
(std::is_same_v<InputType, __nv_bfloat16> || std::is_same_v<InputType, __half>) ? 4 : 8;
#endif #endif
switch (num_experts) { switch (num_experts) {
case 1: case 1:
@ -622,11 +508,11 @@ void topkGatingSoftmaxKernelLauncher(
TORCH_CHECK(softmax_workspace != nullptr, TORCH_CHECK(softmax_workspace != nullptr,
"softmax_workspace must be provided for num_experts that are not a power of 2 or multiple of 64."); "softmax_workspace must be provided for num_experts that are not a power of 2 or multiple of 64.");
static constexpr int TPB = 256; static constexpr int TPB = 256;
moeSoftmax<TPB, InputType><<<num_tokens, TPB, 0, stream>>>( moeSoftmax<TPB><<<num_tokens, TPB, 0, stream>>>(
gating_output, nullptr, softmax_workspace, num_experts); gating_output, nullptr, softmax_workspace, num_experts);
moeTopK<TPB><<<num_tokens, TPB, 0, stream>>>( moeTopK<TPB><<<num_tokens, TPB, 0, stream>>>(
softmax_workspace, nullptr, topk_weights, topk_indices, token_expert_indices, softmax_workspace, nullptr, topk_weights, topk_indices, token_expert_indices,
num_experts, topk, 0, num_experts, renormalize); num_experts, topk, 0, num_experts);
} }
} }
} }
@ -634,50 +520,11 @@ void topkGatingSoftmaxKernelLauncher(
} // namespace moe } // namespace moe
} // namespace vllm } // namespace vllm
template<typename ComputeType>
void dispatch_topk_softmax_launch(
torch::Tensor& gating_output,
torch::Tensor& topk_weights,
torch::Tensor& topk_indices,
torch::Tensor& token_expert_indices,
torch::Tensor& softmax_workspace,
int num_tokens, int num_experts, int topk, bool renormalize, cudaStream_t stream)
{
if (topk_indices.scalar_type() == at::ScalarType::Int) {
vllm::moe::topkGatingSoftmaxKernelLauncher<int, ComputeType>(
reinterpret_cast<const ComputeType*>(gating_output.data_ptr()),
topk_weights.data_ptr<float>(),
topk_indices.data_ptr<int>(),
token_expert_indices.data_ptr<int>(),
softmax_workspace.data_ptr<float>(),
num_tokens, num_experts, topk, renormalize, stream);
} else if (topk_indices.scalar_type() == at::ScalarType::UInt32) {
vllm::moe::topkGatingSoftmaxKernelLauncher<uint32_t, ComputeType>(
reinterpret_cast<const ComputeType*>(gating_output.data_ptr()),
topk_weights.data_ptr<float>(),
topk_indices.data_ptr<uint32_t>(),
token_expert_indices.data_ptr<int>(),
softmax_workspace.data_ptr<float>(),
num_tokens, num_experts, topk, renormalize, stream);
} else {
TORCH_CHECK(topk_indices.scalar_type() == at::ScalarType::Long);
vllm::moe::topkGatingSoftmaxKernelLauncher<int64_t, ComputeType>(
reinterpret_cast<const ComputeType*>(gating_output.data_ptr()),
topk_weights.data_ptr<float>(),
topk_indices.data_ptr<int64_t>(),
token_expert_indices.data_ptr<int>(),
softmax_workspace.data_ptr<float>(),
num_tokens, num_experts, topk, renormalize, stream);
}
}
void topk_softmax( void topk_softmax(
torch::Tensor& topk_weights, // [num_tokens, topk] torch::Tensor& topk_weights, // [num_tokens, topk]
torch::Tensor& topk_indices, // [num_tokens, topk] torch::Tensor& topk_indices, // [num_tokens, topk]
torch::Tensor& token_expert_indices, // [num_tokens, topk] torch::Tensor& token_expert_indices, // [num_tokens, topk]
torch::Tensor& gating_output, // [num_tokens, num_experts] torch::Tensor& gating_output) // [num_tokens, num_experts]
bool renormalize)
{ {
const int num_experts = gating_output.size(-1); const int num_experts = gating_output.size(-1);
const auto num_tokens = gating_output.numel() / num_experts; const auto num_tokens = gating_output.numel() / num_experts;
@ -689,19 +536,45 @@ void topk_softmax(
const at::cuda::OptionalCUDAGuard device_guard(device_of(gating_output)); const at::cuda::OptionalCUDAGuard device_guard(device_of(gating_output));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const auto workspace_options = gating_output.options().dtype(at::ScalarType::Float); torch::Tensor softmax_workspace = torch::empty({workspace_size}, gating_output.options());
torch::Tensor softmax_workspace = torch::empty({workspace_size}, workspace_options);
if (gating_output.scalar_type() == at::ScalarType::Float) { if(topk_indices.scalar_type() == at::ScalarType::Int)
dispatch_topk_softmax_launch<float>(gating_output, topk_weights, topk_indices, {
token_expert_indices, softmax_workspace, num_tokens, num_experts, topk, renormalize, stream); vllm::moe::topkGatingSoftmaxKernelLauncher(
} else if (gating_output.scalar_type() == at::ScalarType::Half) { gating_output.data_ptr<float>(),
dispatch_topk_softmax_launch<__half>(gating_output, topk_weights, topk_indices, topk_weights.data_ptr<float>(),
token_expert_indices, softmax_workspace, num_tokens, num_experts, topk, renormalize, stream); topk_indices.data_ptr<int>(),
} else if (gating_output.scalar_type() == at::ScalarType::BFloat16) { token_expert_indices.data_ptr<int>(),
dispatch_topk_softmax_launch<__nv_bfloat16>(gating_output, topk_weights, topk_indices, softmax_workspace.data_ptr<float>(),
token_expert_indices, softmax_workspace, num_tokens, num_experts, topk, renormalize, stream); num_tokens,
} else { num_experts,
TORCH_CHECK(false, "Unsupported gating_output data type: ", gating_output.scalar_type()); topk,
stream);
}
else if (topk_indices.scalar_type() == at::ScalarType::UInt32)
{
vllm::moe::topkGatingSoftmaxKernelLauncher(
gating_output.data_ptr<float>(),
topk_weights.data_ptr<float>(),
topk_indices.data_ptr<uint32_t>(),
token_expert_indices.data_ptr<int>(),
softmax_workspace.data_ptr<float>(),
num_tokens,
num_experts,
topk,
stream);
}
else {
TORCH_CHECK(topk_indices.scalar_type() == at::ScalarType::Long);
vllm::moe::topkGatingSoftmaxKernelLauncher(
gating_output.data_ptr<float>(),
topk_weights.data_ptr<float>(),
topk_indices.data_ptr<int64_t>(),
token_expert_indices.data_ptr<int>(),
softmax_workspace.data_ptr<float>(),
num_tokens,
num_experts,
topk,
stream);
} }
} }

View File

@ -5,7 +5,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
// Apply topk softmax to the gating outputs. // Apply topk softmax to the gating outputs.
m.def( m.def(
"topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! " "topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! "
"token_expert_indices, Tensor gating_output, bool renormalize) -> ()"); "token_expert_indices, Tensor gating_output) -> ()");
m.impl("topk_softmax", torch::kCUDA, &topk_softmax); m.impl("topk_softmax", torch::kCUDA, &topk_softmax);
// Calculate the result of moe by summing up the partial results // Calculate the result of moe by summing up the partial results
@ -22,29 +22,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
" Tensor! num_tokens_post_pad) -> ()"); " Tensor! num_tokens_post_pad) -> ()");
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size); m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
// Aligning the number of tokens to be processed by each expert such
// that it is divisible by the block size, but for the batched case.
m.def(
"batched_moe_align_block_size(int max_tokens_per_batch,"
" int block_size, Tensor expert_num_tokens,"
" Tensor! sorted_token_ids,"
" Tensor! experts_ids,"
" Tensor! num_tokens_post_pad) -> ()");
m.impl("batched_moe_align_block_size", torch::kCUDA,
&batched_moe_align_block_size);
// Aligning the number of tokens to be processed by each expert such
// that it is divisible by the block size.
m.def(
"moe_lora_align_block_size(Tensor topk_ids,"
" Tensor token_lora_mapping,"
" int num_experts,"
" int block_size, int max_loras, "
" Tensor !sorted_token_ids,"
" Tensor !experts_ids,"
" Tensor !num_tokens_post_pad) -> () ");
m.impl("moe_lora_align_block_size", torch::kCUDA, &moe_lora_align_block_size);
#ifndef USE_ROCM #ifndef USE_ROCM
m.def( m.def(
"moe_wna16_gemm(Tensor input, Tensor! output, Tensor b_qweight, " "moe_wna16_gemm(Tensor input, Tensor! output, Tensor b_qweight, "

View File

@ -92,16 +92,14 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual, void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
torch::Tensor& weight, double epsilon); torch::Tensor& weight, double epsilon);
void poly_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
torch::Tensor& bias, double epsilon);
void apply_repetition_penalties_(torch::Tensor& logits, void apply_repetition_penalties_(torch::Tensor& logits,
const torch::Tensor& prompt_mask, const torch::Tensor& prompt_mask,
const torch::Tensor& output_mask, const torch::Tensor& output_mask,
const torch::Tensor& repetition_penalties); const torch::Tensor& repetition_penalties);
void top_k_per_row(const torch::Tensor& logits, const torch::Tensor& rowStarts,
const torch::Tensor& rowEnds, torch::Tensor& indices,
torch::Tensor& values, int64_t numRows, int64_t stride0,
int64_t stride1);
void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input, void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input,
torch::Tensor& weight, torch::Tensor& scale, torch::Tensor& weight, torch::Tensor& scale,
double epsilon); double epsilon);
@ -135,12 +133,12 @@ void silu_and_mul_nvfp4_quant(torch::Tensor& out,
torch::Tensor& input, torch::Tensor& input,
torch::Tensor& input_global_scale); torch::Tensor& input_global_scale);
#endif #endif
void persistent_masked_m_silu_mul_quant( void silu_mul_fp8_quant_deep_gemm_cuda(
const at::Tensor& input, // (E, T, 2*H) const at::Tensor& input, // (E, T, 2*H)
const at::Tensor& counts, // (E) const at::Tensor& counts, // (E)
at::Tensor& y_q, // (E, T, H) [OUT] at::Tensor& y_q, // (E, T, H) [OUT]
at::Tensor& y_s, // (E, T, H//group_size) [OUT] at::Tensor& y_s, // (E, T, H//group_size) [OUT]
bool use_ue8m0); int64_t group_size, bool use_ue8m0, int64_t num_parallel_tokens);
void mul_and_silu(torch::Tensor& out, torch::Tensor& input); void mul_and_silu(torch::Tensor& out, torch::Tensor& input);

View File

@ -7,7 +7,7 @@
#include "../cuda_compat.h" #include "../cuda_compat.h"
#include "dispatch_utils.h" #include "dispatch_utils.h"
#include "quantization/w8a8/fp8/common.cuh" #include "quantization/fp8/common.cuh"
#include <c10/util/Float8_e4m3fn.h> #include <c10/util/Float8_e4m3fn.h>
@ -114,22 +114,13 @@ __global__ void act_and_mul_quant_kernel(
} }
__device__ __forceinline__ float silu(float x) { __device__ __forceinline__ float silu(float x) {
return __fdividef(x, (1.f + expf(-x))); return (__fdividef(x, (1.f + expf(-x))));
} }
__device__ __forceinline__ float2 silu2(float2 x) { __device__ __forceinline__ float2 silu2(float2 x) {
return make_float2(silu(x.x), silu(x.y)); return make_float2(silu(x.x), silu(x.y));
} }
__device__ __forceinline__ __nv_bfloat162 silu2_v2(float2 x) {
#ifndef USE_ROCM
return make_bfloat162(__float2bfloat16_rn(silu(x.x)),
__float2bfloat16_rn(silu(x.y)));
#else
return __float22bfloat162_rn(make_float2(silu(x.x), silu(x.y)));
#endif
}
#ifndef USE_ROCM #ifndef USE_ROCM
__device__ __forceinline__ float warp_max(float v) { __device__ __forceinline__ float warp_max(float v) {
static constexpr unsigned FULL_MASK = 0xffffffffu; static constexpr unsigned FULL_MASK = 0xffffffffu;
@ -232,308 +223,224 @@ constexpr __nv_bfloat16 get_fp8_min() {
return __nv_bfloat16(__nv_bfloat16_raw{.x = 50032}); return __nv_bfloat16(__nv_bfloat16_raw{.x = 50032});
} }
} }
#ifndef USE_ROCM
template <typename Idx_t> template <typename fp8_type, int32_t NUM_WARPS, typename Idx_t,
__device__ __forceinline__ int warp_expert_search( int NUM_PARALLEL_TOKENS, bool USE_UE8M0, int GROUP_SIZE = 128,
int idx, int n, const Idx_t* __restrict__ input, Idx_t val) {
const Idx_t* input_ptr = input + idx;
int base_offset = 0;
for (;;) {
bool move_on = (idx < n && *input_ptr <= val);
unsigned mask = __ballot_sync(0xffffffff, move_on);
if (mask != 0xffffffffu) {
int last_lane = 31 - __clz(mask);
return base_offset + last_lane;
}
input_ptr += 32;
base_offset += 32;
idx += 32;
}
}
template <int num_parallel_tokens>
__device__ __forceinline__ void token_bounds(int32_t n_tokens,
int32_t worker_id,
int32_t& n_tokens_lower,
int32_t& n_tokens_upper) {
if (n_tokens < num_parallel_tokens && worker_id < n_tokens) {
if (worker_id >= num_parallel_tokens) return;
n_tokens_lower = worker_id;
n_tokens_upper = worker_id + 1;
} else {
int32_t chunk_size = n_tokens / num_parallel_tokens;
int32_t residual = n_tokens - chunk_size * num_parallel_tokens;
auto calc_id = [&](int32_t id) {
if (id < residual)
return min(n_tokens, id * (chunk_size + 1));
else
return min(n_tokens, id * chunk_size + residual);
};
n_tokens_lower = calc_id(worker_id);
n_tokens_upper = calc_id(worker_id + 1);
}
}
template <int BLOCK_COUNT, int SMEM_SIZE_BYTES_Y, typename fp8_type,
int THREADS, typename Idx_t, bool USE_UE8M0, int GROUP_SIZE = 128,
int NUM_STAGES = 3> int NUM_STAGES = 3>
__global__ void silu_mul_fp8_quant_deep_gemm_kernel( __global__ void silu_mul_fp8_quant_deep_gemm_kernel(
const __nv_bfloat16* __restrict__ _input, fp8_type* __restrict__ _y_q, const __nv_bfloat16* __restrict__ _input, fp8_type* __restrict__ _y_q,
float* __restrict__ _y_s, const int32_t* __restrict__ tokens_per_expert, float* __restrict__ _y_s, const int32_t* __restrict__ counts,
// sizes // sizes
Idx_t E, Idx_t T, Idx_t H, int H, int G,
// strides (in elements) // strides (in elements)
Idx_t stride_i_e, Idx_t stride_i_t, Idx_t stride_i_h, Idx_t stride_yq_e, Idx_t stride_i_e, Idx_t stride_i_t, Idx_t stride_i_h, Idx_t stride_yq_e,
Idx_t stride_yq_t, Idx_t stride_yq_h, Idx_t stride_ys_e, Idx_t stride_ys_t, Idx_t stride_yq_t, Idx_t stride_yq_h, Idx_t stride_ys_e, Idx_t stride_ys_t,
Idx_t stride_ys_g, Idx_t stride_counts_e) { Idx_t stride_ys_g, Idx_t stride_counts_e) {
#ifndef USE_ROCM
static constexpr int NUM_WARPS = THREADS / WARP_SIZE;
static constexpr int LOAD_STAGE_SIZE = 2 * GROUP_SIZE / 8;
static constexpr int LOAD_STAGE_MOD = NUM_STAGES * LOAD_STAGE_SIZE;
static constexpr int COMPUTE_STAGE_SIZE = 2 * GROUP_SIZE / 4;
static constexpr int COMPUTE_STAGE_MOD = COMPUTE_STAGE_SIZE * NUM_STAGES;
extern __shared__ __align__(16) __int128_t smem_128[];
int* s_expert_offsets =
reinterpret_cast<int*>(smem_128 + (SMEM_SIZE_BYTES_Y / 16));
static constexpr __nv_bfloat16 fp8_min = get_fp8_min<fp8_type>(); static constexpr __nv_bfloat16 fp8_min = get_fp8_min<fp8_type>();
static constexpr __nv_bfloat16 fp8_max = get_fp8_max<fp8_type>(); static constexpr __nv_bfloat16 fp8_max = get_fp8_max<fp8_type>();
// We assign EPS with it's 16-bit unsigned counterpart to allow constexpr. // We assign EPS with its 16-bit unsigned counterpart to allow constexpr.
static constexpr __nv_bfloat16 EPS = (__nv_bfloat16_raw{.x = 11996}); static constexpr __nv_bfloat16 EPS = (__nv_bfloat16_raw{.x = 11996});
int tid = threadIdx.x;
int warp_id = tid >> 5;
int lane_id = tid & 0x1f;
int running_sum{}; // We pack 8 16-bit bfloat16 values into a 128-bit __int128_t.
if (!warp_id) { static constexpr int32_t BFLOAT16_PER_GROUP = 8;
for (int i = 0; i < E; i += WARP_SIZE) {
bool valid = (i + threadIdx.x) < E;
int value =
(valid ? tokens_per_expert[i + threadIdx.x * stride_counts_e] : 0) +
(!lane_id ? running_sum : 0);
for (int offset = 1; offset < 32; offset *= 2) { // We split the shared memory in half, corresponding to gate and up matrices:
int n = __shfl_up_sync(0xFFFFFFFFu, value, offset); // [...gate_i, ...up_i] where 0 <= i < stages.
if (lane_id >= offset) value += n; static constexpr int32_t S_NUM_128 =
} 2u * (GROUP_SIZE / BFLOAT16_PER_GROUP) * NUM_WARPS * NUM_STAGES;
static constexpr auto THREAD_COUNT = NUM_WARPS * WARP_SIZE;
static constexpr int HALF_THREAD_COUNT = THREAD_COUNT / 2;
static constexpr int32_t S_NUM_64 = S_NUM_128 * 2;
__shared__ __int128_t __align__(16) s_buff_128[S_NUM_128];
if (valid) { const int32_t tid = threadIdx.x;
s_expert_offsets[i + threadIdx.x + 1] = value; const int32_t warp_id = tid / WARP_SIZE;
} const int32_t lane_id = tid % WARP_SIZE;
running_sum = __shfl_sync(0xFFFFFFFFu, value, WARP_SIZE - 1); auto s_buff_compute_32 = reinterpret_cast<__nv_bfloat162*>(s_buff_128);
}
if (!lane_id) { // block handles one (expert e, group g)
s_expert_offsets[0] = 0; int32_t pid = blockIdx.x;
} int32_t e = pid / G;
int32_t g = pid % G;
const int32_t n_tokens = counts[e * stride_counts_e];
if (!n_tokens) {
return; // Exit ASAP.
} }
__syncthreads(); const Idx_t stride_i_t_128 = stride_i_t / 8u;
int32_t total_tokens = s_expert_offsets[E]; int32_t n_tokens_lower, n_tokens_upper;
const int warp_position_yq = warp_id * (H / NUM_WARPS);
const int warp_position_scales = warp_id * (H / (GROUP_SIZE * NUM_WARPS));
// A single block will handle tokens_per_block tokens.
// Each block i iterates over tokens of a slice of n_tokens = // Each block i iterates over tokens of a slice of n_tokens =
// expert_counts[i], with the size of chunk being // expert_counts[i], with the size of chunk being
// (n_tokens / NUM_PARALLEL_TOKENS) + residual, instead of // (n_tokens / NUM_PARALLEL_TOKENS) + residual, instead of
// updiv(n_tokens, NUM_PARALLEL_TOKENS) for better scheduling. // updiv(n_tokens, NUM_PARALLEL_TOKENS) for better scheduling.
if (n_tokens < NUM_PARALLEL_TOKENS && blockIdx.y < n_tokens) {
// Each warp will get space to store its hidden dim for gate and up. // Specialize this, but can be likely fused.
__int128_t* s_hidden_load = smem_128 + warp_id * ((2 * 128 / 8) * NUM_STAGES); if (blockIdx.y >= NUM_PARALLEL_TOKENS) {
__int128_t* smem_load_ptr = s_hidden_load + lane_id; return;
}
const __nv_bfloat16 fp8_inv = __hdiv(__float2bfloat16(1.f), fp8_max); n_tokens_lower = blockIdx.y;
n_tokens_upper = blockIdx.y + 1;
int32_t compute_pipeline_offset_64 = 0;
int32_t load_stage_offset{};
const __nv_bfloat16 one_bf16 = __float2bfloat16_rn(1.f);
__int64_t* smem_compute_ptr = reinterpret_cast<__int64_t*>(smem_128) +
warp_id * (2 * (GROUP_SIZE / 4) * NUM_STAGES) +
lane_id;
__int64_t* s_gate64_ptr = smem_compute_ptr;
__int64_t* s_up64_ptr = smem_compute_ptr + GROUP_SIZE / 4;
int tokens_lower, tokens_upper;
token_bounds<BLOCK_COUNT>(total_tokens, blockIdx.x, tokens_lower,
tokens_upper);
Idx_t expert_id{}, expert_offset{}, next_expert_offset{};
int token_id = tokens_lower;
int32_t t_load{};
if (token_id < tokens_upper) {
expert_id = warp_expert_search<int>(lane_id, E, s_expert_offsets, token_id);
expert_offset = s_expert_offsets[expert_id];
next_expert_offset = s_expert_offsets[expert_id + 1];
} else { } else {
// This thread block has no work to do. auto chunk_size = n_tokens / NUM_PARALLEL_TOKENS;
auto residual = n_tokens - chunk_size * NUM_PARALLEL_TOKENS;
auto calc_id = [&](int32_t id) {
if (id < residual) {
return min(n_tokens, id * (chunk_size + 1));
} else {
return min(n_tokens, id * chunk_size + residual);
}
};
n_tokens_lower = calc_id(blockIdx.y);
n_tokens_upper = calc_id(blockIdx.y + 1);
}
if (n_tokens_lower >= n_tokens_upper) {
return; return;
} }
int t_load_bound = H / (GROUP_SIZE * NUM_WARPS); // We do calculations here, using constexpr wherever possible.
const Idx_t base_i = e * stride_i_e + NUM_WARPS * g * GROUP_SIZE * stride_i_h;
const Idx_t base_ys = e * stride_ys_e + NUM_WARPS * g * stride_ys_g;
const Idx_t base_yq =
e * stride_yq_e + NUM_WARPS * g * GROUP_SIZE * stride_yq_h;
Idx_t gate_off_128 = (base_i / static_cast<Idx_t>(8u));
auto input_128_ptr = reinterpret_cast<const __int128_t*>(_input);
auto gate_128_ptr = input_128_ptr + gate_off_128 + (tid % HALF_THREAD_COUNT) +
stride_i_t_128 * n_tokens_lower;
auto up_128_ptr = gate_128_ptr + (H * stride_i_h) / 8u;
auto y_s_ptr =
_y_s + base_ys + warp_id * stride_ys_g + n_tokens_lower * stride_ys_t;
auto y_q_ptr = _y_q + base_yq + warp_id * GROUP_SIZE +
stride_yq_t * n_tokens_lower + 4 * lane_id;
int32_t t_load = n_tokens_lower, load_stage_id = 0;
auto s_buff_gate_load_128 = s_buff_128 + (tid % HALF_THREAD_COUNT);
auto s_buff_up_load_128 = s_buff_gate_load_128 + S_NUM_128 / 2u;
int32_t stage_offset{};
Idx_t base_i = ((expert_id * stride_i_e) / 8) + static constexpr int32_t LOAD_STAGE_SIZE = (NUM_WARPS * WARP_SIZE / 2);
(token_id - expert_offset) * stride_i_t / 8; static constexpr int32_t LOAD_STAGE_MOD =
const Idx_t gate_warp_offset = NUM_STAGES * (NUM_WARPS * WARP_SIZE / 2);
warp_id * ((stride_i_h * H) / (8 * NUM_WARPS)) + (lane_id & 0b1111);
const __int128_t* input_128_ptr =
reinterpret_cast<const __int128_t*>(_input) + gate_warp_offset +
((lane_id < 16) ? 0 : ((H * stride_i_h) / 8));
__int128_t* load_ptr = const_cast<__int128_t*>(input_128_ptr + base_i);
auto token_offset = token_id - expert_offset;
// Two halves of all threads in a block conduct global loads for gate and up,
// repsectively.
auto load_and_advance_y_pred = [&] { auto load_and_advance_y_pred = [&] {
if (t_load < t_load_bound) { if (t_load < n_tokens_upper) {
// Here we are simply continuing to load data auto s_gate_stage_128_staged_ptr = s_buff_gate_load_128 + stage_offset;
// from the current token. auto s_up_stage_128_staged_ptr = s_buff_up_load_128 + stage_offset;
auto smem_load_ptr_staged = smem_load_ptr + load_stage_offset;
// It is very important that LOAD_STAGE_SIZE is constexpr to avoid // It is very important that LOAD_STAGE_SIZE is constexpr to avoid
// unnecessary ALU ops. // unnecessary ALU ops.
load_stage_offset += LOAD_STAGE_SIZE; stage_offset += LOAD_STAGE_SIZE;
load_stage_offset %= LOAD_STAGE_MOD; stage_offset %= LOAD_STAGE_MOD;
cp_async4(smem_load_ptr_staged, load_ptr); if (tid < HALF_THREAD_COUNT) {
load_ptr += GROUP_SIZE / 8; cp_async4(s_gate_stage_128_staged_ptr, gate_128_ptr);
++t_load; gate_128_ptr += stride_i_t_128;
} else if (token_id + 1 < tokens_upper) {
// We loaded everything from the current token, let's move on
// to the next one, and we checked that we have more tokens to load.
++token_id;
t_load = 0;
if (token_id >= next_expert_offset) {
// We need to find the next expert.
do {
// This is a loop because it's possible
// that some experts are assigned 0 tokens.
// NOTE: We are guaranteed that there's at least
// one more token left so we don't have to check for
// expert_id bounds.
++expert_id;
// This skips 1 memory read.
expert_offset = next_expert_offset;
next_expert_offset = s_expert_offsets[expert_id + 1];
} while (next_expert_offset == expert_offset);
base_i = expert_id * (stride_i_e / 8);
token_offset = 0;
load_ptr = const_cast<__int128_t*>(input_128_ptr + base_i);
} else { } else {
// We remain within the same expert, so just cp_async4(s_up_stage_128_staged_ptr, up_128_ptr);
// move by H/4 __int128_t (2 * H/8). up_128_ptr += stride_i_t_128;
base_i += stride_yq_t / 4;
token_offset++;
} }
load_ptr = const_cast<__int128_t*>(input_128_ptr + base_i);
auto smem_load_ptr_staged = smem_load_ptr + load_stage_offset;
// It is very important that LOAD_STAGE_SIZE is constexpr to avoid
// unnecessary ALU ops.
load_stage_offset += LOAD_STAGE_SIZE;
load_stage_offset %= LOAD_STAGE_MOD;
cp_async4(smem_load_ptr_staged, load_ptr);
load_ptr += GROUP_SIZE / 8;
++t_load; ++t_load;
++load_stage_id;
} }
// We fence even if there is nothing to load to simplify pipelining. // We fence even if there is nothing to load to simplify pipelining.
cp_async_fence(); cp_async_fence();
}; };
// We need to warm-up the pipeline.
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_STAGES - 1; i++) { for (int i = 0; i < NUM_STAGES - 1; i++) {
load_and_advance_y_pred(); load_and_advance_y_pred();
} }
__nv_fp8x4_e4m3* y_q_base_ptr = __int64_t* s_gate_ptr = reinterpret_cast<__int64_t*>(
reinterpret_cast<__nv_fp8x4_e4m3*>(_y_q) + lane_id; s_buff_compute_32 + warp_id * (GROUP_SIZE / 2)) +
auto y_scale_base_ptr = _y_s + warp_position_scales * stride_ys_g; lane_id;
__int64_t* s_up_ptr = s_gate_ptr + S_NUM_64 / 2;
for (auto j = tokens_lower; j < tokens_upper; j++) { static constexpr int32_t STAGE_SIZE = (GROUP_SIZE * NUM_WARPS) / 4u;
const Idx_t base_ys = expert_id * stride_ys_e; static constexpr int32_t STAGE_MOD = STAGE_SIZE * NUM_STAGES;
auto y_s_ptr = y_scale_base_ptr + base_ys + token_offset * stride_ys_t;
__nv_fp8x4_e4m3* y_q_ptr =
y_q_base_ptr + (expert_id * stride_yq_e + token_offset * stride_yq_t +
warp_position_yq * stride_yq_h) /
4;
const int COMPUTE_LIMIT = H / (GROUP_SIZE * NUM_WARPS);
for (int i = 0; i < COMPUTE_LIMIT; i++) { int32_t compute_pipeline_offset_64 = 0;
cp_async_wait<NUM_STAGES - 2>();
__syncthreads();
load_and_advance_y_pred();
__int64_t* gate64_ptr = s_gate64_ptr + compute_pipeline_offset_64; for (int32_t t = n_tokens_lower; t < n_tokens_upper; ++t) {
__int64_t* up64_ptr = s_up64_ptr + compute_pipeline_offset_64; __nv_bfloat162 results_bf162[2];
// COMPUTE_STAGE_SIZE/MOD must also be constexpr! cp_async_wait<NUM_STAGES - 2>();
compute_pipeline_offset_64 += COMPUTE_STAGE_SIZE; __syncthreads();
compute_pipeline_offset_64 %= COMPUTE_STAGE_MOD;
__int64_t gate64 = *gate64_ptr; // We double-buffer pipelined loads so that the next load will
__int64_t up64 = *up64_ptr; // concurrently run with compute without overwrites.
load_and_advance_y_pred();
// Compute auto s_gate_compute_64 = s_gate_ptr + compute_pipeline_offset_64;
__nv_bfloat162 res[2]; auto s_up_compute_64 = s_up_ptr + compute_pipeline_offset_64;
__nv_bfloat162* s_up_comp = reinterpret_cast<__nv_bfloat162*>(&up64);
__nv_bfloat162* s_gate_comp = reinterpret_cast<__nv_bfloat162*>(&gate64); // STAGE_SIZE must also be constexpr!
compute_pipeline_offset_64 += STAGE_SIZE;
compute_pipeline_offset_64 %= STAGE_MOD;
// Each thread loads (gate/up) 2X 4X bfloat16 values into registers.
__int64_t gate64 = *s_gate_compute_64;
__nv_bfloat162* s_gate_compute_32 =
reinterpret_cast<__nv_bfloat162*>(&gate64);
__int64_t up64 = *s_up_compute_64;
__nv_bfloat162* s_up_compute_32 = reinterpret_cast<__nv_bfloat162*>(&up64);
#pragma unroll #pragma unroll
for (int32_t k = 0; k < 2; ++k) { for (int i = 0; i < 2; i++) {
__nv_bfloat162 gate = silu2_v2(__bfloat1622float2(s_gate_comp[k])); // For silu, we make sure that div is emitted.
res[k] = __hmul2(gate, s_up_comp[k]); float2 gate = silu2(__bfloat1622float2(s_gate_compute_32[i]));
} results_bf162[i] = __float22bfloat162_rn(gate);
}
auto _y_max2 = __hmax2(__habs2(res[0]), __habs2(res[1]));
_y_max2.x = __hmax(__hmax(_y_max2.x, _y_max2.y), EPS);
__nv_bfloat16 y_s = __hmul(warp_max(_y_max2.x), fp8_inv);
if constexpr (USE_UE8M0) {
y_s = hexp2(hceil(hlog2(y_s)));
}
__nv_bfloat16 inv_y = __hdiv(one_bf16, y_s);
auto y_s2 = make_bfloat162(inv_y, inv_y);
#pragma unroll #pragma unroll
for (int32_t k = 0; k < 2; ++k) { for (int i = 0; i < 2; i++) {
res[k] = clip(__hmul2(res[k], y_s2), __bfloat162bfloat162(fp8_min), results_bf162[i] = __hmul2(results_bf162[i], s_up_compute_32[i]);
__bfloat162bfloat162(fp8_max)); }
}
*y_q_ptr = __nv_fp8x4_e4m3(res[0], res[1]); auto _y_max2 =
y_q_ptr += WARP_SIZE * stride_yq_h; __hmax2(__habs2(results_bf162[0]), __habs2(results_bf162[1]));
if (!lane_id) { __nv_bfloat16 y_max_bf16 = __hmax(EPS, __hmax(_y_max2.x, _y_max2.y));
*y_s_ptr = y_s;
y_s_ptr += stride_ys_g; // An entire group is assigned to a single warp, so a simple warp reduce
} // is used.
__nv_bfloat16 y_s = warp_max(y_max_bf16) / fp8_max;
if constexpr (USE_UE8M0) {
y_s = hexp2(hceil(hlog2(y_s)));
}
auto inv_y = __float2bfloat16_rn(1.f) / y_s;
auto y_s2 = make_bfloat162(inv_y, inv_y);
#pragma unroll
for (int32_t i = 0; i < 2; ++i) {
results_bf162[i] =
clip(__hmul2(results_bf162[i], y_s2), __bfloat162bfloat162(fp8_min),
__bfloat162bfloat162(fp8_max));
}
auto fp8x4 = __nv_fp8x4_e4m3(results_bf162[0], results_bf162[1]);
*reinterpret_cast<__nv_fp8x4_e4m3*>(y_q_ptr) = fp8x4;
y_q_ptr += stride_yq_t;
if (lane_id == 0) {
*y_s_ptr = y_s;
y_s_ptr += stride_ys_t;
} }
} }
#endif
} }
#endif
} // namespace vllm } // namespace vllm
@ -568,14 +475,14 @@ void silu_and_mul_quant(torch::Tensor& out, // [..., d]
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel); LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel);
} }
void persistent_masked_m_silu_mul_quant( void silu_mul_fp8_quant_deep_gemm_cuda(
const at::Tensor& input, // (E, T, 2*H) const at::Tensor& input, // (E, T, 2*H)
const at::Tensor& tokens_per_expert, // (E) const at::Tensor& counts, // (E)
at::Tensor& y_q, // (E, T, H) [OUT] at::Tensor& y_q, // (E, T, H) [OUT]
at::Tensor& y_s, // (E, T, H//group_size) [OUT] at::Tensor& y_s, // (E, T, H//group_size) [OUT]
bool use_ue8m0) { int64_t group_size, bool use_ue8m0, int64_t num_parallel_tokens) {
#ifndef USE_ROCM #ifndef USE_ROCM
// This kernel relies heavily on cp.async and fp8 support.
// This kernel currently only supports H % 128 == 0 and assumes a // This kernel currently only supports H % 128 == 0 and assumes a
// fixed GROUP_SIZE of 128. // fixed GROUP_SIZE of 128.
TORCH_CHECK(input.dtype() == torch::kBFloat16); TORCH_CHECK(input.dtype() == torch::kBFloat16);
@ -584,6 +491,10 @@ void persistent_masked_m_silu_mul_quant(
TORCH_CHECK(y_s.dtype() == torch::kFloat32); TORCH_CHECK(y_s.dtype() == torch::kFloat32);
TORCH_CHECK(input.size(-1) % 256 == 0); TORCH_CHECK(input.size(-1) % 256 == 0);
// Check that num_parallel_tokens is of power of 2 and between 1 and 64.
TORCH_CHECK(1 <= num_parallel_tokens && num_parallel_tokens <= 64);
TORCH_CHECK(!(num_parallel_tokens & (num_parallel_tokens - 1)));
using Idx_t = int64_t; using Idx_t = int64_t;
Idx_t E = input.size(0); Idx_t E = input.size(0);
@ -599,54 +510,81 @@ void persistent_masked_m_silu_mul_quant(
Idx_t stride_ys_t = y_s.stride(1); Idx_t stride_ys_t = y_s.stride(1);
Idx_t stride_ys_g = y_s.stride(2); Idx_t stride_ys_g = y_s.stride(2);
Idx_t stride_counts_e = tokens_per_expert.stride(0); Idx_t stride_counts_e = counts.stride(0);
static constexpr int GROUP_SIZE = 128; static constexpr int GROUP_SIZE = 128;
#define KERNEL_FN \
if (use_ue8m0) { \
vllm::silu_mul_fp8_quant_deep_gemm_kernel<fp8_t, NUM_WARPS, Idx_t, \
NUM_PARALLEL_TOKENS, true> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<__nv_bfloat16*>(input.data_ptr()), \
(fp8_t*)y_q.data_ptr(), y_s.data_ptr<float>(), \
reinterpret_cast<int32_t*>(counts.data_ptr<int>()), H, G, \
stride_i_e, stride_i_t, stride_i_h, stride_yq_e, stride_yq_t, \
stride_yq_h, stride_ys_e, stride_ys_t, stride_ys_g, \
stride_counts_e); \
} else { \
vllm::silu_mul_fp8_quant_deep_gemm_kernel<fp8_t, NUM_WARPS, Idx_t, \
NUM_PARALLEL_TOKENS, false> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<__nv_bfloat16*>(input.data_ptr()), \
(fp8_t*)y_q.data_ptr(), y_s.data_ptr<float>(), \
reinterpret_cast<int32_t*>(counts.data_ptr<int>()), H, G, \
stride_i_e, stride_i_t, stride_i_h, stride_yq_e, stride_yq_t, \
stride_yq_h, stride_ys_e, stride_ys_t, stride_ys_g, \
stride_counts_e); \
}
#define KERNEL_CALL_H \
if (H % (4 * GROUP_SIZE) == 0) { \
static constexpr int NUM_WARPS = 4; \
populate_launch_params(NUM_WARPS, NUM_PARALLEL_TOKENS); \
KERNEL_FN \
} else { \
static constexpr int NUM_WARPS = 1; \
populate_launch_params(NUM_WARPS, NUM_PARALLEL_TOKENS); \
KERNEL_FN \
}
#define KERNEL_CALL_TOP_LEVEL \
if (num_parallel_tokens == 1) { \
static constexpr int NUM_PARALLEL_TOKENS = 1; \
KERNEL_CALL_H \
} else if (num_parallel_tokens == 2) { \
static constexpr int NUM_PARALLEL_TOKENS = 2; \
KERNEL_CALL_H \
} else if (num_parallel_tokens == 4) { \
static constexpr int NUM_PARALLEL_TOKENS = 4; \
KERNEL_CALL_H \
} else if (num_parallel_tokens == 8) { \
static constexpr int NUM_PARALLEL_TOKENS = 8; \
KERNEL_CALL_H \
} else if (num_parallel_tokens == 16) { \
static constexpr int NUM_PARALLEL_TOKENS = 16; \
KERNEL_CALL_H \
} else if (num_parallel_tokens == 32) { \
static constexpr int NUM_PARALLEL_TOKENS = 32; \
KERNEL_CALL_H \
} else if (num_parallel_tokens == 64) { \
static constexpr int NUM_PARALLEL_TOKENS = 64; \
KERNEL_CALL_H \
}
Idx_t G;
dim3 block, grid;
auto populate_launch_params = [&](int num_warps, int _num_parallel_tokens) {
G = H / Idx_t(group_size * num_warps);
grid = dim3(E * G, _num_parallel_tokens);
block = dim3(num_warps * WARP_SIZE);
};
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
#define KERNEL(BLOCK_COUNT, USE_UE8M0, THREAD_COUNT, STAGES) \ VLLM_DISPATCH_FP8_TYPES(y_q.scalar_type(),
static constexpr int NUM_WARPS = THREAD_COUNT / WARP_SIZE; \ "silu_mul_fp8_quant_deep_gemm_kernel",
int sms = SILU_V2_BLOCK_COUNT; \ [&] { KERNEL_CALL_TOP_LEVEL });
static constexpr int max_shared_mem_bytes = \
GROUP_SIZE * 2 * STAGES * NUM_WARPS * 2; \
dim3 grid(sms), block(THREAD_COUNT); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
VLLM_DISPATCH_FP8_TYPES( \
y_q.scalar_type(), "silu_mul_fp8_quant_deep_gemm_kernel", [&] { \
vllm::silu_mul_fp8_quant_deep_gemm_kernel< \
BLOCK_COUNT, max_shared_mem_bytes, fp8_t, THREAD_COUNT, Idx_t, \
USE_UE8M0, GROUP_SIZE, STAGES> \
<<<grid, block, max_shared_mem_bytes + (E + 1) * 16, stream>>>( \
reinterpret_cast<__nv_bfloat16*>(input.data_ptr()), \
(fp8_t*)y_q.data_ptr(), y_s.data_ptr<float>(), \
reinterpret_cast<int32_t*>(tokens_per_expert.data_ptr()), E, \
T, H, stride_i_e, stride_i_t, stride_i_h, stride_yq_e, \
stride_yq_t, stride_yq_h, stride_ys_e, stride_ys_t, \
stride_ys_g, stride_counts_e); \
});
static constexpr int SILU_V2_BLOCK_COUNT = 132 * 32;
if (!use_ue8m0) {
if (H >= 4096) {
static constexpr int NUM_STAGES = 4;
static constexpr int THREAD_COUNT = 256;
KERNEL(SILU_V2_BLOCK_COUNT, false, THREAD_COUNT, NUM_STAGES);
} else {
static constexpr int THREAD_COUNT = 32;
KERNEL(SILU_V2_BLOCK_COUNT, false, THREAD_COUNT, 2);
}
} else {
if (H >= 4096) {
static constexpr int NUM_STAGES = 4;
static constexpr int THREAD_COUNT = 256;
KERNEL(SILU_V2_BLOCK_COUNT, true, THREAD_COUNT, NUM_STAGES);
} else {
static constexpr int THREAD_COUNT = 32;
KERNEL(SILU_V2_BLOCK_COUNT, true, THREAD_COUNT, 2);
}
}
#endif #endif
} }

View File

@ -1,11 +1,15 @@
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/all.h> #include <torch/all.h>
#ifndef USE_ROCM
#include "../per_token_group_quant_8bit.h"
#endif
#include <cmath> #include <cmath>
#include "dispatch_utils.h" #include "../../cub_helpers.h"
#include "quantization/vectorization_utils.cuh" #include "../../dispatch_utils.h"
#include "cub_helpers.h" #include "../vectorization_utils.cuh"
static inline __device__ int8_t float_to_int8_rn(float x) { static inline __device__ int8_t float_to_int8_rn(float x) {
#ifdef USE_ROCM #ifdef USE_ROCM
@ -21,6 +25,7 @@ static inline __device__ int8_t float_to_int8_rn(float x) {
float dst = std::nearbyint(x); float dst = std::nearbyint(x);
// saturate // saturate
// See https://github.com/pytorch/pytorch/issues/127666 // See https://github.com/pytorch/pytorch/issues/127666
// See https://github.com/llvm/llvm-project/issues/95183 // See https://github.com/llvm/llvm-project/issues/95183
// hip-clang std::clamp __glibcxx_assert_fail host function when building on // hip-clang std::clamp __glibcxx_assert_fail host function when building on
@ -79,6 +84,7 @@ static inline __device__ int8_t int32_to_int8(int32_t x) {
static_cast<int32_t>(std::numeric_limits<int8_t>::max()); static_cast<int32_t>(std::numeric_limits<int8_t>::max());
// saturate // saturate
// See https://github.com/pytorch/pytorch/issues/127666 // See https://github.com/pytorch/pytorch/issues/127666
// See https://github.com/llvm/llvm-project/issues/95183 // See https://github.com/llvm/llvm-project/issues/95183
// hip-clang std::clamp __glibcxx_assert_fail host function when building on // hip-clang std::clamp __glibcxx_assert_fail host function when building on
@ -170,6 +176,7 @@ __global__ void dynamic_scaled_int8_quant_kernel(
float inv_s = (absmax == 0.f) ? 0.f : 127.f / absmax; float inv_s = (absmax == 0.f) ? 0.f : 127.f / absmax;
// 2. quantize
vectorize_with_alignment<16>( vectorize_with_alignment<16>(
row_in, row_out, hidden_size, tid, stride, row_in, row_out, hidden_size, tid, stride,
[=] __device__(int8_t& dst, const scalar_t& src) { [=] __device__(int8_t& dst, const scalar_t& src) {
@ -187,6 +194,7 @@ struct MinMax {
__host__ __device__ explicit MinMax(float v) : min(v), max(v) {} __host__ __device__ explicit MinMax(float v) : min(v), max(v) {}
// add a value to the MinMax
__host__ __device__ MinMax& operator+=(float v) { __host__ __device__ MinMax& operator+=(float v) {
min = fminf(min, v); min = fminf(min, v);
max = fmaxf(max, v); max = fmaxf(max, v);
@ -220,6 +228,7 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel(
const scalar_t* row_in = input + token_idx * hidden_size; const scalar_t* row_in = input + token_idx * hidden_size;
int8_t* row_out = output + token_idx * hidden_size; int8_t* row_out = output + token_idx * hidden_size;
// 1. calculate min & max
MinMax thread_mm; MinMax thread_mm;
vectorize_read_with_alignment<16>(row_in, hidden_size, tid, stride, vectorize_read_with_alignment<16>(row_in, hidden_size, tid, stride,
[&] __device__(const scalar_t& src) { [&] __device__(const scalar_t& src) {
@ -252,6 +261,7 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel(
const float inv_s = 1.f / scale_sh; const float inv_s = 1.f / scale_sh;
const azp_t azp = azp_sh; const azp_t azp = azp_sh;
// 2. quantize
vectorize_with_alignment<16>( vectorize_with_alignment<16>(
row_in, row_out, hidden_size, tid, stride, row_in, row_out, hidden_size, tid, stride,
[=] __device__(int8_t& dst, const scalar_t& src) { [=] __device__(int8_t& dst, const scalar_t& src) {
@ -322,4 +332,14 @@ void dynamic_scaled_int8_quant(
hidden_size); hidden_size);
} }
}); });
} }
#ifndef USE_ROCM
void per_token_group_quant_int8(const torch::Tensor& input,
torch::Tensor& output_q,
torch::Tensor& output_s, int64_t group_size,
double eps, double int8_min, double int8_max) {
per_token_group_quant_8bit(input, output_q, output_s, group_size, eps,
int8_min, int8_max);
}
#endif

Some files were not shown because too many files have changed in this diff Show More