mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
Compare commits
204 Commits
v0.5.5
...
v0.6.1.pos
Author | SHA1 | Date | |
---|---|---|---|
acda0b35d0 | |||
ba77527955 | |||
6821020109 | |||
8427550488 | |||
3f79bc3d1a | |||
40c396533d | |||
5ec9c0fb3c | |||
8f44a92d85 | |||
360ddbd37e | |||
a480939e8e | |||
d31174a4e1 | |||
b61bd98f90 | |||
c16369455f | |||
019877253b | |||
551ce01078 | |||
a6c0f3658d | |||
f2e263b801 | |||
1f0c75afa9 | |||
8a23e93302 | |||
c6202daeed | |||
e56bf27741 | |||
520ca380ae | |||
7de49aa86c | |||
42ffba11ad | |||
295c4730a8 | |||
1bf2dd9df0 | |||
5a60699c45 | |||
b6c75e1cf2 | |||
b71c956deb | |||
f842a7aff1 | |||
a65cb16067 | |||
3fd2b0d21c | |||
d394787e52 | |||
775f00f81e | |||
8baa454937 | |||
73202dbe77 | |||
7015417fd4 | |||
aea02f30de | |||
0b952af458 | |||
3b7fea770f | |||
cea95dfb94 | |||
6a512a00df | |||
efcf946a15 | |||
1230263e16 | |||
e497b8aeff | |||
94144e726c | |||
1d5e397aa4 | |||
22f3a4bc6c | |||
b1f3e18958 | |||
04e7c4e771 | |||
5faedf1b62 | |||
02751a7a42 | |||
f421f3cefb | |||
8c054b7a62 | |||
6234385f4a | |||
da1a844e61 | |||
a1d874224d | |||
6cd5e5b07e | |||
c7cb5c3335 | |||
f9b4a2d415 | |||
58fcc8545a | |||
08287ef675 | |||
4ef41b8476 | |||
cfe712bf1a | |||
b962ee1470 | |||
36bf8150cc | |||
e807125936 | |||
9f68e00d27 | |||
ce2702a923 | |||
795b662cff | |||
2f707fcb35 | |||
41e95c5247 | |||
12dd715807 | |||
29f49cd6e3 | |||
23f322297f | |||
9db52eab3d | |||
1447c97e75 | |||
de80783b69 | |||
e5cab71531 | |||
baa5467547 | |||
db3bf7c991 | |||
2febcf2777 | |||
2ee45281a5 | |||
9da25a88aa | |||
8685ba1a1e | |||
288a938872 | |||
e39ebf5cf5 | |||
ba262c4e5a | |||
4624d98dbd | |||
1afc931987 | |||
e01c2beb7d | |||
32e7db2536 | |||
008cf886c9 | |||
77d9e514a2 | |||
e02ce498be | |||
561d6f8077 | |||
d1dec64243 | |||
2ad2e5608e | |||
d3311562fb | |||
ccd7207191 | |||
855c262a6b | |||
2be8ec6e71 | |||
e16fa99a6a | |||
61f4a93d14 | |||
d4db9f53c8 | |||
2188a60c7e | |||
dc0b6066ab | |||
0af3abe3d3 | |||
f1575dc99f | |||
c02638efb3 | |||
652c83b697 | |||
6d646d08a2 | |||
95a178f861 | |||
bd852f2a8b | |||
ec266536b7 | |||
0fbc6696c2 | |||
6e36f4fa6c | |||
dd2a6a82e3 | |||
4ca65a9763 | |||
e2b2aa5a0f | |||
e6a26ed037 | |||
f8d60145b4 | |||
5b86b19954 | |||
5231f0898e | |||
8423aef4c8 | |||
4f5d8446ed | |||
d05f0a9db2 | |||
622f8abff8 | |||
1248e8506a | |||
2684efc467 | |||
058344f89a | |||
98cef6a227 | |||
f97be32d1d | |||
afd39a4511 | |||
2148441fd3 | |||
dc13e99348 | |||
34a0e96d46 | |||
80c7b089b1 | |||
428dd1445e | |||
4abed65c58 | |||
0c785d344d | |||
4664ceaad6 | |||
257afc37c5 | |||
86a677de42 | |||
d78789ac16 | |||
c334b1898b | |||
6b3421567d | |||
3f60f2244e | |||
f205c09854 | |||
ef99a78760 | |||
74d5543ec5 | |||
a7f65c2be9 | |||
4289cad37f | |||
af59df0a10 | |||
ce6bf3a2cf | |||
3cdfe1f38b | |||
fdd9daafa3 | |||
8c56e57def | |||
eeffde1ac0 | |||
e5697d161c | |||
b98cc28f91 | |||
ef9baee3c5 | |||
98c12cffe5 | |||
f52a43a8b9 | |||
e3580537a4 | |||
f508e03e7f | |||
51f86bf487 | |||
c166e7e43e | |||
bc6e42a9b1 | |||
fab5f53e2d | |||
9c71c97ae2 | |||
5340a2dccf | |||
345be0e244 | |||
fc911880cc | |||
ed6f002d33 | |||
b09c755be8 | |||
42e932c7d4 | |||
076169f603 | |||
9db642138b | |||
6fc4e6e07a | |||
9606c7197d | |||
64cc644425 | |||
39178c7fbc | |||
2eedede875 | |||
015e6cc252 | |||
760e9f71a8 | |||
05826c887b | |||
dd9857f5fa | |||
665304092d | |||
2deb029d11 | |||
029c71de11 | |||
0b769992ec | |||
1856aff4d6 | |||
70c094ade6 | |||
2059b8d9ca | |||
8aaf3d5347 | |||
80162c44b1 | |||
aab0fcdb63 | |||
ea9fa160e3 | |||
7d9ffa2ae1 | |||
d81abefd2e | |||
8da48e4d95 | |||
6885fde317 | |||
9db93de20c |
@ -1,36 +1,43 @@
|
||||
import os
|
||||
import sys
|
||||
import zipfile
|
||||
|
||||
MAX_SIZE_MB = 250
|
||||
# Read the VLLM_MAX_SIZE_MB environment variable, defaulting to 250 MB
|
||||
VLLM_MAX_SIZE_MB = int(os.environ.get('VLLM_MAX_SIZE_MB', 250))
|
||||
|
||||
|
||||
def print_top_10_largest_files(zip_file):
|
||||
"""Print the top 10 largest files in the given zip file."""
|
||||
with zipfile.ZipFile(zip_file, 'r') as z:
|
||||
file_sizes = [(f, z.getinfo(f).file_size) for f in z.namelist()]
|
||||
file_sizes.sort(key=lambda x: x[1], reverse=True)
|
||||
for f, size in file_sizes[:10]:
|
||||
print(f"{f}: {size/(1024*1024)} MBs uncompressed.")
|
||||
print(f"{f}: {size / (1024 * 1024):.2f} MBs uncompressed.")
|
||||
|
||||
|
||||
def check_wheel_size(directory):
|
||||
"""Check the size of .whl files in the given directory."""
|
||||
for root, _, files in os.walk(directory):
|
||||
for f in files:
|
||||
if f.endswith(".whl"):
|
||||
wheel_path = os.path.join(root, f)
|
||||
wheel_size = os.path.getsize(wheel_path)
|
||||
wheel_size_mb = wheel_size / (1024 * 1024)
|
||||
if wheel_size_mb > MAX_SIZE_MB:
|
||||
print(
|
||||
f"Wheel {wheel_path} is too large ({wheel_size_mb} MB) "
|
||||
f"compare to the allowed size ({MAX_SIZE_MB} MB).")
|
||||
for file_name in files:
|
||||
if file_name.endswith(".whl"):
|
||||
wheel_path = os.path.join(root, file_name)
|
||||
wheel_size_mb = os.path.getsize(wheel_path) / (1024 * 1024)
|
||||
if wheel_size_mb > VLLM_MAX_SIZE_MB:
|
||||
print(f"Not allowed: Wheel {wheel_path} is larger "
|
||||
f"({wheel_size_mb:.2f} MB) than the limit "
|
||||
f"({VLLM_MAX_SIZE_MB} MB).")
|
||||
print_top_10_largest_files(wheel_path)
|
||||
return 1
|
||||
else:
|
||||
print(f"Wheel {wheel_path} is within the allowed size "
|
||||
f"({wheel_size_mb} MB).")
|
||||
f"({wheel_size_mb:.2f} MB).")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
sys.exit(check_wheel_size(sys.argv[1]))
|
||||
if len(sys.argv) < 2:
|
||||
print("Usage: python check-wheel-size.py <directory>")
|
||||
sys.exit(1)
|
||||
|
||||
directory = sys.argv[1]
|
||||
sys.exit(check_wheel_size(directory))
|
@ -1,5 +1,4 @@
|
||||
Meta-Llama-3-8B-Instruct.yaml
|
||||
Meta-Llama-3-8B-Instruct-FP8.yaml
|
||||
Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml
|
||||
Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml
|
||||
Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml
|
||||
|
65
.buildkite/run-amd-test.sh
Normal file → Executable file
65
.buildkite/run-amd-test.sh
Normal file → Executable file
@ -1,5 +1,5 @@
|
||||
# This script runs test inside the corresponding ROCm docker container.
|
||||
set -ex
|
||||
set -o pipefail
|
||||
|
||||
# Print ROCm version
|
||||
echo "--- Confirming Clean Initial State"
|
||||
@ -70,15 +70,74 @@ HF_CACHE="$(realpath ~)/huggingface"
|
||||
mkdir -p ${HF_CACHE}
|
||||
HF_MOUNT="/root/.cache/huggingface"
|
||||
|
||||
commands=$@
|
||||
echo "Commands:$commands"
|
||||
#ignore certain kernels tests
|
||||
if [[ $commands == *" kernels "* ]]; then
|
||||
commands="${commands} \
|
||||
--ignore=kernels/test_attention.py \
|
||||
--ignore=kernels/test_attention_selector.py \
|
||||
--ignore=kernels/test_blocksparse_attention.py \
|
||||
--ignore=kernels/test_causal_conv1d.py \
|
||||
--ignore=kernels/test_cutlass.py \
|
||||
--ignore=kernels/test_encoder_decoder_attn.py \
|
||||
--ignore=kernels/test_flash_attn.py \
|
||||
--ignore=kernels/test_flashinfer.py \
|
||||
--ignore=kernels/test_int8_quant.py \
|
||||
--ignore=kernels/test_machete_gemm.py \
|
||||
--ignore=kernels/test_mamba_ssm.py \
|
||||
--ignore=kernels/test_marlin_gemm.py \
|
||||
--ignore=kernels/test_moe.py \
|
||||
--ignore=kernels/test_prefix_prefill.py \
|
||||
--ignore=kernels/test_rand.py \
|
||||
--ignore=kernels/test_sampler.py"
|
||||
fi
|
||||
|
||||
PARALLEL_JOB_COUNT=8
|
||||
# check if the command contains shard flag, we will run all shards in parallel because the host have 8 GPUs.
|
||||
if [[ $commands == *"--shard-id="* ]]; then
|
||||
for GPU in $(seq 0 $(($PARALLEL_JOB_COUNT-1))); do
|
||||
#replace shard arguments
|
||||
commands=${commands//"--shard-id= "/"--shard-id=${GPU} "}
|
||||
commands=${commands//"--num-shards= "/"--num-shards=${PARALLEL_JOB_COUNT} "}
|
||||
echo "Shard ${GPU} commands:$commands"
|
||||
docker run \
|
||||
--device /dev/kfd --device /dev/dri \
|
||||
--network host \
|
||||
--shm-size=16gb \
|
||||
--rm \
|
||||
-e HIP_VISIBLE_DEVICES=${GPU} \
|
||||
-e HF_TOKEN \
|
||||
-v ${HF_CACHE}:${HF_MOUNT} \
|
||||
-e HF_HOME=${HF_MOUNT} \
|
||||
--name ${container_name}_${GPU} \
|
||||
${image_name} \
|
||||
/bin/bash -c "${commands}" \
|
||||
|& while read -r line; do echo ">>Shard $GPU: $line"; done &
|
||||
PIDS+=($!)
|
||||
done
|
||||
#wait for all processes to finish and collect exit codes
|
||||
for pid in ${PIDS[@]}; do
|
||||
wait ${pid}
|
||||
STATUS+=($?)
|
||||
done
|
||||
for st in ${STATUS[@]}; do
|
||||
if [[ ${st} -ne 0 ]]; then
|
||||
echo "One of the processes failed with $st"
|
||||
exit ${st}
|
||||
fi
|
||||
done
|
||||
else
|
||||
docker run \
|
||||
--device /dev/kfd --device /dev/dri \
|
||||
--network host \
|
||||
--shm-size=16gb \
|
||||
--rm \
|
||||
-e HIP_VISIBLE_DEVICES=0 \
|
||||
-e HF_TOKEN \
|
||||
-v ${HF_CACHE}:${HF_MOUNT} \
|
||||
-e HF_HOME=${HF_MOUNT} \
|
||||
--name ${container_name} \
|
||||
${image_name} \
|
||||
/bin/bash -c "${@}"
|
||||
|
||||
/bin/bash -c "${commands}"
|
||||
fi
|
||||
|
33
.buildkite/run-cpu-test-ppc64le.sh
Executable file
33
.buildkite/run-cpu-test-ppc64le.sh
Executable file
@ -0,0 +1,33 @@
|
||||
# This script build the CPU docker image and run the offline inference inside the container.
|
||||
# It serves a sanity check for compilation and basic model usage.
|
||||
set -ex
|
||||
|
||||
# Try building the docker image
|
||||
docker build -t cpu-test -f Dockerfile.ppc64le .
|
||||
|
||||
# Setup cleanup
|
||||
remove_docker_container() { docker rm -f cpu-test || true; }
|
||||
trap remove_docker_container EXIT
|
||||
remove_docker_container
|
||||
|
||||
# Run the image, setting --shm-size=4g for tensor parallel.
|
||||
source /etc/environment
|
||||
#docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test cpu-test
|
||||
docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true --network host -e HF_TOKEN=$HF_TOKEN --name cpu-test cpu-test
|
||||
|
||||
# Run basic model test
|
||||
docker exec cpu-test bash -c "
|
||||
pip install pytest matplotlib einops transformers_stream_generator
|
||||
pytest -v -s tests/models -m \"not vlm\" --ignore=tests/models/test_embedding.py --ignore=tests/models/test_oot_registration.py --ignore=tests/models/test_registry.py --ignore=tests/models/test_jamba.py --ignore=tests/models/test_danube3_4b.py" # Mamba and Danube3-4B on CPU is not supported
|
||||
|
||||
# online inference
|
||||
docker exec cpu-test bash -c "
|
||||
python3 -m vllm.entrypoints.openai.api_server --model facebook/opt-125m &
|
||||
timeout 600 bash -c 'until curl localhost:8000/v1/models; do sleep 1; done' || exit 1
|
||||
python3 benchmarks/benchmark_serving.py \
|
||||
--backend vllm \
|
||||
--dataset-name random \
|
||||
--model facebook/opt-125m \
|
||||
--num-prompts 20 \
|
||||
--endpoint /v1/completions \
|
||||
--tokenizer facebook/opt-125m"
|
@ -23,7 +23,18 @@ docker exec cpu-test-avx2 bash -c "python3 examples/offline_inference.py"
|
||||
# Run basic model test
|
||||
docker exec cpu-test bash -c "
|
||||
pip install pytest matplotlib einops transformers_stream_generator
|
||||
pytest -v -s tests/models -m \"not vlm\" --ignore=tests/models/test_embedding.py --ignore=tests/models/test_oot_registration.py --ignore=tests/models/test_registry.py --ignore=tests/models/test_jamba.py --ignore=tests/models/test_danube3_4b.py" # Mamba and Danube3-4B on CPU is not supported
|
||||
pytest -v -s tests/models -m \"not vlm\" --ignore=tests/models/test_embedding.py \
|
||||
--ignore=tests/models/test_oot_registration.py \
|
||||
--ignore=tests/models/test_registry.py \
|
||||
--ignore=tests/models/test_fp8.py \
|
||||
--ignore=tests/models/test_jamba.py \
|
||||
--ignore=tests/models/test_danube3_4b.py" # Mamba and Danube3-4B on CPU is not supported
|
||||
|
||||
# Run compressed-tensor test
|
||||
docker exec cpu-test bash -c "
|
||||
pytest -s -v \
|
||||
tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_static_setup \
|
||||
tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_dynanmic_per_token"
|
||||
|
||||
# online inference
|
||||
docker exec cpu-test bash -c "
|
||||
|
@ -12,5 +12,4 @@ remove_docker_container
|
||||
# For HF_TOKEN.
|
||||
source /etc/environment
|
||||
# Run a simple end-to-end example.
|
||||
docker run --privileged --net host --shm-size=16G -it -e HF_TOKEN=$HF_TOKEN --name tpu-test vllm-tpu \
|
||||
python3 /workspace/vllm/examples/offline_inference_tpu.py
|
||||
docker run --privileged --net host --shm-size=16G -it -e HF_TOKEN=$HF_TOKEN --name tpu-test vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git && python3 -m pip install pytest && pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py && python3 /workspace/vllm/tests/tpu/test_compilation.py && python3 /workspace/vllm/examples/offline_inference_tpu.py"
|
||||
|
@ -50,6 +50,7 @@ steps:
|
||||
- tests/worker
|
||||
commands:
|
||||
- pytest -v -s async_engine # Async Engine
|
||||
- NUM_SCHEDULER_STEPS=4 pytest -v -s async_engine/test_async_llm_engine.py
|
||||
- pytest -v -s test_inputs.py
|
||||
- pytest -v -s multimodal
|
||||
- pytest -v -s test_utils.py # Utils
|
||||
@ -87,8 +88,12 @@ steps:
|
||||
commands:
|
||||
- pip install -e ./plugins/vllm_add_dummy_model
|
||||
- pip install git+https://github.com/EleutherAI/lm-evaluation-harness.git@a4987bba6e9e9b3f22bd3a6c1ecf0abd04fd5622#egg=lm_eval[api]
|
||||
- pytest -v -s entrypoints/llm
|
||||
- pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py
|
||||
- pytest -v -s entrypoints/llm/test_lazy_outlines.py # it needs a clean process
|
||||
- pytest -v -s entrypoints/openai
|
||||
- pytest -v -s entrypoints/test_chat_utils.py
|
||||
- pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests
|
||||
|
||||
|
||||
- label: Distributed Tests (4 GPUs) # 10min
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
@ -155,6 +160,7 @@ steps:
|
||||
- python3 offline_inference_with_prefix.py
|
||||
- python3 llm_engine_example.py
|
||||
- python3 offline_inference_vision_language.py
|
||||
- python3 offline_inference_vision_language_multi_image.py
|
||||
- python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
|
||||
- python3 offline_inference_encoder_decoder.py
|
||||
|
||||
@ -172,6 +178,7 @@ steps:
|
||||
- vllm/
|
||||
commands:
|
||||
- pytest -v -s ./compile/test_full_graph.py
|
||||
- pytest -v -s ./compile/test_wrapper.py
|
||||
|
||||
|
||||
- label: Vision Language Models Test # 42min
|
||||
@ -212,17 +219,19 @@ steps:
|
||||
commands:
|
||||
# See https://github.com/vllm-project/vllm/issues/5152
|
||||
- export VLLM_ATTENTION_BACKEND=XFORMERS
|
||||
- pytest -v -s spec_decode
|
||||
- pytest -v -s spec_decode/e2e/test_multistep_correctness.py
|
||||
- pytest -v -s spec_decode --ignore=spec_decode/e2e/test_multistep_correctness.py
|
||||
|
||||
- label: LoRA Test %N # 30min each
|
||||
mirror_hardwares: [amd]
|
||||
source_file_dependencies:
|
||||
- vllm/lora
|
||||
- csrc/punica
|
||||
- tests/lora
|
||||
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py
|
||||
parallelism: 4
|
||||
|
||||
- label: Kernels Test %N # 30min each
|
||||
mirror_hardwares: [amd]
|
||||
source_file_dependencies:
|
||||
- csrc/
|
||||
- vllm/attention
|
||||
@ -232,12 +241,13 @@ steps:
|
||||
parallelism: 4
|
||||
|
||||
- label: Tensorizer Test # 11min
|
||||
mirror_hardwares: [amd]
|
||||
soft_fail: true
|
||||
source_file_dependencies:
|
||||
- vllm/model_executor/model_loader
|
||||
- tests/tensorizer_loader
|
||||
commands:
|
||||
- apt-get install -y curl libsodium23
|
||||
- apt-get update && apt-get install -y curl libsodium23
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
- pytest -v -s tensorizer_loader
|
||||
|
||||
@ -267,6 +277,15 @@ steps:
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
- bash ./run-tests.sh -c configs/models-small.txt -t 1
|
||||
|
||||
- label: OpenAI-Compatible Tool Use # 20 min
|
||||
fast_check: false
|
||||
mirror_hardwares: [ amd ]
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/tool_use
|
||||
commands:
|
||||
- pytest -v -s tool_use
|
||||
|
||||
##### 1 GPU test #####
|
||||
##### multi gpus test #####
|
||||
|
||||
@ -335,7 +354,8 @@ steps:
|
||||
- vllm/engine
|
||||
- tests/multi_step
|
||||
commands:
|
||||
- pytest -v -s multi_step/test_correctness.py
|
||||
- pytest -v -s multi_step/test_correctness_async_llm.py
|
||||
- pytest -v -s multi_step/test_correctness_llm.py
|
||||
|
||||
- label: Pipeline Parallelism Test # 23min
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
@ -353,9 +373,9 @@ steps:
|
||||
- label: LoRA Long Context (Distributed) # 11min
|
||||
# This test runs llama 13B, so it is required to run on 4 GPUs.
|
||||
num_gpus: 4
|
||||
soft_fail: true
|
||||
source_file_dependencies:
|
||||
- vllm/lora
|
||||
- csrc/punica
|
||||
- tests/lora/test_long_context
|
||||
commands:
|
||||
# FIXIT: find out which code initialize cuda before running the test
|
||||
@ -370,7 +390,18 @@ steps:
|
||||
- vllm/
|
||||
- tests/weight_loading
|
||||
commands:
|
||||
- bash weight_loading/run_model_weight_loading_test.sh
|
||||
- bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models.txt
|
||||
|
||||
- label: Weight Loading Multiple GPU Test - Large Models # optional
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
num_gpus: 2
|
||||
gpu: a100
|
||||
optional: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/weight_loading
|
||||
commands:
|
||||
- bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models-large.txt
|
||||
|
||||
|
||||
##### multi gpus test #####
|
||||
|
9
.github/ISSUE_TEMPLATE/400-bug report.yml
vendored
9
.github/ISSUE_TEMPLATE/400-bug report.yml
vendored
@ -30,6 +30,15 @@ body:
|
||||
</details>
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Model Input Dumps
|
||||
description: |
|
||||
If you are facing crashing due to illegal memory access or other issues with model execution, vLLM may dump the problematic input of the model. In this case, you will see the message `Error in model execution (input dumped to /tmp/err_xxx.pkl)`. If you see this message, please zip the file (because GitHub doesn't support .pkl file format) and upload it here. This will help us to reproduce the issue and facilitate the debugging process.
|
||||
placeholder: |
|
||||
Upload the dumped input file.
|
||||
validations:
|
||||
required: false
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: 🐛 Describe the bug
|
||||
|
10
.github/PULL_REQUEST_TEMPLATE.md
vendored
10
.github/PULL_REQUEST_TEMPLATE.md
vendored
@ -39,6 +39,16 @@ FIX #xxxx (*link existing issues this PR will resolve*)
|
||||
<li>Please add documentation to <code>docs/source/</code> if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.</li>
|
||||
</ul>
|
||||
|
||||
<h3>Adding or changing kernels</h3>
|
||||
<p>Each custom kernel needs a schema and one or more implementations to be registered with PyTorch.</p>
|
||||
<ul>
|
||||
<li>Make sure custom ops are registered following PyTorch guidelines: <a href="https://pytorch.org/tutorials/advanced/cpp_custom_ops.html#cpp-custom-ops-tutorial">Custom C++ and CUDA Operators</a> and <a href="https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU">The Custom Operators Manual</a></li>
|
||||
<li>Custom operations that return <code>Tensors</code> require meta-functions. Meta-functions should be implemented and registered in python so that dynamic dims can be handled automatically. See above documents for a description of meta-functions.</li>
|
||||
<li>Use <a href="https://pytorch.org/docs/stable/library.html#torch.library.opcheck"><code>torch.libary.opcheck()</code></a> to test the function registration and meta-function for any registered ops. See <code>tests/kernels</code> for examples.</li>
|
||||
<li>When changing the C++ signature of an existing op, the schema must be updated to reflect the changes.</li>
|
||||
<li>If a new custom type is needed, see the following document: <a href="https://docs.google.com/document/d/18fBMPuOJ0fY5ZQ6YyrHUppw9FA332CpNtgB6SOIgyuA">Custom Class Support in PT2</a>.
|
||||
</ul>
|
||||
|
||||
<h3>Notes for Large Changes</h3>
|
||||
<p>Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with <code>rfc-required</code> and might not go through the PR.</p>
|
||||
|
||||
|
23
.github/workflows/add_label_ready_comment.yml
vendored
23
.github/workflows/add_label_ready_comment.yml
vendored
@ -1,23 +0,0 @@
|
||||
name: Add Ready Label on Ready Comment
|
||||
|
||||
on:
|
||||
issue_comment:
|
||||
types: [created]
|
||||
|
||||
jobs:
|
||||
add-ready-label:
|
||||
runs-on: ubuntu-latest
|
||||
if: github.event.issue.pull_request && contains(github.event.comment.body, '/ready')
|
||||
steps:
|
||||
- name: Add label
|
||||
uses: actions/github-script@v5
|
||||
with:
|
||||
script: |
|
||||
github.rest.issues.addLabels({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: context.issue.number,
|
||||
labels: ['ready']
|
||||
})
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
1
.github/workflows/mypy.yaml
vendored
1
.github/workflows/mypy.yaml
vendored
@ -35,7 +35,6 @@ jobs:
|
||||
mypy
|
||||
mypy tests --follow-imports skip
|
||||
mypy vllm/attention --follow-imports skip
|
||||
mypy vllm/core --follow-imports skip
|
||||
mypy vllm/distributed --follow-imports skip
|
||||
mypy vllm/engine --follow-imports skip
|
||||
mypy vllm/executor --follow-imports skip
|
||||
|
2
.github/workflows/reminder_comment.yml
vendored
2
.github/workflows/reminder_comment.yml
vendored
@ -15,7 +15,7 @@ jobs:
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: context.issue.number,
|
||||
body: '👋 Hi! Thank you for contributing to the vLLM project.\n Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run `fastcheck` CI which consists a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of default ones by unblocking the steps in your `fast-check` build on Buildkite UI. \n\nOnce the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge).\n\n To run full CI, you can do one of these:\n- Comment `/ready` on the PR\n- Add `ready` label to the PR\n- Enable auto-merge.\n\n🚀'
|
||||
body: '👋 Hi! Thank you for contributing to the vLLM project.\n Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run `fastcheck` CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your `fastcheck` build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping `simon-mo` or `khluu` to add you in our Buildkite org. \n\nOnce the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.\n\n To run CI, PR reviewers can do one of these:\n- Add `ready` label to the PR\n- Enable auto-merge.\n\n🚀'
|
||||
})
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
@ -1,23 +0,0 @@
|
||||
name: Remove ready Label on notready Comment
|
||||
|
||||
on:
|
||||
issue_comment:
|
||||
types: [created]
|
||||
|
||||
jobs:
|
||||
add-ready-label:
|
||||
runs-on: ubuntu-latest
|
||||
if: github.event.issue.pull_request && contains(github.event.comment.body, '/notready')
|
||||
steps:
|
||||
- name: Remove ready label
|
||||
uses: actions/github-script@v5
|
||||
with:
|
||||
script: |
|
||||
github.rest.issues.removeLabel({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: context.issue.number,
|
||||
name: 'ready'
|
||||
})
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
@ -181,7 +181,6 @@ set(VLLM_EXT_SRC
|
||||
"csrc/pos_encoding_kernels.cu"
|
||||
"csrc/activation_kernels.cu"
|
||||
"csrc/layernorm_kernels.cu"
|
||||
"csrc/quantization/squeezellm/quant_cuda_kernel.cu"
|
||||
"csrc/quantization/gptq/q_gemm.cu"
|
||||
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
|
||||
"csrc/quantization/fp8/common.cu"
|
||||
@ -196,13 +195,19 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
FetchContent_Declare(
|
||||
cutlass
|
||||
GIT_REPOSITORY https://github.com/nvidia/cutlass.git
|
||||
# CUTLASS 3.5.1
|
||||
GIT_TAG 06b21349bcf6ddf6a1686a47a137ad1446579db9
|
||||
GIT_TAG v3.5.1
|
||||
GIT_PROGRESS TRUE
|
||||
|
||||
# Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history.
|
||||
# Important: If GIT_SHALLOW is enabled then GIT_TAG works only with branch names and tags.
|
||||
# So if the GIT_TAG above is updated to a commit hash, GIT_SHALLOW must be set to FALSE
|
||||
GIT_SHALLOW TRUE
|
||||
)
|
||||
FetchContent_MakeAvailable(cutlass)
|
||||
|
||||
list(APPEND VLLM_EXT_SRC
|
||||
"csrc/mamba/mamba_ssm/selective_scan_fwd.cu"
|
||||
"csrc/mamba/causal_conv1d/causal_conv1d.cu"
|
||||
"csrc/quantization/aqlm/gemm_kernels.cu"
|
||||
"csrc/quantization/awq/gemm_kernels.cu"
|
||||
"csrc/quantization/marlin/dense/marlin_cuda_kernel.cu"
|
||||
@ -230,6 +235,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
"-gencode arch=compute_90a,code=sm_90a")
|
||||
endif()
|
||||
|
||||
|
||||
#
|
||||
# Machete kernels
|
||||
|
||||
@ -288,6 +294,12 @@ define_gpu_extension_target(
|
||||
USE_SABI 3
|
||||
WITH_SOABI)
|
||||
|
||||
# If CUTLASS is compiled on NVCC >= 12.5, it by default uses
|
||||
# cudaGetDriverEntryPointByVersion as a wrapper to avoid directly calling the
|
||||
# driver API. This causes problems when linking with earlier versions of CUDA.
|
||||
# Setting this variable sidesteps the issue by calling the driver directly.
|
||||
target_compile_definitions(_C PRIVATE CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1)
|
||||
|
||||
#
|
||||
# _moe_C extension
|
||||
#
|
||||
@ -296,6 +308,11 @@ set(VLLM_MOE_EXT_SRC
|
||||
"csrc/moe/torch_bindings.cpp"
|
||||
"csrc/moe/topk_softmax_kernels.cu")
|
||||
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
list(APPEND VLLM_MOE_EXT_SRC
|
||||
"csrc/moe/marlin_moe_ops.cu")
|
||||
endif()
|
||||
|
||||
define_gpu_extension_target(
|
||||
_moe_C
|
||||
DESTINATION vllm
|
||||
|
128
CODE_OF_CONDUCT.md
Normal file
128
CODE_OF_CONDUCT.md
Normal file
@ -0,0 +1,128 @@
|
||||
|
||||
# vLLM Code of Conduct
|
||||
|
||||
## Our Pledge
|
||||
|
||||
We as members, contributors, and leaders pledge to make participation in our
|
||||
community a harassment-free experience for everyone, regardless of age, body
|
||||
size, visible or invisible disability, ethnicity, sex characteristics, gender
|
||||
identity and expression, level of experience, education, socioeconomic status,
|
||||
nationality, personal appearance, race, caste, color, religion, or sexual
|
||||
identity and orientation.
|
||||
|
||||
We pledge to act and interact in ways that contribute to an open, welcoming,
|
||||
diverse, inclusive, and healthy community.
|
||||
|
||||
## Our Standards
|
||||
|
||||
Examples of behavior that contributes to a positive environment for our
|
||||
community include:
|
||||
|
||||
* Demonstrating empathy and kindness toward other people
|
||||
* Being respectful of differing opinions, viewpoints, and experiences
|
||||
* Giving and gracefully accepting constructive feedback
|
||||
* Accepting responsibility and apologizing to those affected by our mistakes,
|
||||
and learning from the experience
|
||||
* Focusing on what is best not just for us as individuals, but for the overall
|
||||
community
|
||||
|
||||
Examples of unacceptable behavior include:
|
||||
|
||||
* The use of sexualized language or imagery, and sexual attention or advances of
|
||||
any kind
|
||||
* Trolling, insulting or derogatory comments, and personal or political attacks
|
||||
* Public or private harassment
|
||||
* Publishing others' private information, such as a physical or email address,
|
||||
without their explicit permission
|
||||
* Other conduct which could reasonably be considered inappropriate in a
|
||||
professional setting
|
||||
|
||||
## Enforcement Responsibilities
|
||||
|
||||
Community leaders are responsible for clarifying and enforcing our standards of
|
||||
acceptable behavior and will take appropriate and fair corrective action in
|
||||
response to any behavior that they deem inappropriate, threatening, offensive,
|
||||
or harmful.
|
||||
|
||||
Community leaders have the right and responsibility to remove, edit, or reject
|
||||
comments, commits, code, wiki edits, issues, and other contributions that are
|
||||
not aligned to this Code of Conduct, and will communicate reasons for moderation
|
||||
decisions when appropriate.
|
||||
|
||||
## Scope
|
||||
|
||||
This Code of Conduct applies within all community spaces, and also applies when
|
||||
an individual is officially representing the community in public spaces.
|
||||
Examples of representing our community include using an official email address,
|
||||
posting via an official social media account, or acting as an appointed
|
||||
representative at an online or offline/IRL event.
|
||||
|
||||
## Enforcement
|
||||
|
||||
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
||||
reported to the community leaders responsible for enforcement in the #code-of-conduct
|
||||
channel in the [vLLM Discord](https://discord.com/invite/jz7wjKhh6g).
|
||||
All complaints will be reviewed and investigated promptly and fairly.
|
||||
|
||||
All community leaders are obligated to respect the privacy and security of the
|
||||
reporter of any incident.
|
||||
|
||||
## Enforcement Guidelines
|
||||
|
||||
Community leaders will follow these Community Impact Guidelines in determining
|
||||
the consequences for any action they deem in violation of this Code of Conduct:
|
||||
|
||||
### 1. Correction
|
||||
|
||||
**Community Impact**: Use of inappropriate language or other behavior deemed
|
||||
unprofessional or unwelcome in the community.
|
||||
|
||||
**Consequence**: A private, written warning from community leaders, providing
|
||||
clarity around the nature of the violation and an explanation of why the
|
||||
behavior was inappropriate. A public apology may be requested.
|
||||
|
||||
### 2. Warning
|
||||
|
||||
**Community Impact**: A violation through a single incident or series of
|
||||
actions.
|
||||
|
||||
**Consequence**: A warning with consequences for continued behavior. No
|
||||
interaction with the people involved, including unsolicited interaction with
|
||||
those enforcing the Code of Conduct, for a specified period of time. This
|
||||
includes avoiding interactions in community spaces as well as external channels
|
||||
like social media. Violating these terms may lead to a temporary or permanent
|
||||
ban.
|
||||
|
||||
### 3. Temporary Ban
|
||||
|
||||
**Community Impact**: A serious violation of community standards, including
|
||||
sustained inappropriate behavior.
|
||||
|
||||
**Consequence**: A temporary ban from any sort of interaction or public
|
||||
communication with the community for a specified period of time. No public or
|
||||
private interaction with the people involved, including unsolicited interaction
|
||||
with those enforcing the Code of Conduct, is allowed during this period.
|
||||
Violating these terms may lead to a permanent ban.
|
||||
|
||||
### 4. Permanent Ban
|
||||
|
||||
**Community Impact**: Demonstrating a pattern of violation of community
|
||||
standards, including sustained inappropriate behavior, harassment of an
|
||||
individual, or aggression toward or disparagement of classes of individuals.
|
||||
|
||||
**Consequence**: A permanent ban from any sort of public interaction within the
|
||||
community.
|
||||
|
||||
## Attribution
|
||||
|
||||
This Code of Conduct is adapted from the [Contributor Covenant](https://www.contributor-covenant.org/),
|
||||
version 2.1, available at
|
||||
[v2.1](https://www.contributor-covenant.org/version/2/1/code_of_conduct.html).
|
||||
|
||||
Community Impact Guidelines were inspired by
|
||||
[Mozilla's code of conduct enforcement ladder](https://github.com/mozilla/inclusion).
|
||||
|
||||
For answers to common questions about this code of conduct, see the
|
||||
[Contributor Covenant FAQ](https://www.contributor-covenant.org/faq). Translations are available at
|
||||
[Contributor Covenant translations](https://www.contributor-covenant.org/translations).
|
||||
|
49
Dockerfile
49
Dockerfile
@ -10,7 +10,7 @@ ARG CUDA_VERSION=12.4.1
|
||||
# prepare basic build environment
|
||||
FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04 AS base
|
||||
ARG CUDA_VERSION=12.4.1
|
||||
ARG PYTHON_VERSION=3.10
|
||||
ARG PYTHON_VERSION=3.12
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
# Install Python and other dependencies
|
||||
@ -37,14 +37,10 @@ WORKDIR /workspace
|
||||
|
||||
# install build and runtime dependencies
|
||||
COPY requirements-common.txt requirements-common.txt
|
||||
COPY requirements-adag.txt requirements-adag.txt
|
||||
COPY requirements-cuda.txt requirements-cuda.txt
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
python3 -m pip install -r requirements-cuda.txt
|
||||
|
||||
COPY requirements-mamba.txt requirements-mamba.txt
|
||||
RUN python3 -m pip install packaging
|
||||
RUN python3 -m pip install -r requirements-mamba.txt
|
||||
|
||||
# cuda arch list used by torch
|
||||
# can be useful for both `dev` and `test`
|
||||
@ -69,7 +65,6 @@ COPY setup.py setup.py
|
||||
COPY cmake cmake
|
||||
COPY CMakeLists.txt CMakeLists.txt
|
||||
COPY requirements-common.txt requirements-common.txt
|
||||
COPY requirements-adag.txt requirements-adag.txt
|
||||
COPY requirements-cuda.txt requirements-cuda.txt
|
||||
COPY pyproject.toml pyproject.toml
|
||||
COPY vllm vllm
|
||||
@ -111,10 +106,17 @@ RUN --mount=type=cache,target=/root/.cache/ccache \
|
||||
python3 setup.py bdist_wheel --dist-dir=dist --py-limited-api=cp38; \
|
||||
fi
|
||||
|
||||
# check the size of the wheel, we cannot upload wheels larger than 100MB
|
||||
# Check the size of the wheel if RUN_WHEEL_CHECK is true
|
||||
COPY .buildkite/check-wheel-size.py check-wheel-size.py
|
||||
RUN python3 check-wheel-size.py dist
|
||||
|
||||
# Default max size of the wheel is 250MB
|
||||
ARG VLLM_MAX_SIZE_MB=250
|
||||
ENV VLLM_MAX_SIZE_MB=$VLLM_MAX_SIZE_MB
|
||||
ARG RUN_WHEEL_CHECK=true
|
||||
RUN if [ "$RUN_WHEEL_CHECK" = "true" ]; then \
|
||||
python3 check-wheel-size.py dist; \
|
||||
else \
|
||||
echo "Skipping wheel size check."; \
|
||||
fi
|
||||
#################### EXTENSION Build IMAGE ####################
|
||||
|
||||
#################### DEV IMAGE ####################
|
||||
@ -127,27 +129,11 @@ RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
python3 -m pip install -r requirements-dev.txt
|
||||
|
||||
#################### DEV IMAGE ####################
|
||||
#################### MAMBA Build IMAGE ####################
|
||||
FROM dev as mamba-builder
|
||||
# max jobs used for build
|
||||
ARG max_jobs=2
|
||||
ENV MAX_JOBS=${max_jobs}
|
||||
|
||||
WORKDIR /usr/src/mamba
|
||||
|
||||
COPY requirements-mamba.txt requirements-mamba.txt
|
||||
|
||||
# Download the wheel or build it if a pre-compiled release doesn't exist
|
||||
RUN pip --verbose wheel -r requirements-mamba.txt \
|
||||
--no-build-isolation --no-deps --no-cache-dir
|
||||
|
||||
#################### MAMBA Build IMAGE ####################
|
||||
|
||||
#################### vLLM installation IMAGE ####################
|
||||
# image with vLLM installed
|
||||
FROM nvidia/cuda:${CUDA_VERSION}-base-ubuntu20.04 AS vllm-base
|
||||
ARG CUDA_VERSION=12.4.1
|
||||
ARG PYTHON_VERSION=3.10
|
||||
ARG PYTHON_VERSION=3.12
|
||||
WORKDIR /vllm-workspace
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
@ -159,6 +145,7 @@ RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
|
||||
&& echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \
|
||||
&& apt-get update -y \
|
||||
&& apt-get install -y ccache software-properties-common git curl sudo vim python3-pip \
|
||||
&& apt-get install -y ffmpeg libsm6 libxext6 libgl1 \
|
||||
&& add-apt-repository ppa:deadsnakes/ppa \
|
||||
&& apt-get update -y \
|
||||
&& apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv libibverbs-dev \
|
||||
@ -179,13 +166,9 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist
|
||||
--mount=type=cache,target=/root/.cache/pip \
|
||||
python3 -m pip install dist/*.whl --verbose
|
||||
|
||||
RUN --mount=type=bind,from=mamba-builder,src=/usr/src/mamba,target=/usr/src/mamba \
|
||||
--mount=type=cache,target=/root/.cache/pip \
|
||||
python3 -m pip install /usr/src/mamba/*.whl --no-cache-dir
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
. /etc/environment && \
|
||||
python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.4/flashinfer-0.1.4+cu121torch2.4-cp${PYTHON_VERSION_STR}-cp${PYTHON_VERSION_STR}-linux_x86_64.whl
|
||||
python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.6/flashinfer-0.1.6+cu121torch2.4-cp${PYTHON_VERSION_STR}-cp${PYTHON_VERSION_STR}-linux_x86_64.whl
|
||||
#################### vLLM installation IMAGE ####################
|
||||
|
||||
|
||||
@ -197,6 +180,10 @@ FROM vllm-base AS test
|
||||
ADD . /vllm-workspace/
|
||||
|
||||
# install development dependencies (for testing)
|
||||
# A newer setuptools is required for installing some test dependencies from source that do not publish python 3.12 wheels
|
||||
# This installation must complete before the test dependencies are collected and installed.
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
python3 -m pip install "setuptools>=74.1.1"
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
python3 -m pip install -r requirements-dev.txt
|
||||
|
||||
|
@ -2,9 +2,14 @@
|
||||
|
||||
FROM ubuntu:22.04 AS cpu-test-1
|
||||
|
||||
ENV CCACHE_DIR=/root/.cache/ccache
|
||||
|
||||
ENV CMAKE_CXX_COMPILER_LAUNCHER=ccache
|
||||
|
||||
RUN --mount=type=cache,target=/var/cache/apt \
|
||||
apt-get update -y \
|
||||
&& apt-get install -y curl ccache git wget vim numactl gcc-12 g++-12 python3 python3-pip libtcmalloc-minimal4 libnuma-dev \
|
||||
&& apt-get install -y ffmpeg libsm6 libxext6 libgl1 \
|
||||
&& update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12
|
||||
|
||||
# https://intel.github.io/intel-extension-for-pytorch/cpu/latest/tutorials/performance_tuning/tuning_guide.html
|
||||
@ -25,6 +30,19 @@ RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
pip install --upgrade pip && \
|
||||
pip install -r requirements-build.txt
|
||||
|
||||
# install oneDNN
|
||||
RUN git clone -b rls-v3.5 https://github.com/oneapi-src/oneDNN.git
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/ccache \
|
||||
cmake -B ./oneDNN/build -S ./oneDNN -G Ninja -DONEDNN_LIBRARY_TYPE=STATIC \
|
||||
-DONEDNN_BUILD_DOC=OFF \
|
||||
-DONEDNN_BUILD_EXAMPLES=OFF \
|
||||
-DONEDNN_BUILD_TESTS=OFF \
|
||||
-DONEDNN_BUILD_GRAPH=OFF \
|
||||
-DONEDNN_ENABLE_WORKLOAD=INFERENCE \
|
||||
-DONEDNN_ENABLE_PRIMITIVE=MATMUL && \
|
||||
cmake --build ./oneDNN/build --target install --config Release
|
||||
|
||||
FROM cpu-test-1 AS build
|
||||
|
||||
WORKDIR /workspace/vllm
|
||||
@ -40,7 +58,6 @@ COPY ./ ./
|
||||
ARG VLLM_CPU_DISABLE_AVX512
|
||||
ENV VLLM_CPU_DISABLE_AVX512=${VLLM_CPU_DISABLE_AVX512}
|
||||
|
||||
ENV CCACHE_DIR=/root/.cache/ccache
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
--mount=type=cache,target=/root/.cache/ccache \
|
||||
VLLM_TARGET_DEVICE=cpu python3 setup.py bdist_wheel && \
|
||||
|
@ -6,7 +6,9 @@ FROM $BASE_IMAGE
|
||||
RUN echo "Base image is $BASE_IMAGE"
|
||||
|
||||
# Install some basic utilities
|
||||
RUN apt-get update && apt-get install python3 python3-pip -y
|
||||
RUN apt-get update \
|
||||
&& apt-get install python3 python3-pip -y \
|
||||
&& apt-get install -y ffmpeg libsm6 libxext6 libgl1
|
||||
|
||||
### Mount Point ###
|
||||
# When launching the container, mount the code directory to /app
|
||||
|
@ -4,7 +4,8 @@
|
||||
FROM ubuntu:22.04 AS dev
|
||||
|
||||
RUN apt-get update -y && \
|
||||
apt-get install -y python3-pip git
|
||||
apt-get install -y python3-pip git && \
|
||||
apt-get install -y ffmpeg libsm6 libxext6 libgl1
|
||||
WORKDIR /workspace
|
||||
|
||||
# copy requirements
|
||||
|
@ -2,21 +2,26 @@ FROM mambaorg/micromamba
|
||||
ARG MAMBA_DOCKERFILE_ACTIVATE=1
|
||||
USER root
|
||||
|
||||
RUN apt-get update -y && apt-get install -y git wget vim numactl gcc-12 g++-12 protobuf-compiler libprotobuf-dev && update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12
|
||||
ENV PATH="/usr/local/cargo/bin:$PATH:/opt/conda/bin/"
|
||||
|
||||
RUN apt-get update -y && apt-get install -y git wget curl vim libnuma-dev libsndfile-dev libprotobuf-dev build-essential ffmpeg libsm6 libxext6 libgl1
|
||||
|
||||
# Some packages in requirements-cpu are installed here
|
||||
# IBM provides optimized packages for ppc64le processors in the open-ce project for mamba
|
||||
# Currently these may not be available for venv or pip directly
|
||||
RUN micromamba install -y -n base -c https://ftp.osuosl.org/pub/open-ce/1.11.0-p10/ -c defaults python=3.10 pytorch-cpu=2.1.2 torchvision-cpu=0.16.2 && micromamba clean --all --yes
|
||||
RUN micromamba install -y -n base -c https://ftp.osuosl.org/pub/open-ce/1.11.0-p10/ -c defaults python=3.10 torchvision-cpu=0.16.2 rust && micromamba clean --all --yes
|
||||
|
||||
COPY ./ /workspace/vllm
|
||||
|
||||
WORKDIR /workspace/vllm
|
||||
|
||||
# These packages will be in rocketce eventually
|
||||
RUN pip install -v -r requirements-cpu.txt --prefer-binary --extra-index-url https://repo.fury.io/mgiessing
|
||||
RUN pip install -v cmake xformers torch==2.3.1 uvloop==0.20.0 -r requirements-cpu.txt --prefer-binary --extra-index-url https://repo.fury.io/mgiessing
|
||||
|
||||
RUN VLLM_TARGET_DEVICE=cpu python3 setup.py install
|
||||
|
||||
WORKDIR /vllm-workspace
|
||||
ENTRYPOINT ["/opt/conda/bin/python3", "-m", "vllm.entrypoints.openai.api_server"]
|
||||
WORKDIR /workspace/
|
||||
|
||||
RUN ln -s /workspace/vllm/tests && ln -s /workspace/vllm/examples && ln -s /workspace/vllm/benchmarks
|
||||
|
||||
ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"]
|
||||
|
@ -1,9 +1,12 @@
|
||||
ARG NIGHTLY_DATE="20240808"
|
||||
ARG NIGHTLY_DATE="20240828"
|
||||
ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_$NIGHTLY_DATE"
|
||||
|
||||
FROM $BASE_IMAGE
|
||||
WORKDIR /workspace
|
||||
|
||||
# Install some basic utilities
|
||||
RUN apt-get update && apt-get install -y ffmpeg libsm6 libxext6 libgl1
|
||||
|
||||
# Install the TPU and Pallas dependencies.
|
||||
RUN python3 -m pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
|
||||
RUN python3 -m pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
|
||||
|
@ -9,8 +9,7 @@ RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRO
|
||||
chmod 644 /usr/share/keyrings/intel-graphics.gpg
|
||||
|
||||
RUN apt-get update -y \
|
||||
&& apt-get install -y curl libicu70 lsb-release git wget vim numactl python3 python3-pip
|
||||
|
||||
&& apt-get install -y curl libicu70 lsb-release git wget vim numactl python3 python3-pip ffmpeg libsm6 libxext6 libgl1
|
||||
COPY ./ /workspace/vllm
|
||||
|
||||
WORKDIR /workspace/vllm
|
||||
|
@ -1,5 +1,4 @@
|
||||
include LICENSE
|
||||
include requirements-adag.txt
|
||||
include requirements-common.txt
|
||||
include requirements-cuda.txt
|
||||
include requirements-rocm.txt
|
||||
|
16
README.md
16
README.md
@ -17,15 +17,16 @@ Easy, fast, and cheap LLM serving for everyone
|
||||
|
||||
---
|
||||
|
||||
**vLLM & NVIDIA Triton User Meetup (Monday, September 9, 5pm-9pm PT) at Fort Mason, San Francisco**
|
||||
**vLLM, AMD, Anyscale Meet & Greet at [Ray Summit 2024](http://raysummit.anyscale.com) (Monday, Sept 30th, 5-7pm PT) at Marriott Marquis San Francisco**
|
||||
|
||||
We are excited to announce our sixth vLLM Meetup, in collaboration with NVIDIA Triton Team.
|
||||
Join us to hear the vLLM's recent update about performance.
|
||||
Register now [here](https://lu.ma/87q3nvnh) and be part of the event!
|
||||
We are excited to announce our special vLLM event in collaboration with AMD and Anyscale.
|
||||
Join us to learn more about recent advancements of vLLM on MI300X.
|
||||
Register [here](https://lu.ma/db5ld9n5) and be a part of the event!
|
||||
|
||||
---
|
||||
|
||||
*Latest News* 🔥
|
||||
- [2024/09] We hosted [the sixth vLLM meetup](https://lu.ma/87q3nvnh) with NVIDIA! Please find the meetup slides [here](https://docs.google.com/presentation/d/1wrLGwytQfaOTd5wCGSPNhoaW3nq0E-9wqyP7ny93xRs/edit?usp=sharing).
|
||||
- [2024/07] We hosted [the fifth vLLM meetup](https://lu.ma/lp0gyjqr) with AWS! Please find the meetup slides [here](https://docs.google.com/presentation/d/1RgUD8aCfcHocghoP3zmXzck9vX3RCI9yfUAB2Bbcl4Y/edit?usp=sharing).
|
||||
- [2024/07] In partnership with Meta, vLLM officially supports Llama 3.1 with FP8 quantization and pipeline parallelism! Please check out our blog post [here](https://blog.vllm.ai/2024/07/23/llama31.html).
|
||||
- [2024/06] We hosted [the fourth vLLM meetup](https://lu.ma/agivllm) with Cloudflare and BentoML! Please find the meetup slides [here](https://docs.google.com/presentation/d/1iJ8o7V2bQEi0BFEljLTwc5G1S10_Rhv3beed5oB0NJ4/edit?usp=sharing).
|
||||
@ -130,3 +131,10 @@ If you use vLLM for your research, please cite our [paper](https://arxiv.org/abs
|
||||
year={2023}
|
||||
}
|
||||
```
|
||||
|
||||
## Contact Us
|
||||
|
||||
* For technical questions and feature requests, please use Github issues or discussions.
|
||||
* For discussing with fellow users, please use Discord.
|
||||
* For security disclosures, please use Github's security advisory feature.
|
||||
* For collaborations and partnerships, please contact us at vllm-questions AT lists.berkeley.edu.
|
@ -24,6 +24,7 @@ class RequestFuncInput:
|
||||
model: str
|
||||
best_of: int = 1
|
||||
use_beam_search: bool = False
|
||||
logprobs: Optional[int] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -236,6 +237,7 @@ async def async_request_openai_completions(
|
||||
"temperature": 0.0,
|
||||
"best_of": request_func_input.best_of,
|
||||
"max_tokens": request_func_input.output_len,
|
||||
"logprobs": request_func_input.logprobs,
|
||||
"stream": True,
|
||||
}
|
||||
headers = {
|
||||
|
@ -10,7 +10,7 @@ import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.engine.arg_utils import DEVICE_OPTIONS, EngineArgs
|
||||
from vllm.inputs import PromptInputs
|
||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
@ -205,13 +205,11 @@ if __name__ == '__main__':
|
||||
default=None,
|
||||
help=('path to save the pytorch profiler output. Can be visualized '
|
||||
'with ui.perfetto.dev or Tensorboard.'))
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
parser.add_argument("--device",
|
||||
type=str,
|
||||
default="auto",
|
||||
choices=["auto", "cuda", "cpu", "openvino", "tpu", "xpu"],
|
||||
help='device type for vLLM execution, supporting CUDA, OpenVINO and '
|
||||
'CPU.')
|
||||
choices=DEVICE_OPTIONS,
|
||||
help='device type for vLLM execution')
|
||||
parser.add_argument('--block-size',
|
||||
type=int,
|
||||
default=16,
|
||||
|
@ -56,20 +56,27 @@ class BenchmarkMetrics:
|
||||
total_input: int
|
||||
total_output: int
|
||||
request_throughput: float
|
||||
input_throughput: float
|
||||
output_throughput: float
|
||||
total_token_throughput: float
|
||||
mean_ttft_ms: float
|
||||
median_ttft_ms: float
|
||||
std_ttft_ms: float
|
||||
p99_ttft_ms: float
|
||||
percentiles_ttft_ms: List[Tuple[float, float]]
|
||||
mean_tpot_ms: float
|
||||
median_tpot_ms: float
|
||||
std_tpot_ms: float
|
||||
p99_tpot_ms: float
|
||||
percentiles_tpot_ms: List[Tuple[float, float]]
|
||||
mean_itl_ms: float
|
||||
median_itl_ms: float
|
||||
std_itl_ms: float
|
||||
p99_itl_ms: float
|
||||
percentiles_itl_ms: List[Tuple[float, float]]
|
||||
# E2EL stands for end-to-end latency per request.
|
||||
# It is the time taken on the client side from sending
|
||||
# a request to receiving a complete response.
|
||||
mean_e2el_ms: float
|
||||
median_e2el_ms: float
|
||||
std_e2el_ms: float
|
||||
percentiles_e2el_ms: List[Tuple[float, float]]
|
||||
|
||||
|
||||
def sample_sharegpt_requests(
|
||||
@ -188,8 +195,16 @@ def sample_sonnet_requests(
|
||||
|
||||
|
||||
def sample_random_requests(
|
||||
input_len: int, output_len: int, num_prompts: int, range_ratio: float,
|
||||
tokenizer: PreTrainedTokenizerBase) -> List[Tuple[str, int, int]]:
|
||||
prefix_len: int,
|
||||
input_len: int,
|
||||
output_len: int,
|
||||
num_prompts: int,
|
||||
range_ratio: float,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
) -> List[Tuple[str, int, int]]:
|
||||
prefix_token_ids = np.random.randint(0,
|
||||
tokenizer.vocab_size,
|
||||
size=prefix_len).tolist()
|
||||
|
||||
input_lens = np.random.randint(
|
||||
int(input_len * range_ratio),
|
||||
@ -204,10 +219,12 @@ def sample_random_requests(
|
||||
offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts)
|
||||
input_requests = []
|
||||
for i in range(num_prompts):
|
||||
prompt = tokenizer.decode([(offsets[i] + i + j) % tokenizer.vocab_size
|
||||
prompt = tokenizer.decode(prefix_token_ids +
|
||||
[(offsets[i] + i + j) % tokenizer.vocab_size
|
||||
for j in range(input_lens[i])])
|
||||
|
||||
input_requests.append(
|
||||
(prompt, int(input_lens[i]), int(output_lens[i])))
|
||||
(prompt, int(prefix_len + input_lens[i]), int(output_lens[i])))
|
||||
|
||||
return input_requests
|
||||
|
||||
@ -235,6 +252,8 @@ def calculate_metrics(
|
||||
outputs: List[RequestFuncOutput],
|
||||
dur_s: float,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
selected_percentile_metrics: List[str],
|
||||
selected_percentiles: List[float],
|
||||
) -> Tuple[BenchmarkMetrics, List[int]]:
|
||||
actual_output_lens: List[int] = []
|
||||
total_input = 0
|
||||
@ -242,6 +261,7 @@ def calculate_metrics(
|
||||
itls: List[float] = []
|
||||
tpots: List[float] = []
|
||||
ttfts: List[float] = []
|
||||
e2els: List[float] = []
|
||||
for i in range(len(outputs)):
|
||||
if outputs[i].success:
|
||||
# We use the tokenizer to count the number of output tokens for all
|
||||
@ -258,6 +278,7 @@ def calculate_metrics(
|
||||
(outputs[i].latency - outputs[i].ttft) / (output_len - 1))
|
||||
itls += outputs[i].itl
|
||||
ttfts.append(outputs[i].ttft)
|
||||
e2els.append(outputs[i].latency)
|
||||
completed += 1
|
||||
else:
|
||||
actual_output_lens.append(0)
|
||||
@ -272,21 +293,29 @@ def calculate_metrics(
|
||||
total_input=total_input,
|
||||
total_output=sum(actual_output_lens),
|
||||
request_throughput=completed / dur_s,
|
||||
input_throughput=total_input / dur_s,
|
||||
output_throughput=sum(actual_output_lens) / dur_s,
|
||||
total_token_throughput=(total_input + sum(actual_output_lens)) / dur_s,
|
||||
mean_ttft_ms=np.mean(ttfts or 0) *
|
||||
1000, # ttfts is empty if streaming is not supported by backend
|
||||
median_ttft_ms=np.median(ttfts or 0) * 1000,
|
||||
std_ttft_ms=np.std(ttfts or 0) * 1000,
|
||||
p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000,
|
||||
median_ttft_ms=np.median(ttfts or 0) * 1000,
|
||||
percentiles_ttft_ms=[(p, np.percentile(ttfts or 0, p) * 1000)
|
||||
for p in selected_percentiles],
|
||||
mean_tpot_ms=np.mean(tpots or 0) * 1000,
|
||||
median_tpot_ms=np.median(tpots or 0) * 1000,
|
||||
std_tpot_ms=np.std(tpots or 0) * 1000,
|
||||
p99_tpot_ms=np.percentile(tpots or 0, 99) * 1000,
|
||||
median_tpot_ms=np.median(tpots or 0) * 1000,
|
||||
percentiles_tpot_ms=[(p, np.percentile(tpots or 0, p) * 1000)
|
||||
for p in selected_percentiles],
|
||||
mean_itl_ms=np.mean(itls or 0) * 1000,
|
||||
median_itl_ms=np.median(itls or 0) * 1000,
|
||||
std_itl_ms=np.std(itls or 0) * 1000,
|
||||
p99_itl_ms=np.percentile(itls or 0, 99) * 1000,
|
||||
median_itl_ms=np.median(itls or 0) * 1000,
|
||||
percentiles_itl_ms=[(p, np.percentile(itls or 0, p) * 1000)
|
||||
for p in selected_percentiles],
|
||||
mean_e2el_ms=np.median(e2els or 0) * 1000,
|
||||
std_e2el_ms=np.std(e2els or 0) * 1000,
|
||||
median_e2el_ms=np.mean(e2els or 0) * 1000,
|
||||
percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000)
|
||||
for p in selected_percentiles],
|
||||
)
|
||||
|
||||
return metrics, actual_output_lens
|
||||
@ -299,11 +328,14 @@ async def benchmark(
|
||||
model_id: str,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
input_requests: List[Tuple[str, int, int]],
|
||||
logprobs: Optional[int],
|
||||
best_of: int,
|
||||
use_beam_search: bool,
|
||||
request_rate: float,
|
||||
disable_tqdm: bool,
|
||||
profile: bool,
|
||||
selected_percentile_metrics: List[str],
|
||||
selected_percentiles: List[str],
|
||||
):
|
||||
if backend in ASYNC_REQUEST_FUNCS:
|
||||
request_func = ASYNC_REQUEST_FUNCS[backend]
|
||||
@ -318,6 +350,7 @@ async def benchmark(
|
||||
api_url=api_url,
|
||||
prompt_len=test_prompt_len,
|
||||
output_len=test_output_len,
|
||||
logprobs=logprobs,
|
||||
best_of=best_of,
|
||||
use_beam_search=use_beam_search,
|
||||
)
|
||||
@ -337,6 +370,7 @@ async def benchmark(
|
||||
api_url=base_url + "/start_profile",
|
||||
prompt_len=test_prompt_len,
|
||||
output_len=test_output_len,
|
||||
logprobs=logprobs,
|
||||
best_of=best_of,
|
||||
use_beam_search=use_beam_search,
|
||||
)
|
||||
@ -358,6 +392,7 @@ async def benchmark(
|
||||
api_url=api_url,
|
||||
prompt_len=prompt_len,
|
||||
output_len=output_len,
|
||||
logprobs=logprobs,
|
||||
best_of=best_of,
|
||||
use_beam_search=use_beam_search,
|
||||
)
|
||||
@ -375,6 +410,7 @@ async def benchmark(
|
||||
api_url=base_url + "/stop_profile",
|
||||
prompt_len=test_prompt_len,
|
||||
output_len=test_output_len,
|
||||
logprobs=logprobs,
|
||||
best_of=best_of,
|
||||
use_beam_search=use_beam_search,
|
||||
)
|
||||
@ -392,6 +428,8 @@ async def benchmark(
|
||||
outputs=outputs,
|
||||
dur_s=benchmark_duration,
|
||||
tokenizer=tokenizer,
|
||||
selected_percentile_metrics=selected_percentile_metrics,
|
||||
selected_percentiles=selected_percentiles,
|
||||
)
|
||||
|
||||
print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='='))
|
||||
@ -403,27 +441,10 @@ async def benchmark(
|
||||
metrics.total_output))
|
||||
print("{:<40} {:<10.2f}".format("Request throughput (req/s):",
|
||||
metrics.request_throughput))
|
||||
print("{:<40} {:<10.2f}".format("Input token throughput (tok/s):",
|
||||
metrics.input_throughput))
|
||||
print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):",
|
||||
metrics.output_throughput))
|
||||
print("{s:{c}^{n}}".format(s='Time to First Token', n=50, c='-'))
|
||||
print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms))
|
||||
print("{:<40} {:<10.2f}".format("Median TTFT (ms):",
|
||||
metrics.median_ttft_ms))
|
||||
print("{:<40} {:<10.2f}".format("P99 TTFT (ms):", metrics.p99_ttft_ms))
|
||||
print("{s:{c}^{n}}".format(s='Time per Output Token (excl. 1st token)',
|
||||
n=50,
|
||||
c='-'))
|
||||
print("{:<40} {:<10.2f}".format("Mean TPOT (ms):", metrics.mean_tpot_ms))
|
||||
print("{:<40} {:<10.2f}".format("Median TPOT (ms):",
|
||||
metrics.median_tpot_ms))
|
||||
print("{:<40} {:<10.2f}".format("P99 TPOT (ms):", metrics.p99_tpot_ms))
|
||||
print("{s:{c}^{n}}".format(s='Inter-token Latency', n=50, c='-'))
|
||||
print("{:<40} {:<10.2f}".format("Mean ITL (ms):", metrics.mean_itl_ms))
|
||||
print("{:<40} {:<10.2f}".format("Median ITL (ms):", metrics.median_itl_ms))
|
||||
print("{:<40} {:<10.2f}".format("P99 ITL (ms):", metrics.p99_itl_ms))
|
||||
print("=" * 50)
|
||||
print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):",
|
||||
metrics.total_token_throughput))
|
||||
|
||||
result = {
|
||||
"duration": benchmark_duration,
|
||||
@ -431,20 +452,8 @@ async def benchmark(
|
||||
"total_input_tokens": metrics.total_input,
|
||||
"total_output_tokens": metrics.total_output,
|
||||
"request_throughput": metrics.request_throughput,
|
||||
"input_throughput": metrics.input_throughput,
|
||||
"output_throughput": metrics.output_throughput,
|
||||
"mean_ttft_ms": metrics.mean_ttft_ms,
|
||||
"median_ttft_ms": metrics.median_ttft_ms,
|
||||
"std_ttft_ms": metrics.std_ttft_ms,
|
||||
"p99_ttft_ms": metrics.p99_ttft_ms,
|
||||
"mean_tpot_ms": metrics.mean_tpot_ms,
|
||||
"median_tpot_ms": metrics.median_tpot_ms,
|
||||
"std_tpot_ms": metrics.std_tpot_ms,
|
||||
"p99_tpot_ms": metrics.p99_tpot_ms,
|
||||
"mean_itl_ms": metrics.mean_itl_ms,
|
||||
"median_itl_ms": metrics.median_itl_ms,
|
||||
"std_itl_ms": metrics.std_itl_ms,
|
||||
"p99_itl_ms": metrics.p99_itl_ms,
|
||||
"total_token_throughput": metrics.total_token_throughput,
|
||||
"input_lens": [output.prompt_len for output in outputs],
|
||||
"output_lens": actual_output_lens,
|
||||
"ttfts": [output.ttft for output in outputs],
|
||||
@ -452,6 +461,47 @@ async def benchmark(
|
||||
"generated_texts": [output.generated_text for output in outputs],
|
||||
"errors": [output.error for output in outputs],
|
||||
}
|
||||
|
||||
def process_one_metric(
|
||||
# E.g., "ttft"
|
||||
metric_attribute_name: str,
|
||||
# E.g., "TTFT"
|
||||
metric_name: str,
|
||||
# E.g., "Time to First Token"
|
||||
metric_header: str,
|
||||
):
|
||||
# This function print and add statistics of the specified
|
||||
# metric.
|
||||
if metric_attribute_name not in selected_percentile_metrics:
|
||||
return
|
||||
print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-'))
|
||||
print("{:<40} {:<10.2f}".format(
|
||||
f"Mean {metric_name} (ms):",
|
||||
getattr(metrics, f"mean_{metric_attribute_name}_ms")))
|
||||
print("{:<40} {:<10.2f}".format(
|
||||
f"Median {metric_name} (ms):",
|
||||
getattr(metrics, f"median_{metric_attribute_name}_ms")))
|
||||
result[f"mean_{metric_attribute_name}_ms"] = getattr(
|
||||
metrics, f"mean_{metric_attribute_name}_ms")
|
||||
result[f"median_{metric_attribute_name}_ms"] = getattr(
|
||||
metrics, f"median_{metric_attribute_name}_ms")
|
||||
result[f"std_{metric_attribute_name}_ms"] = getattr(
|
||||
metrics, f"std_{metric_attribute_name}_ms")
|
||||
for p, value in getattr(metrics,
|
||||
f"percentiles_{metric_attribute_name}_ms"):
|
||||
p_word = str(int(p)) if int(p) == p else str(p)
|
||||
print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):",
|
||||
value))
|
||||
result[f"p{p_word}_{metric_attribute_name}_ms"] = value
|
||||
|
||||
process_one_metric("ttft", "TTFT", "Time to First Token")
|
||||
process_one_metric("tpot", "TPOT",
|
||||
"Time per Output Token (excl. 1st token)")
|
||||
process_one_metric("itl", "ITL", "Inter-token Latency")
|
||||
process_one_metric("e2el", "E2EL", "End-to-end Latency")
|
||||
|
||||
print("=" * 50)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@ -527,6 +577,7 @@ def main(args: argparse.Namespace):
|
||||
|
||||
elif args.dataset_name == "random":
|
||||
input_requests = sample_random_requests(
|
||||
prefix_len=args.random_prefix_len,
|
||||
input_len=args.random_input_len,
|
||||
output_len=args.random_output_len,
|
||||
num_prompts=args.num_prompts,
|
||||
@ -545,11 +596,16 @@ def main(args: argparse.Namespace):
|
||||
model_id=model_id,
|
||||
tokenizer=tokenizer,
|
||||
input_requests=input_requests,
|
||||
logprobs=args.logprobs,
|
||||
best_of=args.best_of,
|
||||
use_beam_search=args.use_beam_search,
|
||||
request_rate=args.request_rate,
|
||||
disable_tqdm=args.disable_tqdm,
|
||||
profile=args.profile,
|
||||
selected_percentile_metrics=args.percentile_metrics.split(","),
|
||||
selected_percentiles=[
|
||||
float(p) for p in args.metric_percentiles.split(",")
|
||||
],
|
||||
))
|
||||
|
||||
# Save config and results to json
|
||||
@ -682,6 +738,16 @@ if __name__ == "__main__":
|
||||
help=
|
||||
"Number of output tokens per request, used only for sonnet dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logprobs",
|
||||
type=int,
|
||||
default=None,
|
||||
help=("Number of logprobs-per-token to compute & return as part of "
|
||||
"the request. If unspecified, then either (1) if beam search "
|
||||
"is disabled, no logprobs are computed & a single dummy "
|
||||
"logprob is returned for each token; or (2) if beam search "
|
||||
"is enabled 1 logprob per token is computed"),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sonnet-prefix-len",
|
||||
type=int,
|
||||
@ -710,6 +776,14 @@ if __name__ == "__main__":
|
||||
help="Range of sampled ratio of input/output length, "
|
||||
"used only for random sampling.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--random-prefix-len",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Number of fixed prefix tokens before random "
|
||||
" context. The length range of context in a random "
|
||||
" request is [random-prefix-len, "
|
||||
" random-prefix-len + random-prefix-len * random-range-ratio).")
|
||||
parser.add_argument(
|
||||
"--request-rate",
|
||||
type=float,
|
||||
@ -765,6 +839,23 @@ if __name__ == "__main__":
|
||||
"{backend}-{args.request_rate}qps-{base_model_id}-{current_dt}.json"
|
||||
" format.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--percentile-metrics",
|
||||
type=str,
|
||||
default="ttft,tpot,itl",
|
||||
help="Comma-seperated list of selected metrics to report percentils. "
|
||||
"This argument specifies the metrics to report percentiles. "
|
||||
"Allowed metric names are \"ttft\", \"tpot\", \"itl\", \"e2el\". "
|
||||
"Default value is \"ttft,tpot,itl\".")
|
||||
parser.add_argument(
|
||||
"--metric-percentiles",
|
||||
type=str,
|
||||
default="99",
|
||||
help="Comma-seperated list of percentiles for selected metrics. "
|
||||
"To report 25-th, 50-th, and 75-th percentiles, use \"25,50,75\". "
|
||||
"Default value is \"99\". "
|
||||
"Use \"--percentile-metrics\" to select metrics.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
@ -6,13 +6,16 @@ import time
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import uvloop
|
||||
from tqdm import tqdm
|
||||
from transformers import (AutoModelForCausalLM, AutoTokenizer,
|
||||
PreTrainedTokenizerBase)
|
||||
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.engine.arg_utils import DEVICE_OPTIONS, AsyncEngineArgs, EngineArgs
|
||||
from vllm.entrypoints.openai.api_server import (
|
||||
build_async_engine_client_from_engine_args)
|
||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
from vllm.utils import FlexibleArgumentParser, merge_async_iterators
|
||||
|
||||
|
||||
def sample_requests(
|
||||
@ -82,8 +85,11 @@ def run_vllm(
|
||||
max_num_batched_tokens: int,
|
||||
distributed_executor_backend: Optional[str],
|
||||
gpu_memory_utilization: float = 0.9,
|
||||
num_scheduler_steps: int = 1,
|
||||
use_v2_block_manager: bool = False,
|
||||
download_dir: Optional[str] = None,
|
||||
load_format: str = EngineArgs.load_format,
|
||||
disable_async_output_proc: bool = False,
|
||||
) -> float:
|
||||
from vllm import LLM, SamplingParams
|
||||
llm = LLM(
|
||||
@ -106,6 +112,9 @@ def run_vllm(
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
load_format=load_format,
|
||||
num_scheduler_steps=num_scheduler_steps,
|
||||
use_v2_block_manager=use_v2_block_manager,
|
||||
disable_async_output_proc=disable_async_output_proc,
|
||||
)
|
||||
|
||||
# Add the requests to the engine.
|
||||
@ -129,6 +138,93 @@ def run_vllm(
|
||||
return end - start
|
||||
|
||||
|
||||
async def run_vllm_async(
|
||||
requests: List[Tuple[str, int, int]],
|
||||
model: str,
|
||||
tokenizer: str,
|
||||
quantization: Optional[str],
|
||||
tensor_parallel_size: int,
|
||||
seed: int,
|
||||
n: int,
|
||||
use_beam_search: bool,
|
||||
trust_remote_code: bool,
|
||||
dtype: str,
|
||||
max_model_len: Optional[int],
|
||||
enforce_eager: bool,
|
||||
kv_cache_dtype: str,
|
||||
quantization_param_path: Optional[str],
|
||||
device: str,
|
||||
enable_prefix_caching: bool,
|
||||
enable_chunked_prefill: bool,
|
||||
max_num_batched_tokens: int,
|
||||
distributed_executor_backend: Optional[str],
|
||||
gpu_memory_utilization: float = 0.9,
|
||||
num_scheduler_steps: int = 1,
|
||||
use_v2_block_manager: bool = False,
|
||||
download_dir: Optional[str] = None,
|
||||
load_format: str = EngineArgs.load_format,
|
||||
disable_async_output_proc: bool = False,
|
||||
disable_frontend_multiprocessing: bool = False,
|
||||
) -> float:
|
||||
from vllm import SamplingParams
|
||||
engine_args = AsyncEngineArgs(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
quantization=quantization,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
seed=seed,
|
||||
trust_remote_code=trust_remote_code,
|
||||
dtype=dtype,
|
||||
max_model_len=max_model_len,
|
||||
gpu_memory_utilization=gpu_memory_utilization,
|
||||
enforce_eager=enforce_eager,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
quantization_param_path=quantization_param_path,
|
||||
device=device,
|
||||
enable_prefix_caching=enable_prefix_caching,
|
||||
download_dir=download_dir,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
load_format=load_format,
|
||||
num_scheduler_steps=num_scheduler_steps,
|
||||
use_v2_block_manager=use_v2_block_manager,
|
||||
disable_async_output_proc=disable_async_output_proc,
|
||||
worker_use_ray=False,
|
||||
engine_use_ray=False,
|
||||
disable_log_requests=True,
|
||||
)
|
||||
|
||||
async with build_async_engine_client_from_engine_args(
|
||||
engine_args, disable_frontend_multiprocessing) as llm:
|
||||
|
||||
# Add the requests to the engine.
|
||||
prompts: List[str] = []
|
||||
sampling_params: List[SamplingParams] = []
|
||||
for prompt, _, output_len in requests:
|
||||
prompts.append(prompt)
|
||||
sampling_params.append(
|
||||
SamplingParams(
|
||||
n=n,
|
||||
temperature=0.0 if use_beam_search else 1.0,
|
||||
top_p=1.0,
|
||||
use_beam_search=use_beam_search,
|
||||
ignore_eos=True,
|
||||
max_tokens=output_len,
|
||||
))
|
||||
|
||||
generators = []
|
||||
start = time.perf_counter()
|
||||
for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)):
|
||||
generator = llm.generate(prompt, sp, request_id=f"test{i}")
|
||||
generators.append(generator)
|
||||
all_gens = merge_async_iterators(*generators)
|
||||
async for i, res in all_gens:
|
||||
pass
|
||||
end = time.perf_counter()
|
||||
return end - start
|
||||
|
||||
|
||||
def run_hf(
|
||||
requests: List[Tuple[str, int, int]],
|
||||
model: str,
|
||||
@ -224,7 +320,7 @@ def main(args: argparse.Namespace):
|
||||
args.output_len)
|
||||
|
||||
if args.backend == "vllm":
|
||||
elapsed_time = run_vllm(
|
||||
run_args = [
|
||||
requests, args.model, args.tokenizer, args.quantization,
|
||||
args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
|
||||
args.trust_remote_code, args.dtype, args.max_model_len,
|
||||
@ -232,7 +328,16 @@ def main(args: argparse.Namespace):
|
||||
args.quantization_param_path, args.device,
|
||||
args.enable_prefix_caching, args.enable_chunked_prefill,
|
||||
args.max_num_batched_tokens, args.distributed_executor_backend,
|
||||
args.gpu_memory_utilization, args.download_dir, args.load_format)
|
||||
args.gpu_memory_utilization, args.num_scheduler_steps,
|
||||
args.use_v2_block_manager, args.download_dir, args.load_format,
|
||||
args.disable_async_output_proc
|
||||
]
|
||||
|
||||
if args.async_engine:
|
||||
run_args.append(args.disable_frontend_multiprocessing)
|
||||
elapsed_time = uvloop.run(run_vllm_async(*run_args))
|
||||
else:
|
||||
elapsed_time = run_vllm(*run_args)
|
||||
elif args.backend == "hf":
|
||||
assert args.tensor_parallel_size == 1
|
||||
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
|
||||
@ -346,17 +451,23 @@ if __name__ == "__main__":
|
||||
'accuracy issues. FP8_E5M2 (without scaling) is only supported on '
|
||||
'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is '
|
||||
'instead supported for common inference criteria.')
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
parser.add_argument("--device",
|
||||
type=str,
|
||||
default="auto",
|
||||
choices=["auto", "cuda", "cpu", "openvino", "tpu", "xpu"],
|
||||
help='device type for vLLM execution, supporting CUDA, OpenVINO and '
|
||||
'CPU.')
|
||||
choices=DEVICE_OPTIONS,
|
||||
help='device type for vLLM execution')
|
||||
parser.add_argument(
|
||||
"--num-scheduler-steps",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Maximum number of forward steps per scheduler call.")
|
||||
parser.add_argument("--use-v2-block-manager",
|
||||
action='store_true',
|
||||
help="Enable block manager v2.")
|
||||
parser.add_argument(
|
||||
"--enable-prefix-caching",
|
||||
action='store_true',
|
||||
help="enable automatic prefix caching for vLLM backend.")
|
||||
help="Enable automatic prefix caching for vLLM backend.")
|
||||
parser.add_argument("--enable-chunked-prefill",
|
||||
action='store_true',
|
||||
help="enable chunked prefill for vLLM backend.")
|
||||
@ -405,6 +516,19 @@ if __name__ == "__main__":
|
||||
'section for more information.\n'
|
||||
'* "bitsandbytes" will load the weights using bitsandbytes '
|
||||
'quantization.\n')
|
||||
parser.add_argument(
|
||||
"--disable-async-output-proc",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="Disable async output processor for vLLM backend.")
|
||||
parser.add_argument("--async-engine",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="Use vLLM async engine rather than LLM class.")
|
||||
parser.add_argument("--disable-frontend-multiprocessing",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="Disable decoupled async engine frontend.")
|
||||
args = parser.parse_args()
|
||||
if args.tokenizer is None:
|
||||
args.tokenizer = args.model
|
||||
|
@ -6,7 +6,7 @@ TOKENS=$2
|
||||
|
||||
docker run -e HF_TOKEN=$HF_TOKEN --gpus all --shm-size 1g -p $PORT:80 \
|
||||
-v $PWD/data:/data \
|
||||
ghcr.io/huggingface/text-generation-inference:1.4.0 \
|
||||
ghcr.io/huggingface/text-generation-inference:2.2.0 \
|
||||
--model-id $MODEL \
|
||||
--sharded false \
|
||||
--max-input-length 1024 \
|
||||
|
@ -1,4 +1,5 @@
|
||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
|
||||
#
|
||||
# Define environment variables for special configurations
|
||||
@ -83,12 +84,7 @@ endif()
|
||||
|
||||
message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}")
|
||||
|
||||
list(APPEND LIBS "numa")
|
||||
|
||||
|
||||
#
|
||||
# Define extension targets
|
||||
#
|
||||
list(APPEND LIBS dnnl numa)
|
||||
|
||||
#
|
||||
# _C extension
|
||||
@ -102,6 +98,16 @@ set(VLLM_EXT_SRC
|
||||
"csrc/cpu/pos_encoding.cpp"
|
||||
"csrc/cpu/torch_bindings.cpp")
|
||||
|
||||
if (AVX512_FOUND AND NOT AVX512_DISABLED)
|
||||
set(VLLM_EXT_SRC
|
||||
"csrc/cpu/quant.cpp"
|
||||
${VLLM_EXT_SRC})
|
||||
endif()
|
||||
|
||||
#
|
||||
# Define extension targets
|
||||
#
|
||||
|
||||
define_gpu_extension_target(
|
||||
_C
|
||||
DESTINATION vllm
|
||||
|
@ -350,6 +350,7 @@ function (define_gpu_extension_target GPU_MOD_NAME)
|
||||
target_include_directories(${GPU_MOD_NAME} PRIVATE csrc
|
||||
${GPU_INCLUDE_DIRECTORIES})
|
||||
|
||||
# TODO: is torch_python_LIBRARY needed?
|
||||
target_link_libraries(${GPU_MOD_NAME} PRIVATE torch ${torch_python_LIBRARY}
|
||||
${GPU_LIBRARIES})
|
||||
|
||||
|
@ -387,7 +387,8 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
|
||||
// This needs to be implemented and throw a TypeError in order for
|
||||
// PyTorch's opcheck to work on ops that use ScalarTypes.
|
||||
int64_t len() const {
|
||||
throw c10::TypeError("__len__ not implemented");
|
||||
throw c10::TypeError({__func__, __FILE__, static_cast<uint32_t>(__LINE__)},
|
||||
"__len__ not implemented");
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
@ -24,8 +24,8 @@ namespace vec_op {
|
||||
#define CPU_KERNEL_GUARD_OUT(NAME)
|
||||
#else
|
||||
#define CPU_KERNEL_GUARD_IN(NAME) \
|
||||
std::cout << #NAME << " invoked." << std::endl;
|
||||
#define CPU_KERNEL_GUARD_OUT(NAME) std::cout << #NAME << " exit." << std::endl;
|
||||
RECORD_FUNCTION(#NAME, c10::ArrayRef<c10::IValue>({}));
|
||||
#define CPU_KERNEL_GUARD_OUT(NAME)
|
||||
#endif
|
||||
|
||||
#define FORCE_INLINE __attribute__((always_inline)) inline
|
||||
@ -106,6 +106,12 @@ struct BF16Vec16 : public Vec<BF16Vec16> {
|
||||
explicit BF16Vec16(const FP32Vec16 &);
|
||||
|
||||
void save(void *ptr) const { *reinterpret_cast<__m256i *>(ptr) = reg; }
|
||||
|
||||
void save(void* ptr, const int elem_num) const {
|
||||
constexpr uint32_t M = 0xFFFFFFFF;
|
||||
__mmask16 mask = _cvtu32_mask16(M >> (32 - elem_num));
|
||||
_mm256_mask_storeu_epi16(ptr, mask, reg);
|
||||
}
|
||||
};
|
||||
|
||||
#ifdef __AVX512F__
|
||||
@ -313,8 +319,28 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
return FP32Vec16(_mm512_div_ps(reg, b.reg));
|
||||
}
|
||||
|
||||
FP32Vec16 clamp(const FP32Vec16& min, const FP32Vec16& max) const {
|
||||
return FP32Vec16(_mm512_min_ps(max.reg, _mm512_max_ps(min.reg, reg)));
|
||||
}
|
||||
|
||||
FP32Vec16 max(const FP32Vec16& b) const {
|
||||
return FP32Vec16(_mm512_max_ps(reg, b.reg));
|
||||
}
|
||||
|
||||
FP32Vec16 max(const FP32Vec16& b, const int elem_num) const {
|
||||
constexpr uint32_t M = 0xFFFFFFFF;
|
||||
__mmask16 mask = _cvtu32_mask16(M >> (32 - elem_num));
|
||||
return FP32Vec16(_mm512_mask_max_ps(reg, mask, reg, b.reg));
|
||||
}
|
||||
|
||||
FP32Vec16 abs() const {
|
||||
return FP32Vec16(_mm512_abs_ps(reg));
|
||||
}
|
||||
|
||||
float reduce_sum() const { return _mm512_reduce_add_ps(reg); }
|
||||
|
||||
float reduce_max() const { return _mm512_reduce_max_ps(reg); }
|
||||
|
||||
template <int group_size> float reduce_sub_sum(int idx) {
|
||||
static_assert(VEC_ELEM_NUM % group_size == 0);
|
||||
constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size));
|
||||
@ -323,6 +349,12 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
}
|
||||
|
||||
void save(float *ptr) const { _mm512_storeu_ps(ptr, reg); }
|
||||
|
||||
void save(float* ptr, const int elem_num) const {
|
||||
constexpr uint32_t M = 0xFFFFFFFF;
|
||||
__mmask16 mask = _cvtu32_mask16(M >> (32 - elem_num));
|
||||
_mm512_mask_storeu_ps(ptr, mask, reg);
|
||||
}
|
||||
};
|
||||
#else
|
||||
struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
@ -433,6 +465,32 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
};
|
||||
#endif
|
||||
|
||||
#ifdef __AVX512F__
|
||||
struct INT8Vec16: public Vec<INT8Vec16> {
|
||||
constexpr static int VEC_ELEM_NUM = 16;
|
||||
union AliasReg {
|
||||
__m128i reg;
|
||||
int8_t values[VEC_ELEM_NUM];
|
||||
};
|
||||
|
||||
__m128i reg;
|
||||
|
||||
explicit INT8Vec16(const FP32Vec16& vec) : reg(
|
||||
_mm512_cvtepi32_epi8(_mm512_cvt_roundps_epi32(vec.reg, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC))
|
||||
) {}
|
||||
|
||||
void save(int8_t* ptr) const {
|
||||
_mm_storeu_epi8(ptr, reg);
|
||||
}
|
||||
|
||||
void save(int8_t* ptr, const int elem_num) const {
|
||||
constexpr uint32_t M = 0xFFFFFFFF;
|
||||
__mmask16 mask = _cvtu32_mask16(M >> (32 - elem_num));
|
||||
_mm_mask_storeu_epi8(ptr, mask, reg);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
template <typename T> struct VecType { using vec_type = void; };
|
||||
|
||||
template <typename T> using vec_t = typename VecType<T>::vec_type;
|
||||
|
168
csrc/cpu/dnnl_helper.hpp
Normal file
168
csrc/cpu/dnnl_helper.hpp
Normal file
@ -0,0 +1,168 @@
|
||||
#ifndef DNNL_HELPER_HPP
|
||||
#define DNNL_HELPER_HPP
|
||||
|
||||
#include <c10/util/BFloat16.h>
|
||||
|
||||
#include "oneapi/dnnl/dnnl.hpp"
|
||||
|
||||
namespace {
|
||||
template <typename T>
|
||||
struct DNNLType {
|
||||
static constexpr dnnl::memory::data_type type =
|
||||
dnnl::memory::data_type::undef;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DNNLType<int8_t> {
|
||||
static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s8;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DNNLType<int32_t> {
|
||||
static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s32;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DNNLType<float> {
|
||||
static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f32;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DNNLType<c10::BFloat16> {
|
||||
static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::bf16;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
constexpr inline dnnl::memory::data_type get_dnnl_type() {
|
||||
return DNNLType<std::decay_t<T>>::type;
|
||||
}
|
||||
}; // namespace
|
||||
|
||||
template <bool InputNoScale>
|
||||
class DNNLPrimitiveHelper {
|
||||
public:
|
||||
// I8 input GEMM kernel (C = a_scales * A @ (b_scales * B^T) + bias)
|
||||
// A: [M, K], row-major
|
||||
// B: [K, N], column-major
|
||||
// C: [M, N], row-major
|
||||
// bias: [N], row-major, optional
|
||||
// a_scales: [MS]
|
||||
// b_scales: [NS]
|
||||
// Note: Due to the limitation of oneDNN
|
||||
// (https://github.com/oneapi-src/oneDNN/issues/1636), the quantized bias is
|
||||
// not supported.
|
||||
template <typename OutputT, typename BiasT>
|
||||
static void gemm_s8s8_jit(const int8_t* a, const int8_t* b, OutputT* c,
|
||||
const BiasT* bias, dnnl_dim_t M, dnnl_dim_t N,
|
||||
dnnl_dim_t K, const float* a_scales,
|
||||
const float* b_scales, dnnl_dim_t MS,
|
||||
dnnl_dim_t NS) {
|
||||
auto&& OutputType = get_dnnl_type<OutputT>();
|
||||
auto&& BiasType = get_dnnl_type<BiasT>();
|
||||
|
||||
dnnl::memory::desc a_md({M, K}, dnnl::memory::data_type::s8, {K, 1});
|
||||
dnnl::memory::desc b_md({K, N}, dnnl::memory::data_type::s8, {1, K});
|
||||
dnnl::memory::desc c_md({M, N}, OutputType, {N, 1});
|
||||
|
||||
dnnl::primitive_attr attr;
|
||||
if constexpr (!InputNoScale) {
|
||||
if (MS == 1) {
|
||||
// per-tensor
|
||||
attr.set_scales_mask(DNNL_ARG_SRC, 0);
|
||||
} else {
|
||||
// per-token
|
||||
TORCH_CHECK(false, "per-token quantization is unsupported.");
|
||||
}
|
||||
}
|
||||
|
||||
if (NS == 1) {
|
||||
// per-tensor
|
||||
attr.set_scales_mask(DNNL_ARG_WEIGHTS, 0);
|
||||
} else {
|
||||
// per-channel
|
||||
attr.set_scales_mask(DNNL_ARG_WEIGHTS, 2);
|
||||
}
|
||||
|
||||
dnnl::matmul::primitive_desc matmul_pd;
|
||||
if (bias) {
|
||||
dnnl::memory::desc bias_md({1, N}, BiasType, {N, 1});
|
||||
matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md,
|
||||
bias_md, c_md, attr);
|
||||
} else {
|
||||
matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md,
|
||||
c_md, attr);
|
||||
}
|
||||
dnnl::matmul matmul(matmul_pd);
|
||||
|
||||
auto& engine = default_engine();
|
||||
|
||||
dnnl::memory a_m(a_md, engine, (void*)a);
|
||||
dnnl::memory b_m(b_md, engine, (void*)b);
|
||||
dnnl::memory c_m(c_md, engine, (void*)c);
|
||||
dnnl::memory a_scales_m({{MS}, dnnl::memory::data_type::f32, {1}}, engine,
|
||||
(void*)a_scales);
|
||||
dnnl::memory b_scales_m({{NS}, dnnl::memory::data_type::f32, {1}}, engine,
|
||||
(void*)b_scales);
|
||||
|
||||
auto& stream = default_stream();
|
||||
if constexpr (InputNoScale) {
|
||||
if (bias) {
|
||||
dnnl::memory::desc bias_md({N}, BiasType, {1});
|
||||
dnnl::memory bias_m(bias_md, engine, (void*)bias);
|
||||
matmul.execute(
|
||||
stream, {
|
||||
{DNNL_ARG_SRC, a_m},
|
||||
{DNNL_ARG_WEIGHTS, b_m},
|
||||
{DNNL_ARG_BIAS, bias_m},
|
||||
{DNNL_ARG_DST, c_m},
|
||||
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
|
||||
});
|
||||
} else {
|
||||
matmul.execute(
|
||||
stream, {
|
||||
{DNNL_ARG_SRC, a_m},
|
||||
{DNNL_ARG_WEIGHTS, b_m},
|
||||
{DNNL_ARG_DST, c_m},
|
||||
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
|
||||
});
|
||||
}
|
||||
} else {
|
||||
if (bias) {
|
||||
dnnl::memory::desc bias_md({N}, BiasType, {1});
|
||||
dnnl::memory bias_m(bias_md, engine, (void*)bias);
|
||||
matmul.execute(
|
||||
stream, {
|
||||
{DNNL_ARG_SRC, a_m},
|
||||
{DNNL_ARG_WEIGHTS, b_m},
|
||||
{DNNL_ARG_BIAS, bias_m},
|
||||
{DNNL_ARG_DST, c_m},
|
||||
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m},
|
||||
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
|
||||
});
|
||||
} else {
|
||||
matmul.execute(
|
||||
stream, {
|
||||
{DNNL_ARG_SRC, a_m},
|
||||
{DNNL_ARG_WEIGHTS, b_m},
|
||||
{DNNL_ARG_DST, c_m},
|
||||
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m},
|
||||
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
|
||||
});
|
||||
}
|
||||
}
|
||||
stream.wait();
|
||||
}
|
||||
|
||||
private:
|
||||
static dnnl::engine& default_engine() {
|
||||
static dnnl::engine engine(dnnl::engine::kind::cpu, 0);
|
||||
return engine;
|
||||
}
|
||||
|
||||
static dnnl::stream& default_stream() {
|
||||
static dnnl::stream stream(default_engine());
|
||||
return stream;
|
||||
}
|
||||
};
|
||||
|
||||
#endif
|
294
csrc/cpu/quant.cpp
Normal file
294
csrc/cpu/quant.cpp
Normal file
@ -0,0 +1,294 @@
|
||||
#include "cpu_types.hpp"
|
||||
#include "dnnl_helper.hpp"
|
||||
|
||||
namespace {
|
||||
template <typename scalar_t>
|
||||
struct KernelVecType {
|
||||
using load_vec_type = void;
|
||||
using cvt_vec_type = void;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct KernelVecType<float> {
|
||||
using load_vec_type = vec_op::FP32Vec16;
|
||||
using cvt_vec_type = vec_op::FP32Vec16;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct KernelVecType<c10::BFloat16> {
|
||||
using load_vec_type = vec_op::BF16Vec16;
|
||||
using cvt_vec_type = vec_op::FP32Vec16;
|
||||
};
|
||||
|
||||
#ifdef __AVX512F__
|
||||
template <typename scalar_t>
|
||||
void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
|
||||
const float* scale, const int num_tokens,
|
||||
const int hidden_size) {
|
||||
using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
|
||||
using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
|
||||
constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
|
||||
|
||||
constexpr float i8_min =
|
||||
static_cast<float>(std::numeric_limits<int8_t>::min());
|
||||
constexpr float i8_max =
|
||||
static_cast<float>(std::numeric_limits<int8_t>::max());
|
||||
const cvt_vec_t inv_scale(1.0 / *scale);
|
||||
const cvt_vec_t i8_min_vec(i8_min);
|
||||
const cvt_vec_t i8_max_vec(i8_max);
|
||||
|
||||
#pragma omp parallel for
|
||||
for (int i = 0; i < num_tokens; ++i) {
|
||||
int j = 0;
|
||||
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
|
||||
load_vec_t elems(input + i * hidden_size + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
elems_fp32 = (elems_fp32 * inv_scale).clamp(i8_min_vec, i8_max_vec);
|
||||
vec_op::INT8Vec16 elems_int8(elems_fp32);
|
||||
elems_int8.save(output + i * hidden_size + j);
|
||||
}
|
||||
|
||||
load_vec_t elems(input + i * hidden_size + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
elems_fp32 = (elems_fp32 * inv_scale).clamp(i8_min_vec, i8_max_vec);
|
||||
vec_op::INT8Vec16 elems_int8(elems_fp32);
|
||||
|
||||
if (j + vec_elem_num == hidden_size) {
|
||||
elems_int8.save(output + i * hidden_size + j);
|
||||
} else {
|
||||
elems_int8.save(output + i * hidden_size + j, hidden_size - j);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
|
||||
float* scale, const int num_tokens,
|
||||
const int hidden_size) {
|
||||
using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
|
||||
using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
|
||||
constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
|
||||
|
||||
#pragma omp parallel for
|
||||
for (int i = 0; i < num_tokens; ++i) {
|
||||
cvt_vec_t max_abs(0.0);
|
||||
{
|
||||
int j = 0;
|
||||
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
|
||||
load_vec_t elems(input + i * hidden_size + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
max_abs = max_abs.max(elems_fp32.abs());
|
||||
}
|
||||
|
||||
load_vec_t elems(input + i * hidden_size + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
|
||||
if (j + vec_elem_num == hidden_size) {
|
||||
max_abs = max_abs.max(elems_fp32.abs());
|
||||
} else {
|
||||
max_abs = max_abs.max(elems_fp32.abs(), hidden_size - j);
|
||||
}
|
||||
}
|
||||
|
||||
float scale_val = max_abs.reduce_max() / 127.0f;
|
||||
scale[i] = scale_val;
|
||||
const cvt_vec_t inv_scale(1.0 / scale_val);
|
||||
|
||||
{
|
||||
int j = 0;
|
||||
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
|
||||
load_vec_t elems(input + i * hidden_size + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
elems_fp32 = (elems_fp32 * inv_scale);
|
||||
vec_op::INT8Vec16 elems_int8(elems_fp32);
|
||||
elems_int8.save(output + i * hidden_size + j);
|
||||
}
|
||||
|
||||
load_vec_t elems(input + i * hidden_size + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
elems_fp32 = (elems_fp32 * inv_scale);
|
||||
vec_op::INT8Vec16 elems_int8(elems_fp32);
|
||||
|
||||
if (j + vec_elem_num == hidden_size) {
|
||||
elems_int8.save(output + i * hidden_size + j);
|
||||
} else {
|
||||
elems_int8.save(output + i * hidden_size + j, hidden_size - j);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <bool Bias, typename scalar_t>
|
||||
void dynamic_output_scale_impl(const float* input, scalar_t* output,
|
||||
const float* scale, const scalar_t* bias,
|
||||
const int num_tokens, const int hidden_size) {
|
||||
CPU_KERNEL_GUARD_IN(dynamic_output_scale_impl)
|
||||
using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
|
||||
using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
|
||||
constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
|
||||
|
||||
#pragma omp parallel for
|
||||
for (int i = 0; i < num_tokens; ++i) {
|
||||
int j = 0;
|
||||
cvt_vec_t token_scale_vec(scale[i]);
|
||||
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
|
||||
cvt_vec_t elems_fp32(input + i * hidden_size + j);
|
||||
elems_fp32 = elems_fp32 * token_scale_vec;
|
||||
|
||||
if constexpr (Bias) {
|
||||
load_vec_t bias_vec(bias + j);
|
||||
cvt_vec_t bias_vec_fp32(bias_vec);
|
||||
elems_fp32 = elems_fp32 + bias_vec_fp32;
|
||||
}
|
||||
|
||||
load_vec_t elems_out(elems_fp32);
|
||||
elems_out.save(output + i * hidden_size + j);
|
||||
}
|
||||
|
||||
cvt_vec_t elems_fp32(input + i * hidden_size + j);
|
||||
elems_fp32 = elems_fp32 * token_scale_vec;
|
||||
|
||||
if constexpr (Bias) {
|
||||
load_vec_t bias_vec(bias + j);
|
||||
cvt_vec_t bias_vec_fp32(bias_vec);
|
||||
elems_fp32 = elems_fp32 + bias_vec_fp32;
|
||||
}
|
||||
|
||||
load_vec_t elems_out(elems_fp32);
|
||||
|
||||
if (j + vec_elem_num == hidden_size) {
|
||||
elems_out.save(output + i * hidden_size + j);
|
||||
} else {
|
||||
elems_out.save(output + i * hidden_size + j, hidden_size - j);
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
template <typename scalar_t>
|
||||
void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
|
||||
const float* scale, const int num_tokens,
|
||||
const int hidden_size) {
|
||||
TORCH_CHECK(false, "static_scaled_int8_quant_impl requires AVX512 support.")
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
|
||||
float* scale, const int num_tokens,
|
||||
const int hidden_size) {
|
||||
TORCH_CHECK(false, "dynamic_scaled_int8_quant_impl requires AVX512 support.")
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void dynamic_output_scale_impl() {
|
||||
TORCH_CHECK(false, "dynamic_output_scale_impl requires AVX512 support.")
|
||||
}
|
||||
#endif
|
||||
} // namespace
|
||||
|
||||
void int8_scaled_mm(torch::Tensor& c, // [M, OC], row-major
|
||||
const torch::Tensor& a, // [M, IC], row-major
|
||||
const torch::Tensor& b, // [IC, OC], column-major
|
||||
const torch::Tensor& a_scales, // [1] or [M]
|
||||
const torch::Tensor& b_scales, // [1] or [OC]
|
||||
const c10::optional<torch::Tensor>& bias // [OC]
|
||||
) {
|
||||
CPU_KERNEL_GUARD_IN(cutlass_scaled_mm)
|
||||
// Checks for conformality
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8 && b.dtype() == torch::kInt8,
|
||||
"int8_scaled_mm only supports INT8 inputs.")
|
||||
TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
|
||||
TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
|
||||
b.size(1) == c.size(1));
|
||||
TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
|
||||
TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1));
|
||||
|
||||
// Check for strides and alignment
|
||||
TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
|
||||
TORCH_CHECK(b.stride(0) == 1); // Column-major
|
||||
TORCH_CHECK(c.stride(0) % 16 == 0 &&
|
||||
b.stride(1) % 16 == 0); // 16 Byte Alignment
|
||||
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
|
||||
if (bias) {
|
||||
TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() &&
|
||||
bias->dim() == 1);
|
||||
}
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "cutlass_scaled_mm", [&] {
|
||||
if (a_scales.numel() != 1) {
|
||||
// per-token
|
||||
// Note: oneDNN doesn't support per-token activation quantization
|
||||
torch::Tensor tmp_fp32_out =
|
||||
torch::empty_like(c, ::at::ScalarType::Float);
|
||||
DNNLPrimitiveHelper<true>::gemm_s8s8_jit(
|
||||
a.data_ptr<int8_t>(), b.data_ptr<int8_t>(),
|
||||
tmp_fp32_out.data_ptr<float>(), (void*)(0), a.size(0), b.size(1),
|
||||
a.size(1), (float*)(0), b_scales.data_ptr<float>(), 0,
|
||||
b_scales.numel());
|
||||
if (bias.has_value()) {
|
||||
dynamic_output_scale_impl<true>(
|
||||
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
|
||||
a_scales.data_ptr<float>(), bias->data_ptr<scalar_t>(), c.size(0),
|
||||
c.size(1));
|
||||
} else {
|
||||
dynamic_output_scale_impl<false>(
|
||||
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
|
||||
a_scales.data_ptr<float>(), (scalar_t*)(0), c.size(0), c.size(1));
|
||||
}
|
||||
} else {
|
||||
// per-tensor
|
||||
if (bias.has_value()) {
|
||||
DNNLPrimitiveHelper<false>::gemm_s8s8_jit(
|
||||
a.data_ptr<int8_t>(), b.data_ptr<int8_t>(), c.data_ptr<scalar_t>(),
|
||||
bias->data_ptr<scalar_t>(), a.size(0), b.size(1), a.size(1),
|
||||
a_scales.data_ptr<float>(), b_scales.data_ptr<float>(),
|
||||
a_scales.numel(), b_scales.numel());
|
||||
} else {
|
||||
DNNLPrimitiveHelper<false>::gemm_s8s8_jit(
|
||||
a.data_ptr<int8_t>(), b.data_ptr<int8_t>(), c.data_ptr<scalar_t>(),
|
||||
(void*)(0), a.size(0), b.size(1), a.size(1),
|
||||
a_scales.data_ptr<float>(), b_scales.data_ptr<float>(),
|
||||
a_scales.numel(), b_scales.numel());
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// static-per-tensor quantization.
|
||||
void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
|
||||
const torch::Tensor& input, // [..., hidden_size]
|
||||
const torch::Tensor& scale) {
|
||||
CPU_KERNEL_GUARD_IN(static_scaled_int8_quant)
|
||||
TORCH_CHECK(input.is_contiguous());
|
||||
TORCH_CHECK(out.is_contiguous());
|
||||
TORCH_CHECK(scale.numel() == 1);
|
||||
|
||||
const int hidden_size = input.size(-1);
|
||||
const int num_tokens = input.numel() / hidden_size;
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "static_scaled_int8_quant_impl", [&] {
|
||||
static_scaled_int8_quant_impl(
|
||||
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
|
||||
scale.data_ptr<float>(), num_tokens, hidden_size);
|
||||
});
|
||||
}
|
||||
|
||||
// dynamic-per-token quantization.
|
||||
void dynamic_scaled_int8_quant(
|
||||
torch::Tensor& out, // [..., hidden_size]
|
||||
const torch::Tensor& input, // [..., hidden_size]
|
||||
torch::Tensor& scale // [..., 1]
|
||||
) {
|
||||
CPU_KERNEL_GUARD_IN(dynamic_scaled_int8_quant)
|
||||
TORCH_CHECK(input.is_contiguous());
|
||||
TORCH_CHECK(out.is_contiguous());
|
||||
|
||||
int const hidden_size = input.size(-1);
|
||||
int const num_tokens = input.numel() / hidden_size;
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "dynamic_scaled_int8_quant_impl", [&] {
|
||||
dynamic_scaled_int8_quant_impl(
|
||||
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
|
||||
scale.data_ptr<float>(), num_tokens, hidden_size);
|
||||
});
|
||||
}
|
@ -4,7 +4,12 @@
|
||||
|
||||
#include <torch/library.h>
|
||||
|
||||
void init_cpu_threads_env(const std::string& cpu_ids);
|
||||
std::string init_cpu_threads_env(const std::string& cpu_ids);
|
||||
|
||||
void int8_scaled_mm(torch::Tensor& c, const torch::Tensor& a,
|
||||
const torch::Tensor& b, const torch::Tensor& a_scales,
|
||||
const torch::Tensor& b_scales,
|
||||
const c10::optional<torch::Tensor>& bias);
|
||||
|
||||
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
// vLLM custom ops
|
||||
@ -27,8 +32,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
// PagedAttention V2.
|
||||
ops.def(
|
||||
"paged_attention_v2("
|
||||
" Tensor! out, Tensor exp_sums, Tensor max_logits,"
|
||||
" Tensor tmp_out, Tensor query, Tensor key_cache,"
|
||||
" Tensor! out, Tensor! exp_sums, Tensor! max_logits,"
|
||||
" Tensor! tmp_out, Tensor query, Tensor key_cache,"
|
||||
" Tensor value_cache, int num_kv_heads, float scale,"
|
||||
" Tensor block_tables, Tensor seq_lens, int block_size,"
|
||||
" int max_seq_len, Tensor? alibi_slopes,"
|
||||
@ -84,6 +89,28 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
" Tensor! key, int head_size,"
|
||||
" Tensor cos_sin_cache, bool is_neox) -> ()");
|
||||
ops.impl("rotary_embedding", torch::kCPU, &rotary_embedding);
|
||||
|
||||
// Quantization
|
||||
#ifdef __AVX512F__
|
||||
// Compute int8 quantized tensor for given scaling factor.
|
||||
ops.def(
|
||||
"static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale) -> "
|
||||
"()");
|
||||
ops.impl("static_scaled_int8_quant", torch::kCPU, &static_scaled_int8_quant);
|
||||
// Compute int8 quantized tensor and scaling factor
|
||||
ops.def(
|
||||
"dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale) -> "
|
||||
"()");
|
||||
ops.impl("dynamic_scaled_int8_quant", torch::kCPU,
|
||||
&dynamic_scaled_int8_quant);
|
||||
// W8A8 GEMM, supporting symmetric per-tensor or per-row/column
|
||||
// quantization.
|
||||
ops.def(
|
||||
"cutlass_scaled_mm(Tensor! out, Tensor a,"
|
||||
" Tensor b, Tensor a_scales,"
|
||||
" Tensor b_scales, Tensor? bias) -> ()");
|
||||
ops.impl("cutlass_scaled_mm", torch::kCPU, &int8_scaled_mm);
|
||||
#endif
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
||||
@ -95,8 +122,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
||||
|
||||
// Copy the cache blocks from src to dst.
|
||||
cache_ops.def(
|
||||
"copy_blocks(Tensor[]! key_caches, Tensor[]! value_caches, Tensor "
|
||||
"block_mapping) -> ()");
|
||||
"copy_blocks(Tensor(a!)[] key_caches, Tensor[](b!) value_caches, "
|
||||
"Tensor block_mapping) -> ()");
|
||||
cache_ops.impl("copy_blocks", torch::kCPU, ©_blocks);
|
||||
|
||||
// Reshape the key and value tensors and cache them.
|
||||
@ -111,7 +138,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
||||
|
||||
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _utils), utils) {
|
||||
// CPU utils
|
||||
utils.def("init_cpu_threads_env(str cpu_ids) -> ()", &init_cpu_threads_env);
|
||||
utils.def("init_cpu_threads_env(str cpu_ids) -> str", &init_cpu_threads_env);
|
||||
}
|
||||
|
||||
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
||||
|
@ -5,7 +5,7 @@
|
||||
|
||||
#include "cpu_types.hpp"
|
||||
|
||||
void init_cpu_threads_env(const std::string& cpu_ids) {
|
||||
std::string init_cpu_threads_env(const std::string& cpu_ids) {
|
||||
bitmask* omp_cpu_mask = numa_parse_cpustring(cpu_ids.c_str());
|
||||
TORCH_CHECK(omp_cpu_mask->size > 0);
|
||||
std::vector<int> omp_cpu_ids;
|
||||
@ -51,15 +51,40 @@ void init_cpu_threads_env(const std::string& cpu_ids) {
|
||||
torch::set_num_threads((int)omp_cpu_ids.size());
|
||||
TORCH_CHECK_EQ(omp_cpu_ids.size(), torch::get_num_threads());
|
||||
TORCH_CHECK_EQ(omp_cpu_ids.size(), omp_get_max_threads());
|
||||
|
||||
std::vector<std::pair<int, int>> thread_core_mapping;
|
||||
thread_core_mapping.reserve(omp_cpu_ids.size());
|
||||
omp_lock_t writelock;
|
||||
omp_init_lock(&writelock);
|
||||
|
||||
#pragma omp parallel for schedule(static, 1)
|
||||
for (size_t i = 0; i < omp_cpu_ids.size(); ++i) {
|
||||
cpu_set_t* mask = CPU_ALLOC(omp_cpu_mask->size);
|
||||
size_t size = CPU_ALLOC_SIZE(omp_cpu_mask->size);
|
||||
CPU_ZERO_S(size, mask);
|
||||
CPU_SET_S(omp_cpu_ids[i], size, mask);
|
||||
sched_setaffinity(0, sizeof(cpu_set_t), mask);
|
||||
CPU_FREE(mask);
|
||||
cpu_set_t mask;
|
||||
CPU_ZERO(&mask);
|
||||
CPU_SET(omp_cpu_ids[i], &mask);
|
||||
int ret = sched_setaffinity(0, sizeof(cpu_set_t), &mask);
|
||||
if (ret == -1) {
|
||||
TORCH_CHECK(false,
|
||||
"sched_setaffinity failed. errno: " + std::to_string(errno));
|
||||
}
|
||||
|
||||
numa_free_nodemask(omp_cpu_mask);
|
||||
omp_set_lock(&writelock);
|
||||
thread_core_mapping.emplace_back(gettid(), omp_cpu_ids[i]);
|
||||
omp_unset_lock(&writelock);
|
||||
}
|
||||
|
||||
omp_destroy_lock(&writelock);
|
||||
|
||||
numa_free_nodemask(omp_cpu_mask);
|
||||
|
||||
std::stringstream ss;
|
||||
ss << "OMP threads binding of Process " << getpid() << ":\n";
|
||||
std::sort(thread_core_mapping.begin(), thread_core_mapping.end(),
|
||||
[](auto&& a, auto&& b) { return a.second < b.second; });
|
||||
for (auto&& item : thread_core_mapping) {
|
||||
ss << "\t"
|
||||
<< "OMP tid: " << item.first << ", core " << item.second << "\n";
|
||||
}
|
||||
|
||||
return ss.str();
|
||||
}
|
||||
|
700
csrc/mamba/causal_conv1d/causal_conv1d.cu
Normal file
700
csrc/mamba/causal_conv1d/causal_conv1d.cu
Normal file
@ -0,0 +1,700 @@
|
||||
// clang-format off
|
||||
// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d_fwd.cu
|
||||
// and https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d_update.cu
|
||||
#include <torch/all.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include "causal_conv1d.h"
|
||||
#include <c10/util/BFloat16.h>
|
||||
#include <c10/util/Half.h>
|
||||
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
|
||||
|
||||
#include <cub/block/block_load.cuh>
|
||||
#include <cub/block/block_store.cuh>
|
||||
|
||||
#include "static_switch.h"
|
||||
|
||||
|
||||
|
||||
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
|
||||
|
||||
#define DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \
|
||||
if (ITYPE == at::ScalarType::Half) { \
|
||||
using input_t = at::Half; \
|
||||
using weight_t = at::Half; \
|
||||
__VA_ARGS__(); \
|
||||
} else if (ITYPE == at::ScalarType::BFloat16) { \
|
||||
using input_t = at::BFloat16; \
|
||||
using weight_t = at::BFloat16; \
|
||||
__VA_ARGS__(); \
|
||||
} else if (ITYPE == at::ScalarType::Float) { \
|
||||
using input_t = float; \
|
||||
using weight_t = float; \
|
||||
__VA_ARGS__(); \
|
||||
} else { \
|
||||
AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
|
||||
}
|
||||
|
||||
|
||||
template<typename input_t, typename weight_t>
|
||||
void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
template <typename input_t, typename weight_t>
|
||||
void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
|
||||
template<typename input_t, typename weight_t>
|
||||
void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
|
||||
void set_conv_params_fwd(ConvParamsBase ¶ms,
|
||||
// sizes
|
||||
const size_t batch,
|
||||
const size_t dim,
|
||||
const size_t seqlen,
|
||||
const size_t width,
|
||||
// device pointers
|
||||
const at::Tensor x,
|
||||
const at::Tensor weight,
|
||||
const at::Tensor out,
|
||||
void* bias_ptr,
|
||||
bool silu_activation) {
|
||||
|
||||
// Reset the parameters
|
||||
memset(¶ms, 0, sizeof(params));
|
||||
|
||||
params.batch = batch;
|
||||
params.dim = dim;
|
||||
params.seqlen = seqlen;
|
||||
params.width = width;
|
||||
|
||||
params.silu_activation = silu_activation;
|
||||
|
||||
// Set the pointers and strides.
|
||||
params.x_ptr = x.data_ptr();
|
||||
params.weight_ptr = weight.data_ptr();
|
||||
params.bias_ptr = bias_ptr;
|
||||
params.out_ptr = out.data_ptr();
|
||||
// All stride are in elements, not bytes.
|
||||
params.x_batch_stride = x.stride(0);
|
||||
params.x_c_stride = x.stride(1);
|
||||
params.x_l_stride = x.stride(-1);
|
||||
params.weight_c_stride = weight.stride(0);
|
||||
params.weight_width_stride = weight.stride(1);
|
||||
params.out_batch_stride = out.stride(0);
|
||||
params.out_c_stride = out.stride(1);
|
||||
params.out_l_stride = out.stride(-1);
|
||||
}
|
||||
|
||||
|
||||
at::Tensor
|
||||
causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
|
||||
const c10::optional<at::Tensor> &bias_,
|
||||
const c10::optional<at::Tensor> &seq_idx_,
|
||||
const c10::optional<at::Tensor> &initial_states_,
|
||||
const c10::optional<at::Tensor> &final_states_out_,
|
||||
bool silu_activation) {
|
||||
auto input_type = x.scalar_type();
|
||||
auto weight_type = weight.scalar_type();
|
||||
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
||||
TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16);
|
||||
|
||||
TORCH_CHECK(x.is_cuda());
|
||||
TORCH_CHECK(weight.is_cuda());
|
||||
|
||||
const auto sizes = x.sizes();
|
||||
const int batch_size = sizes[0];
|
||||
const int dim = sizes[1];
|
||||
const int seqlen = sizes[2];
|
||||
const int width = weight.size(-1);
|
||||
|
||||
CHECK_SHAPE(x, batch_size, dim, seqlen);
|
||||
CHECK_SHAPE(weight, dim, width);
|
||||
|
||||
TORCH_CHECK(x.stride(2) == 1 || x.stride(1) == 1);
|
||||
const bool is_channel_last = x.stride(1) == 1 && x.stride(2) > 1;
|
||||
|
||||
if (is_channel_last) {
|
||||
TORCH_CHECK(dim % 8 == 0, "causal_conv1d only supports channel dimension divisible by 8 for now");
|
||||
TORCH_CHECK(x.stride(2) % 8 == 0 and x.stride(0) % 8 == 0, "causal_conv1d with channel last layout requires strides (x.stride(0) and x.stride(2)) to be multiples of 8");
|
||||
}
|
||||
TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4");
|
||||
|
||||
if (bias_.has_value()) {
|
||||
auto bias = bias_.value();
|
||||
TORCH_CHECK(bias.scalar_type() == weight_type);
|
||||
TORCH_CHECK(bias.is_cuda());
|
||||
TORCH_CHECK(bias.stride(-1) == 1);
|
||||
CHECK_SHAPE(bias, dim);
|
||||
}
|
||||
|
||||
if (seq_idx_.has_value()) {
|
||||
TORCH_CHECK(is_channel_last, "seq_idx is only supported for channel last layout");
|
||||
auto seq_idx = seq_idx_.value();
|
||||
TORCH_CHECK(seq_idx.scalar_type() == torch::kInt32);
|
||||
TORCH_CHECK(seq_idx.is_cuda());
|
||||
TORCH_CHECK(seq_idx.is_contiguous());
|
||||
CHECK_SHAPE(seq_idx, batch_size, seqlen);
|
||||
}
|
||||
|
||||
at::Tensor out = torch::empty_like(x);
|
||||
|
||||
ConvParamsBase params;
|
||||
set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out,
|
||||
bias_.has_value() ? bias_.value().data_ptr() : nullptr,
|
||||
silu_activation);
|
||||
|
||||
if (seq_idx_.has_value()) {
|
||||
params.seq_idx_ptr = seq_idx_.value().data_ptr();
|
||||
} else {
|
||||
params.seq_idx_ptr = nullptr;
|
||||
}
|
||||
|
||||
if (initial_states_.has_value()) {
|
||||
TORCH_CHECK(is_channel_last, "initial_states is only supported for channel last layout");
|
||||
auto initial_states = initial_states_.value();
|
||||
TORCH_CHECK(initial_states.scalar_type() == input_type);
|
||||
TORCH_CHECK(initial_states.is_cuda());
|
||||
CHECK_SHAPE(initial_states, batch_size, dim, width - 1);
|
||||
TORCH_CHECK(initial_states.stride(1) == 1);
|
||||
params.initial_states_ptr = initial_states.data_ptr();
|
||||
params.initial_states_batch_stride = initial_states.stride(0);
|
||||
params.initial_states_c_stride = initial_states.stride(1);
|
||||
params.initial_states_l_stride = initial_states.stride(2);
|
||||
} else {
|
||||
params.initial_states_ptr = nullptr;
|
||||
}
|
||||
|
||||
if (final_states_out_.has_value()) {
|
||||
TORCH_CHECK(is_channel_last, "final_states is only supported for channel last layout");
|
||||
auto final_states = final_states_out_.value();
|
||||
TORCH_CHECK(final_states.scalar_type() == input_type);
|
||||
TORCH_CHECK(final_states.is_cuda());
|
||||
CHECK_SHAPE(final_states, batch_size, dim, width - 1);
|
||||
TORCH_CHECK(final_states.stride(1) == 1);
|
||||
params.final_states_ptr = final_states.data_ptr();
|
||||
params.final_states_batch_stride = final_states.stride(0);
|
||||
params.final_states_c_stride = final_states.stride(1);
|
||||
params.final_states_l_stride = final_states.stride(2);
|
||||
} else {
|
||||
params.final_states_ptr = nullptr;
|
||||
}
|
||||
|
||||
// Otherwise the kernel will be launched from cuda:0 device
|
||||
// Cast to char to avoid compiler warning about narrowing
|
||||
at::cuda::CUDAGuard device_guard{(char)x.get_device()};
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_fwd", [&] {
|
||||
if (!is_channel_last) {
|
||||
causal_conv1d_fwd_cuda<input_t, weight_t>(params, stream);
|
||||
} else {
|
||||
causal_conv1d_channellast_fwd_cuda<input_t, weight_t>(params, stream);
|
||||
}
|
||||
});
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
at::Tensor
|
||||
causal_conv1d_update(const at::Tensor &x,
|
||||
const at::Tensor &conv_state,
|
||||
const at::Tensor &weight,
|
||||
const c10::optional<at::Tensor> &bias_,
|
||||
bool silu_activation) {
|
||||
auto input_type = x.scalar_type();
|
||||
auto weight_type = weight.scalar_type();
|
||||
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
||||
TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16);
|
||||
TORCH_CHECK(weight_type == input_type, "weight type must equal to input type, other variations are disabled due to binary size limitations");
|
||||
TORCH_CHECK(conv_state.scalar_type() == input_type);
|
||||
|
||||
TORCH_CHECK(x.is_cuda());
|
||||
TORCH_CHECK(conv_state.is_cuda());
|
||||
TORCH_CHECK(weight.is_cuda());
|
||||
|
||||
const auto sizes = x.sizes();
|
||||
const int batch_size = sizes[0];
|
||||
const int dim = sizes[1];
|
||||
const int width = weight.size(-1);
|
||||
|
||||
CHECK_SHAPE(x, batch_size, dim);
|
||||
CHECK_SHAPE(conv_state, batch_size, dim, width);
|
||||
CHECK_SHAPE(weight, dim, width);
|
||||
|
||||
TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4");
|
||||
|
||||
if (bias_.has_value()) {
|
||||
auto bias = bias_.value();
|
||||
TORCH_CHECK(bias.scalar_type() == weight_type);
|
||||
TORCH_CHECK(bias.is_cuda());
|
||||
TORCH_CHECK(bias.stride(-1) == 1);
|
||||
CHECK_SHAPE(bias, dim);
|
||||
}
|
||||
|
||||
at::Tensor out = torch::empty_like(x);
|
||||
|
||||
ConvParamsBase params;
|
||||
set_conv_params_fwd(params, batch_size, dim, /*seqlen=*/1, width, x, weight, out,
|
||||
bias_.has_value() ? bias_.value().data_ptr() : nullptr,
|
||||
silu_activation);
|
||||
params.conv_state_ptr = conv_state.data_ptr();
|
||||
// All stride are in elements, not bytes.
|
||||
params.conv_state_batch_stride = conv_state.stride(0);
|
||||
params.conv_state_c_stride = conv_state.stride(1);
|
||||
params.conv_state_l_stride = conv_state.stride(2);
|
||||
|
||||
// Otherwise the kernel will be launched from cuda:0 device
|
||||
// Cast to char to avoid compiler warning about narrowing
|
||||
at::cuda::CUDAGuard device_guard{(char)x.get_device()};
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_update", [&] {
|
||||
causal_conv1d_update_cuda<input_t, weight_t>(params, stream);
|
||||
});
|
||||
return out;
|
||||
}
|
||||
|
||||
template<int kNThreads_, int kWidth_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
|
||||
struct Causal_conv1d_fwd_kernel_traits {
|
||||
using input_t = input_t_;
|
||||
using weight_t = weight_t_;
|
||||
static constexpr int kNThreads = kNThreads_;
|
||||
static constexpr int kWidth = kWidth_;
|
||||
static constexpr int kNBytes = sizeof(input_t);
|
||||
static_assert(kNBytes == 2 || kNBytes == 4);
|
||||
static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
|
||||
static_assert(kWidth <= kNElts);
|
||||
static constexpr bool kIsVecLoad = kIsVecLoad_;
|
||||
using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
|
||||
using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNElts, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
||||
using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, 1, cub::BLOCK_LOAD_DIRECT>;
|
||||
using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNElts, cub::BLOCK_STORE_WARP_TRANSPOSE>;
|
||||
using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, 1, cub::BLOCK_STORE_DIRECT>;
|
||||
static constexpr int kSmemIOSize = kIsVecLoad
|
||||
? 0
|
||||
: custom_max({sizeof(typename BlockLoadT::TempStorage), sizeof(typename BlockStoreT::TempStorage)});
|
||||
static constexpr int kSmemExchangeSize = kNThreads * kNBytes * kNElts;
|
||||
static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize;
|
||||
};
|
||||
|
||||
template<typename Ktraits>
|
||||
__global__ __launch_bounds__(Ktraits::kNThreads)
|
||||
void causal_conv1d_fwd_kernel(ConvParamsBase params) {
|
||||
constexpr int kWidth = Ktraits::kWidth;
|
||||
constexpr int kNThreads = Ktraits::kNThreads;
|
||||
constexpr int kNElts = Ktraits::kNElts;
|
||||
static constexpr bool kIsVecLoad = Ktraits::kIsVecLoad;
|
||||
using input_t = typename Ktraits::input_t;
|
||||
using vec_t = typename Ktraits::vec_t;
|
||||
using weight_t = typename Ktraits::weight_t;
|
||||
|
||||
// Shared memory.
|
||||
extern __shared__ char smem_[];
|
||||
auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
|
||||
auto& smem_load_vec = reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(smem_);
|
||||
auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
|
||||
auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_);
|
||||
vec_t *smem_exchange = reinterpret_cast<vec_t *>(smem_ + Ktraits::kSmemIOSize);
|
||||
|
||||
const int tidx = threadIdx.x;
|
||||
const int batch_id = blockIdx.x;
|
||||
const int channel_id = blockIdx.y;
|
||||
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
|
||||
+ channel_id * params.x_c_stride;
|
||||
weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride;
|
||||
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
|
||||
+ channel_id * params.out_c_stride;
|
||||
float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[channel_id]);
|
||||
|
||||
// Thread 0 will load the last elements of the previous chunk, so we initialize those to 0.
|
||||
if (tidx == 0) {
|
||||
input_t zeros[kNElts] = {0};
|
||||
smem_exchange[kNThreads - 1] = reinterpret_cast<vec_t *>(zeros)[0];
|
||||
}
|
||||
|
||||
float weight_vals[kWidth];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); }
|
||||
|
||||
constexpr int kChunkSize = kNThreads * kNElts;
|
||||
const int n_chunks = (params.seqlen + kChunkSize - 1) / kChunkSize;
|
||||
for (int chunk = 0; chunk < n_chunks; ++chunk) {
|
||||
input_t x_vals_load[2 * kNElts] = {0};
|
||||
if constexpr(kIsVecLoad) {
|
||||
typename Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast<vec_t*>(x), *reinterpret_cast<vec_t (*)[1]>(&x_vals_load[kNElts]), (params.seqlen - chunk * kChunkSize) / kNElts);
|
||||
} else {
|
||||
__syncthreads();
|
||||
typename Ktraits::BlockLoadT(smem_load).Load(x, *reinterpret_cast<input_t (*)[kNElts]>(&x_vals_load[kNElts]), params.seqlen - chunk * kChunkSize);
|
||||
}
|
||||
x += kChunkSize;
|
||||
__syncthreads();
|
||||
// Thread kNThreads - 1 don't write yet, so that thread 0 can read
|
||||
// the last elements of the previous chunk.
|
||||
if (tidx < kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(x_vals_load)[1]; }
|
||||
__syncthreads();
|
||||
reinterpret_cast<vec_t *>(x_vals_load)[0] = smem_exchange[tidx > 0 ? tidx - 1 : kNThreads - 1];
|
||||
__syncthreads();
|
||||
// Now thread kNThreads - 1 can write the last elements of the current chunk.
|
||||
if (tidx == kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(x_vals_load)[1]; }
|
||||
|
||||
float x_vals[2 * kNElts];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 2 * kNElts; ++i) { x_vals[i] = float(x_vals_load[i]); }
|
||||
|
||||
float out_vals[kNElts];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNElts; ++i) {
|
||||
out_vals[i] = bias_val;
|
||||
#pragma unroll
|
||||
for (int w = 0; w < kWidth; ++w) {
|
||||
out_vals[i] += weight_vals[w] * x_vals[kNElts + i - (kWidth - w - 1)];
|
||||
}
|
||||
}
|
||||
|
||||
if (params.silu_activation) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNElts; ++i) {
|
||||
out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i]));
|
||||
}
|
||||
}
|
||||
|
||||
input_t out_vals_store[kNElts];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNElts; ++i) { out_vals_store[i] = out_vals[i]; }
|
||||
if constexpr(kIsVecLoad) {
|
||||
typename Ktraits::BlockStoreVecT(smem_store_vec).Store(reinterpret_cast<vec_t*>(out), reinterpret_cast<vec_t (&)[1]>(out_vals_store), (params.seqlen - chunk * kChunkSize) / kNElts);
|
||||
} else {
|
||||
typename Ktraits::BlockStoreT(smem_store).Store(out, out_vals_store, params.seqlen - chunk * kChunkSize);
|
||||
}
|
||||
out += kChunkSize;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template<int kNThreads, int kWidth, typename input_t, typename weight_t>
|
||||
void causal_conv1d_fwd_launch(ConvParamsBase ¶ms, cudaStream_t stream) {
|
||||
static constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8;
|
||||
BOOL_SWITCH(params.seqlen % kNElts == 0, kIsVecLoad, [&] {
|
||||
using Ktraits = Causal_conv1d_fwd_kernel_traits<kNThreads, kWidth, kIsVecLoad, input_t, weight_t>;
|
||||
constexpr int kSmemSize = Ktraits::kSmemSize;
|
||||
dim3 grid(params.batch, params.dim);
|
||||
|
||||
auto kernel = &causal_conv1d_fwd_kernel<Ktraits>;
|
||||
|
||||
if (kSmemSize >= 48 * 1024) {
|
||||
#ifndef USE_ROCM
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
||||
#else
|
||||
// There is a slight signature discrepancy in HIP and CUDA "FuncSetAttribute" function.
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
(void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
||||
std::cerr << "Warning (causal_conv1d fwd launch): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl;
|
||||
#endif
|
||||
}
|
||||
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
||||
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
});
|
||||
}
|
||||
|
||||
template<typename input_t, typename weight_t>
|
||||
void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream) {
|
||||
if (params.width == 2) {
|
||||
causal_conv1d_fwd_launch<128, 2, input_t, weight_t>(params, stream);
|
||||
} else if (params.width == 3) {
|
||||
causal_conv1d_fwd_launch<128, 3, input_t, weight_t>(params, stream);
|
||||
} else if (params.width == 4) {
|
||||
causal_conv1d_fwd_launch<128, 4, input_t, weight_t>(params, stream);
|
||||
}
|
||||
}
|
||||
|
||||
template<int kNThreads_, int kWidth_, int kChunkSizeL_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
|
||||
struct Causal_conv1d_channellast_fwd_kernel_traits {
|
||||
// The cache line is 128 bytes, and we try to read 16 bytes per thread.
|
||||
// So we have 8 threads per "row", so 32 or 64 elements in the channel dimension.
|
||||
// That leaves 4 columns per warp, and so 16 columns per block (assuming each block has 128
|
||||
// threads). Each each load is 16 x 32|64 elements in the L x C dimensions.
|
||||
using input_t = input_t_;
|
||||
using weight_t = weight_t_;
|
||||
static constexpr int kNThreads = kNThreads_;
|
||||
static_assert(kNThreads % 32 == 0);
|
||||
static constexpr int kNWarps = kNThreads / 32;
|
||||
static constexpr int kWidth = kWidth_;
|
||||
static constexpr int kChunkSizeL = kChunkSizeL_;
|
||||
static constexpr int kNBytes = sizeof(input_t);
|
||||
static_assert(kNBytes == 2 || kNBytes == 4);
|
||||
static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
|
||||
static constexpr int kNEltsPerRow = 128 / kNBytes;
|
||||
static constexpr int kNThreadsPerRow = kNEltsPerRow / kNElts; // Always 8 for now
|
||||
static_assert(kNThreadsPerRow * kNBytes * kNElts == 128);
|
||||
static constexpr int kNColsPerWarp = 32 / kNThreadsPerRow; // Always 4 for now
|
||||
static_assert(kNColsPerWarp * kNThreadsPerRow == 32);
|
||||
static constexpr int kNColsPerLoad = kNColsPerWarp * kNWarps;
|
||||
static constexpr int kNLoads = kChunkSizeL / kNColsPerLoad;
|
||||
static_assert(kNLoads * kNColsPerLoad == kChunkSizeL);
|
||||
static constexpr bool kIsVecLoad = kIsVecLoad_;
|
||||
using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
|
||||
// using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
||||
// using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
|
||||
// static constexpr int kSmemSize = std::max({sizeof(typename BlockLoadT::TempStorage),
|
||||
// sizeof(typename BlockStoreT::TempStorage)});
|
||||
// static constexpr int kSmemSize = kChunkSizeL * kNEltsPerRow * kNBytes;
|
||||
};
|
||||
|
||||
template<typename Ktraits, bool kHasSeqIdx>
|
||||
__global__ __launch_bounds__(Ktraits::kNThreads)
|
||||
void causal_conv1d_channellast_fwd_kernel(ConvParamsBase params) {
|
||||
constexpr int kWidth = Ktraits::kWidth;
|
||||
constexpr int kNThreads = Ktraits::kNThreads;
|
||||
constexpr int kNElts = Ktraits::kNElts;
|
||||
constexpr int kNThreadsPerC = Ktraits::kNThreadsPerRow;
|
||||
constexpr int kLPerLoad = Ktraits::kNColsPerLoad;
|
||||
constexpr int kChunkSizeL = Ktraits::kChunkSizeL;
|
||||
constexpr int kChunkSizeC = Ktraits::kNEltsPerRow;
|
||||
using input_t = typename Ktraits::input_t;
|
||||
using vec_t = typename Ktraits::vec_t;
|
||||
using weight_t = typename Ktraits::weight_t;
|
||||
|
||||
// Shared memory.
|
||||
__shared__ input_t x_smem[kWidth - 1 + kChunkSizeL][kChunkSizeC + kNElts];
|
||||
|
||||
const int batch_id = blockIdx.x;
|
||||
const int chunk_l_id = blockIdx.y;
|
||||
const int chunk_c_id = blockIdx.z;
|
||||
const int tid = threadIdx.x;
|
||||
const int l_idx = tid / kNThreadsPerC;
|
||||
const int c_idx = tid % kNThreadsPerC;
|
||||
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
|
||||
+ (chunk_l_id * kChunkSizeL + l_idx) * params.x_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
|
||||
weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr)
|
||||
+ chunk_c_id * kChunkSizeC * params.weight_c_stride;
|
||||
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
|
||||
+ (chunk_l_id * kChunkSizeL + l_idx) * params.out_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
|
||||
int *seq_idx = !kHasSeqIdx ? nullptr : reinterpret_cast<int *>(params.seq_idx_ptr)
|
||||
+ batch_id * params.seqlen + chunk_l_id * kChunkSizeL;
|
||||
input_t *initial_states = params.initial_states_ptr == nullptr || chunk_l_id > 0 ? nullptr
|
||||
: reinterpret_cast<input_t *>(params.initial_states_ptr) + batch_id * params.initial_states_batch_stride + l_idx * params.initial_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
|
||||
// The last L-chunk will also have enough info to write to final states, since it also contain a few x values
|
||||
// from the previous L-chunk.
|
||||
input_t *final_states = params.final_states_ptr == nullptr || chunk_l_id < gridDim.y - 1 ? nullptr
|
||||
: reinterpret_cast<input_t *>(params.final_states_ptr) + batch_id * params.final_states_batch_stride + l_idx * params.final_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < Ktraits::kNLoads; ++l) {
|
||||
input_t x_vals_load[kNElts] = {0};
|
||||
if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen
|
||||
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
||||
reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x + l * kLPerLoad * params.x_l_stride);
|
||||
}
|
||||
reinterpret_cast<vec_t *>(x_smem[kWidth - 1 + l * kLPerLoad + l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0];
|
||||
}
|
||||
// Load the elements from the previous chunk that are needed for convolution.
|
||||
if (l_idx < kWidth - 1) {
|
||||
input_t x_vals_load[kNElts] = {0};
|
||||
if (chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) >= 0
|
||||
&& chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < params.seqlen
|
||||
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
||||
reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x - (kWidth - 1) * params.x_l_stride);
|
||||
} else if (initial_states != nullptr
|
||||
&& chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < 0
|
||||
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
||||
reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(initial_states);
|
||||
}
|
||||
reinterpret_cast<vec_t *>(x_smem[l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (final_states != nullptr
|
||||
&& l_idx < kWidth - 1
|
||||
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
||||
// x_smem[0] contains element at index chunk_l_id * kChunkSizeL - (kWidth - 1)
|
||||
// So last few elements (index params.seqlen - kWidth + 1 + l_idx) are stored in x_smem[params.seqlen - kWidth + 1 + l_idx - (chunk_l_id * kChunkSizeL - kWidth + 1)][c_idx]
|
||||
*reinterpret_cast<vec_t *>(final_states) = reinterpret_cast<vec_t *>(x_smem[params.seqlen + l_idx - chunk_l_id * kChunkSizeL])[c_idx];
|
||||
}
|
||||
|
||||
constexpr int kLPerThread = constexpr_min(kChunkSizeL * kChunkSizeC / kNThreads, kChunkSizeL);
|
||||
static_assert(kLPerThread * kNThreads == kChunkSizeL * kChunkSizeC);
|
||||
constexpr int kNThreadsPerRow = kChunkSizeL / kLPerThread;
|
||||
static_assert(kNThreadsPerRow * kLPerThread == kChunkSizeL);
|
||||
// kChunkSizeL, kLPerThread, kNThreadsPerRow should be powers of 2 for simplicity
|
||||
static_assert((kChunkSizeL & (kChunkSizeL - 1)) == 0);
|
||||
static_assert((kLPerThread & (kLPerThread - 1)) == 0);
|
||||
static_assert((kNThreadsPerRow & (kNThreadsPerRow - 1)) == 0);
|
||||
static_assert(kNThreadsPerRow <= 32);
|
||||
|
||||
const int row_idx = tid / kNThreadsPerRow;
|
||||
const int col_idx = tid % kNThreadsPerRow;
|
||||
|
||||
float bias_val = params.bias_ptr == nullptr || chunk_c_id * kChunkSizeC + row_idx >= params.dim ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[chunk_c_id * kChunkSizeC + row_idx]);
|
||||
float weight_vals[kWidth] = {0};
|
||||
if (chunk_c_id * kChunkSizeC + row_idx < params.dim) {
|
||||
#pragma unroll
|
||||
for (int w = 0; w < kWidth; ++w) {
|
||||
weight_vals[w] = weight[row_idx * params.weight_c_stride + w * params.weight_width_stride];
|
||||
}
|
||||
}
|
||||
float x_vals[kWidth - 1 + kLPerThread];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) {
|
||||
x_vals[i] = float(x_smem[col_idx * kLPerThread + i][row_idx]);
|
||||
}
|
||||
int seq_idx_thread[kWidth - 1 + kLPerThread];
|
||||
if constexpr (kHasSeqIdx) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) {
|
||||
seq_idx_thread[i] = chunk_l_id * kChunkSizeL + col_idx * kLPerThread + i - (kWidth - 1) >= 0 ? seq_idx[col_idx * kLPerThread + i - (kWidth - 1)] : -1;
|
||||
}
|
||||
}
|
||||
|
||||
float out_vals[kLPerThread];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kLPerThread; ++i) {
|
||||
out_vals[i] = bias_val;
|
||||
const int seq_idx_cur = !kHasSeqIdx ? 0 : seq_idx_thread[i + kWidth - 1];
|
||||
#pragma unroll
|
||||
for (int w = 0; w < kWidth; ++w) {
|
||||
if constexpr (!kHasSeqIdx) {
|
||||
out_vals[i] += weight_vals[w] * x_vals[i + w];
|
||||
} else {
|
||||
out_vals[i] += seq_idx_thread[i + w] == seq_idx_cur ? weight_vals[w] * x_vals[i + w] : 0.f;
|
||||
}
|
||||
}
|
||||
if (params.silu_activation) {out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i])); }
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kLPerThread; ++i) { x_smem[col_idx * kLPerThread + i][row_idx] = out_vals[i]; }
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < Ktraits::kNLoads; ++l) {
|
||||
input_t out_vals_store[kNElts];
|
||||
reinterpret_cast<vec_t *>(out_vals_store)[0] = reinterpret_cast<vec_t *>(x_smem[l * kLPerLoad + l_idx])[c_idx];
|
||||
if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen
|
||||
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
||||
*reinterpret_cast<vec_t *>(out + l * kLPerLoad * params.out_l_stride) = reinterpret_cast<vec_t *>(out_vals_store)[0];
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
template<int kNThreads, int kWidth, typename input_t, typename weight_t>
|
||||
void causal_conv1d_channellast_fwd_launch(ConvParamsBase ¶ms, cudaStream_t stream) {
|
||||
BOOL_SWITCH(params.seq_idx_ptr != nullptr, kHasSeqIdx, [&] {
|
||||
using Ktraits = Causal_conv1d_channellast_fwd_kernel_traits<kNThreads, kWidth, 64, true, input_t, weight_t>;
|
||||
// constexpr int kSmemSize = Ktraits::kSmemSize;
|
||||
constexpr int kChunkSizeL = Ktraits::kChunkSizeL;
|
||||
constexpr int kChunkSizeC = Ktraits::kNEltsPerRow;
|
||||
const int n_chunks_L = (params.seqlen + kChunkSizeL - 1) / kChunkSizeL;
|
||||
const int n_chunks_C = (params.dim + kChunkSizeC - 1) / kChunkSizeC;
|
||||
dim3 grid(params.batch, n_chunks_L, n_chunks_C);
|
||||
dim3 block(Ktraits::kNThreads);
|
||||
auto kernel = &causal_conv1d_channellast_fwd_kernel<Ktraits, kHasSeqIdx>;
|
||||
// if (kSmemSize >= 48 * 1024) {
|
||||
// C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
// kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
||||
// }
|
||||
// kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
||||
kernel<<<grid, Ktraits::kNThreads, 0, stream>>>(params);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
});
|
||||
}
|
||||
|
||||
template<typename input_t, typename weight_t>
|
||||
void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream) {
|
||||
if (params.width == 2) {
|
||||
causal_conv1d_channellast_fwd_launch<128, 2, input_t, weight_t>(params, stream);
|
||||
} else if (params.width == 3) {
|
||||
causal_conv1d_channellast_fwd_launch<128, 3, input_t, weight_t>(params, stream);
|
||||
} else if (params.width == 4) {
|
||||
causal_conv1d_channellast_fwd_launch<128, 4, input_t, weight_t>(params, stream);
|
||||
}
|
||||
}
|
||||
|
||||
template void causal_conv1d_fwd_cuda<float, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
template void causal_conv1d_fwd_cuda<at::Half, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
template void causal_conv1d_fwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
|
||||
template void causal_conv1d_channellast_fwd_cuda<float, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
template void causal_conv1d_channellast_fwd_cuda<at::Half, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
template void causal_conv1d_channellast_fwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
///////
|
||||
|
||||
|
||||
|
||||
|
||||
template<int kNThreads_, int kWidth_, typename input_t_, typename weight_t_>
|
||||
struct Causal_conv1d_update_kernel_traits {
|
||||
using input_t = input_t_;
|
||||
using weight_t = weight_t_;
|
||||
static constexpr int kNThreads = kNThreads_;
|
||||
static constexpr int kWidth = kWidth_;
|
||||
static constexpr int kNBytes = sizeof(input_t);
|
||||
static_assert(kNBytes == 2 || kNBytes == 4);
|
||||
};
|
||||
|
||||
template<typename Ktraits>
|
||||
__global__ __launch_bounds__(Ktraits::kNThreads)
|
||||
void causal_conv1d_update_kernel(ConvParamsBase params) {
|
||||
constexpr int kWidth = Ktraits::kWidth;
|
||||
constexpr int kNThreads = Ktraits::kNThreads;
|
||||
using input_t = typename Ktraits::input_t;
|
||||
using weight_t = typename Ktraits::weight_t;
|
||||
|
||||
const int tidx = threadIdx.x;
|
||||
const int batch_id = blockIdx.x;
|
||||
const int channel_id = blockIdx.y * kNThreads + tidx;
|
||||
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
|
||||
+ channel_id * params.x_c_stride;
|
||||
input_t *conv_state = reinterpret_cast<input_t *>(params.conv_state_ptr) + batch_id * params.conv_state_batch_stride
|
||||
+ channel_id * params.conv_state_c_stride;
|
||||
weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride;
|
||||
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
|
||||
+ channel_id * params.out_c_stride;
|
||||
float bias_val = params.bias_ptr == nullptr || channel_id >= params.dim ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[channel_id]);
|
||||
|
||||
float weight_vals[kWidth] = {0};
|
||||
if (channel_id < params.dim) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); }
|
||||
}
|
||||
|
||||
float x_vals[kWidth] = {0};
|
||||
if (channel_id < params.dim) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kWidth - 1; ++i) { x_vals[i] = float(conv_state[(i + 1) * params.conv_state_l_stride]); }
|
||||
x_vals[kWidth - 1] = float(x[0]);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kWidth; ++i) { conv_state[i * params.conv_state_l_stride] = input_t(x_vals[i]); }
|
||||
}
|
||||
|
||||
float out_val = bias_val;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kWidth; ++i) { out_val += weight_vals[i] * x_vals[i]; }
|
||||
if (params.silu_activation) { out_val = out_val / (1 + expf(-out_val)); }
|
||||
if (channel_id < params.dim) { out[0] = input_t(out_val); }
|
||||
}
|
||||
|
||||
template<int kNThreads, int kWidth, typename input_t, typename weight_t>
|
||||
void causal_conv1d_update_launch(ConvParamsBase ¶ms, cudaStream_t stream) {
|
||||
using Ktraits = Causal_conv1d_update_kernel_traits<kNThreads, kWidth, input_t, weight_t>;
|
||||
dim3 grid(params.batch, (params.dim + kNThreads - 1) / kNThreads);
|
||||
auto kernel = &causal_conv1d_update_kernel<Ktraits>;
|
||||
kernel<<<grid, Ktraits::kNThreads, 0, stream>>>(params);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
}
|
||||
|
||||
template<typename input_t, typename weight_t>
|
||||
void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream) {
|
||||
if (params.width == 2) {
|
||||
causal_conv1d_update_launch<64, 2, input_t, weight_t>(params, stream);
|
||||
} else if (params.width == 3) {
|
||||
causal_conv1d_update_launch<64, 3, input_t, weight_t>(params, stream);
|
||||
} else if (params.width == 4) {
|
||||
causal_conv1d_update_launch<64, 4, input_t, weight_t>(params, stream);
|
||||
}
|
||||
}
|
||||
|
||||
template void causal_conv1d_update_cuda<float, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
template void causal_conv1d_update_cuda<at::Half, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
template void causal_conv1d_update_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
144
csrc/mamba/causal_conv1d/causal_conv1d.h
Normal file
144
csrc/mamba/causal_conv1d/causal_conv1d.h
Normal file
@ -0,0 +1,144 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2024, Tri Dao.
|
||||
******************************************************************************/
|
||||
// clang-format off
|
||||
// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d.h
|
||||
#pragma once
|
||||
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct ConvParamsBase {
|
||||
using index_t = uint32_t;
|
||||
|
||||
int batch, dim, seqlen, width;
|
||||
bool silu_activation;
|
||||
|
||||
index_t x_batch_stride;
|
||||
index_t x_c_stride;
|
||||
index_t x_l_stride;
|
||||
index_t weight_c_stride;
|
||||
index_t weight_width_stride;
|
||||
index_t out_batch_stride;
|
||||
index_t out_c_stride;
|
||||
index_t out_l_stride;
|
||||
|
||||
index_t conv_state_batch_stride;
|
||||
index_t conv_state_c_stride;
|
||||
index_t conv_state_l_stride;
|
||||
|
||||
// Common data pointers.
|
||||
void *__restrict__ x_ptr;
|
||||
void *__restrict__ weight_ptr;
|
||||
void *__restrict__ bias_ptr;
|
||||
void *__restrict__ out_ptr;
|
||||
|
||||
void *__restrict__ conv_state_ptr;
|
||||
|
||||
void *__restrict__ seq_idx_ptr;
|
||||
|
||||
// No __restrict__ since initial_states could be the same as final_states.
|
||||
void * initial_states_ptr;
|
||||
index_t initial_states_batch_stride;
|
||||
index_t initial_states_l_stride;
|
||||
index_t initial_states_c_stride;
|
||||
|
||||
void * final_states_ptr;
|
||||
index_t final_states_batch_stride;
|
||||
index_t final_states_l_stride;
|
||||
index_t final_states_c_stride;
|
||||
};
|
||||
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#include <cuda_bf16.h>
|
||||
|
||||
template<typename T>
|
||||
__device__ inline T shuffle_xor(T val, int offset) {
|
||||
return __shfl_xor_sync(uint32_t(-1), val, offset);
|
||||
}
|
||||
|
||||
constexpr size_t custom_max(std::initializer_list<size_t> ilist)
|
||||
{
|
||||
return std::max(ilist);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
constexpr T constexpr_min(T a, T b) {
|
||||
return std::min(a, b);
|
||||
}
|
||||
|
||||
#else
|
||||
#include <hip/hip_bf16.h>
|
||||
|
||||
template<typename T>
|
||||
__device__ inline T shuffle_xor(T val, int offset) {
|
||||
return __shfl_xor(val, offset);
|
||||
}
|
||||
constexpr size_t custom_max(std::initializer_list<size_t> ilist)
|
||||
{
|
||||
return *std::max_element(ilist.begin(), ilist.end());
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
constexpr T constexpr_min(T a, T b) {
|
||||
return a < b ? a : b;
|
||||
}
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<int BYTES> struct BytesToType {};
|
||||
|
||||
template<> struct BytesToType<16> {
|
||||
using Type = uint4;
|
||||
static_assert(sizeof(Type) == 16);
|
||||
};
|
||||
|
||||
template<> struct BytesToType<8> {
|
||||
using Type = uint64_t;
|
||||
static_assert(sizeof(Type) == 8);
|
||||
};
|
||||
|
||||
template<> struct BytesToType<4> {
|
||||
using Type = uint32_t;
|
||||
static_assert(sizeof(Type) == 4);
|
||||
};
|
||||
|
||||
template<> struct BytesToType<2> {
|
||||
using Type = uint16_t;
|
||||
static_assert(sizeof(Type) == 2);
|
||||
};
|
||||
|
||||
template<> struct BytesToType<1> {
|
||||
using Type = uint8_t;
|
||||
static_assert(sizeof(Type) == 1);
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T>
|
||||
struct SumOp {
|
||||
__device__ inline T operator()(T const & x, T const & y) { return x + y; }
|
||||
};
|
||||
|
||||
template<int THREADS>
|
||||
struct Allreduce {
|
||||
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
|
||||
template<typename T, typename Operator>
|
||||
static __device__ inline T run(T x, Operator &op) {
|
||||
constexpr int OFFSET = THREADS / 2;
|
||||
x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
|
||||
return Allreduce<OFFSET>::run(x, op);
|
||||
}
|
||||
};
|
||||
|
||||
template<>
|
||||
struct Allreduce<2> {
|
||||
template<typename T, typename Operator>
|
||||
static __device__ inline T run(T x, Operator &op) {
|
||||
x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
|
||||
return x;
|
||||
}
|
||||
};
|
28
csrc/mamba/causal_conv1d/static_switch.h
Normal file
28
csrc/mamba/causal_conv1d/static_switch.h
Normal file
@ -0,0 +1,28 @@
|
||||
// Inspired by
|
||||
// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
|
||||
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
|
||||
// clang-format off
|
||||
// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/static_switch.h
|
||||
|
||||
#pragma once
|
||||
|
||||
/// @param COND - a boolean expression to switch by
|
||||
/// @param CONST_NAME - a name given for the constexpr bool variable.
|
||||
/// @param ... - code to execute for true and false
|
||||
///
|
||||
/// Usage:
|
||||
/// ```
|
||||
/// BOOL_SWITCH(flag, BoolConst, [&] {
|
||||
/// some_function<BoolConst>(...);
|
||||
/// });
|
||||
/// ```
|
||||
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
|
||||
[&] { \
|
||||
if (COND) { \
|
||||
static constexpr bool CONST_NAME = true; \
|
||||
return __VA_ARGS__(); \
|
||||
} else { \
|
||||
static constexpr bool CONST_NAME = false; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
}()
|
276
csrc/mamba/mamba_ssm/selective_scan.h
Normal file
276
csrc/mamba/mamba_ssm/selective_scan.h
Normal file
@ -0,0 +1,276 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2023, Tri Dao.
|
||||
******************************************************************************/
|
||||
// clang-format off
|
||||
// adapted from https://github.com/state-spaces/mamba/blob/main/csrc/selective_scan/selective_scan.h
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#include <cuda_bf16.h>
|
||||
#else
|
||||
#include <hip/hip_bf16.h>
|
||||
#endif
|
||||
#include <cuda_fp16.h>
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct SSMParamsBase {
|
||||
using index_t = uint32_t;
|
||||
|
||||
int batch, dim, seqlen, dstate, n_groups, n_chunks;
|
||||
int dim_ngroups_ratio;
|
||||
bool is_variable_B;
|
||||
bool is_variable_C;
|
||||
|
||||
bool delta_softplus;
|
||||
|
||||
index_t A_d_stride;
|
||||
index_t A_dstate_stride;
|
||||
index_t B_batch_stride;
|
||||
index_t B_d_stride;
|
||||
index_t B_dstate_stride;
|
||||
index_t B_group_stride;
|
||||
index_t C_batch_stride;
|
||||
index_t C_d_stride;
|
||||
index_t C_dstate_stride;
|
||||
index_t C_group_stride;
|
||||
index_t u_batch_stride;
|
||||
index_t u_d_stride;
|
||||
index_t delta_batch_stride;
|
||||
index_t delta_d_stride;
|
||||
index_t z_batch_stride;
|
||||
index_t z_d_stride;
|
||||
index_t out_batch_stride;
|
||||
index_t out_d_stride;
|
||||
index_t out_z_batch_stride;
|
||||
index_t out_z_d_stride;
|
||||
|
||||
// Common data pointers.
|
||||
void *__restrict__ A_ptr;
|
||||
void *__restrict__ B_ptr;
|
||||
void *__restrict__ C_ptr;
|
||||
void *__restrict__ D_ptr;
|
||||
void *__restrict__ u_ptr;
|
||||
void *__restrict__ delta_ptr;
|
||||
void *__restrict__ delta_bias_ptr;
|
||||
void *__restrict__ out_ptr;
|
||||
void *__restrict__ x_ptr;
|
||||
void *__restrict__ z_ptr;
|
||||
void *__restrict__ out_z_ptr;
|
||||
void *__restrict__ index_ptr;
|
||||
};
|
||||
|
||||
|
||||
|
||||
|
||||
#ifndef USE_ROCM
|
||||
|
||||
constexpr size_t custom_max(std::initializer_list<size_t> ilist)
|
||||
{
|
||||
return std::max(ilist);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
constexpr T constexpr_min(T a, T b) {
|
||||
return std::min(a, b);
|
||||
}
|
||||
|
||||
#else
|
||||
constexpr size_t custom_max(std::initializer_list<size_t> ilist)
|
||||
{
|
||||
return *std::max_element(ilist.begin(), ilist.end());
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
constexpr T constexpr_min(T a, T b) {
|
||||
return a < b ? a : b;
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
#define MAX_DSTATE 256
|
||||
|
||||
|
||||
inline __device__ float2 operator+(const float2 & a, const float2 & b){
|
||||
return {a.x + b.x, a.y + b.y};
|
||||
}
|
||||
|
||||
inline __device__ float3 operator+(const float3 &a, const float3 &b) {
|
||||
return {a.x + b.x, a.y + b.y, a.z + b.z};
|
||||
}
|
||||
|
||||
inline __device__ float4 operator+(const float4 & a, const float4 & b){
|
||||
return {a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w};
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<int BYTES> struct BytesToType {};
|
||||
|
||||
template<> struct BytesToType<16> {
|
||||
using Type = uint4;
|
||||
static_assert(sizeof(Type) == 16);
|
||||
};
|
||||
|
||||
template<> struct BytesToType<8> {
|
||||
using Type = uint64_t;
|
||||
static_assert(sizeof(Type) == 8);
|
||||
};
|
||||
|
||||
template<> struct BytesToType<4> {
|
||||
using Type = uint32_t;
|
||||
static_assert(sizeof(Type) == 4);
|
||||
};
|
||||
|
||||
template<> struct BytesToType<2> {
|
||||
using Type = uint16_t;
|
||||
static_assert(sizeof(Type) == 2);
|
||||
};
|
||||
|
||||
template<> struct BytesToType<1> {
|
||||
using Type = uint8_t;
|
||||
static_assert(sizeof(Type) == 1);
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename scalar_t, int N>
|
||||
struct Converter{
|
||||
static inline __device__ void to_float(const scalar_t (&src)[N], float (&dst)[N]) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N; ++i) { dst[i] = src[i]; }
|
||||
}
|
||||
};
|
||||
|
||||
template<int N>
|
||||
struct Converter<at::Half, N>{
|
||||
static inline __device__ void to_float(const at::Half (&src)[N], float (&dst)[N]) {
|
||||
static_assert(N % 2 == 0);
|
||||
auto &src2 = reinterpret_cast<const half2 (&)[N / 2]>(src);
|
||||
auto &dst2 = reinterpret_cast<float2 (&)[N / 2]>(dst);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N / 2; ++i) { dst2[i] = __half22float2(src2[i]); }
|
||||
}
|
||||
};
|
||||
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
template<int N>
|
||||
struct Converter<at::BFloat16, N>{
|
||||
static inline __device__ void to_float(const at::BFloat16 (&src)[N], float (&dst)[N]) {
|
||||
static_assert(N % 2 == 0);
|
||||
auto &src2 = reinterpret_cast<const nv_bfloat162 (&)[N / 2]>(src);
|
||||
auto &dst2 = reinterpret_cast<float2 (&)[N / 2]>(dst);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N / 2; ++i) { dst2[i] = __bfloat1622float2(src2[i]); }
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
template<typename scalar_t> struct SSMScanOp;
|
||||
|
||||
template<>
|
||||
struct SSMScanOp<float> {
|
||||
__device__ __forceinline__ float2 operator()(const float2 &ab0, const float2 &ab1) const {
|
||||
return make_float2(ab1.x * ab0.x, ab1.x * ab0.y + ab1.y);
|
||||
}
|
||||
};
|
||||
|
||||
// A stateful callback functor that maintains a running prefix to be applied
|
||||
// during consecutive scan operations.
|
||||
template <typename scalar_t> struct SSMScanPrefixCallbackOp {
|
||||
using scan_t = std::conditional_t<std::is_same_v<scalar_t, float>, float2, float4>;
|
||||
scan_t running_prefix;
|
||||
// Constructor
|
||||
__device__ SSMScanPrefixCallbackOp(scan_t running_prefix_) : running_prefix(running_prefix_) {}
|
||||
// Callback operator to be entered by the first warp of threads in the block.
|
||||
// Thread-0 is responsible for returning a value for seeding the block-wide scan.
|
||||
__device__ scan_t operator()(scan_t block_aggregate) {
|
||||
scan_t old_prefix = running_prefix;
|
||||
running_prefix = SSMScanOp<scalar_t>()(running_prefix, block_aggregate);
|
||||
return old_prefix;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Ktraits>
|
||||
inline __device__ void load_input(typename Ktraits::input_t *u,
|
||||
typename Ktraits::input_t (&u_vals)[Ktraits::kNItems],
|
||||
typename Ktraits::BlockLoadT::TempStorage &smem_load,
|
||||
int seqlen) {
|
||||
if constexpr (Ktraits::kIsEvenLen) {
|
||||
auto& smem_load_vec = reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(smem_load);
|
||||
using vec_t = typename Ktraits::vec_t;
|
||||
typename Ktraits::BlockLoadVecT(smem_load_vec).Load(
|
||||
reinterpret_cast<vec_t*>(u),
|
||||
reinterpret_cast<vec_t(&)[Ktraits::kNLoads]>(u_vals)
|
||||
#ifdef USE_ROCM
|
||||
, Ktraits::kNThreads * Ktraits::kNLoads
|
||||
#endif
|
||||
|
||||
);
|
||||
} else {
|
||||
typename Ktraits::BlockLoadT(smem_load).Load(u, u_vals, seqlen, 0.f);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename Ktraits>
|
||||
inline __device__ void load_index(int *u,
|
||||
int (&u_vals)[Ktraits::kNItems],
|
||||
typename Ktraits::BlockLoadIndexT::TempStorage &smem_load_index,
|
||||
int seqlen) {
|
||||
if constexpr (Ktraits::kIsEvenLen) {
|
||||
auto& smem_load_index_vec = reinterpret_cast<typename Ktraits::BlockLoadIndexVecT::TempStorage&>(smem_load_index);
|
||||
Ktraits::BlockLoadIndexVecT(smem_load_index_vec).Load(
|
||||
reinterpret_cast<uint4*>(u),
|
||||
reinterpret_cast<uint4(&)[Ktraits::kNLoadsIndex]>(u_vals)
|
||||
);
|
||||
} else {
|
||||
Ktraits::BlockLoadIndexT(smem_load_index).Load(u, u_vals, seqlen, 0);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename Ktraits>
|
||||
inline __device__ void load_weight(typename Ktraits::input_t *Bvar,
|
||||
typename Ktraits::weight_t (&B_vals)[Ktraits::kNItems],
|
||||
typename Ktraits::BlockLoadWeightT::TempStorage &smem_load_weight,
|
||||
int seqlen) {
|
||||
constexpr int kNItems = Ktraits::kNItems;
|
||||
typename Ktraits::input_t B_vals_load[kNItems];
|
||||
if constexpr (Ktraits::kIsEvenLen) {
|
||||
auto& smem_load_weight_vec = reinterpret_cast<typename Ktraits::BlockLoadWeightVecT::TempStorage&>(smem_load_weight);
|
||||
using vec_t = typename Ktraits::vec_t;
|
||||
typename Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load(
|
||||
reinterpret_cast<vec_t*>(Bvar),
|
||||
reinterpret_cast<vec_t(&)[Ktraits::kNLoads]>(B_vals_load)
|
||||
);
|
||||
} else {
|
||||
typename Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f);
|
||||
}
|
||||
// #pragma unroll
|
||||
// for (int i = 0; i < kNItems; ++i) { B_vals[i] = B_vals_load[i]; }
|
||||
Converter<typename Ktraits::input_t, kNItems>::to_float(B_vals_load, B_vals);
|
||||
}
|
||||
|
||||
template<typename Ktraits>
|
||||
inline __device__ void store_output(typename Ktraits::input_t *out,
|
||||
const float (&out_vals)[Ktraits::kNItems],
|
||||
typename Ktraits::BlockStoreT::TempStorage &smem_store,
|
||||
int seqlen) {
|
||||
typename Ktraits::input_t write_vals[Ktraits::kNItems];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < Ktraits::kNItems; ++i) { write_vals[i] = out_vals[i]; }
|
||||
if constexpr (Ktraits::kIsEvenLen) {
|
||||
auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_store);
|
||||
using vec_t = typename Ktraits::vec_t;
|
||||
typename Ktraits::BlockStoreVecT(smem_store_vec).Store(
|
||||
reinterpret_cast<vec_t*>(out),
|
||||
reinterpret_cast<vec_t(&)[Ktraits::kNLoads]>(write_vals)
|
||||
);
|
||||
} else {
|
||||
typename Ktraits::BlockStoreT(smem_store).Store(out, write_vals, seqlen);
|
||||
}
|
||||
}
|
593
csrc/mamba/mamba_ssm/selective_scan_fwd.cu
Normal file
593
csrc/mamba/mamba_ssm/selective_scan_fwd.cu
Normal file
@ -0,0 +1,593 @@
|
||||
// clang-format off
|
||||
// adapted from https://github.com/state-spaces/mamba/blob/main/csrc/selective_scan/selective_scan_fwd_kernel.cuh
|
||||
#include <torch/all.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include "selective_scan.h"
|
||||
|
||||
#include <c10/util/BFloat16.h>
|
||||
#include <c10/util/Half.h>
|
||||
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#include <cub/block/block_load.cuh>
|
||||
#include <cub/block/block_store.cuh>
|
||||
#include <cub/block/block_scan.cuh>
|
||||
#else
|
||||
#include <hipcub/hipcub.hpp>
|
||||
namespace cub = hipcub;
|
||||
#endif
|
||||
|
||||
#include "selective_scan.h"
|
||||
#include "static_switch.h"
|
||||
|
||||
template<int kNThreads_, int kNItems_, int kNRows_, bool kIsEvenLen_,
|
||||
bool kIsVariableB_, bool kIsVariableC_,
|
||||
bool kHasZ_, bool kUseIndex_, typename input_t_, typename weight_t_>
|
||||
struct Selective_Scan_fwd_kernel_traits {
|
||||
static_assert(kNItems_ % 4 == 0);
|
||||
using input_t = input_t_;
|
||||
using weight_t = weight_t_;
|
||||
static constexpr int kNThreads = kNThreads_;
|
||||
// Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads improves occupancy.
|
||||
static constexpr int kMinBlocks = kNThreads < 128 ? 5 : 3;
|
||||
static constexpr int kNItems = kNItems_;
|
||||
static constexpr int kNRows = kNRows_;
|
||||
static constexpr int kNBytes = sizeof(input_t);
|
||||
static_assert(kNBytes == 2 || kNBytes == 4);
|
||||
static constexpr int kNElts = kNBytes == 4 ? 4 : constexpr_min(8, kNItems);
|
||||
static_assert(kNItems % kNElts == 0);
|
||||
static constexpr int kNLoads = kNItems / kNElts;
|
||||
static constexpr bool kIsEvenLen = kIsEvenLen_;
|
||||
static constexpr bool kIsVariableB = kIsVariableB_;
|
||||
static constexpr bool kIsVariableC = kIsVariableC_;
|
||||
static constexpr bool kHasZ = kHasZ_;
|
||||
static constexpr bool kUseIndex = kUseIndex_;
|
||||
|
||||
static constexpr bool kDirectIO = kIsEvenLen && kNLoads == 1;
|
||||
static constexpr int kNLoadsIndex = kNItems / 4;
|
||||
using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
|
||||
using scan_t = float2;
|
||||
using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
||||
using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads,
|
||||
!kDirectIO ? cub::BLOCK_LOAD_WARP_TRANSPOSE : cub::BLOCK_LOAD_DIRECT>;
|
||||
using BlockLoadIndexT = cub::BlockLoad<int, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
||||
using BlockLoadIndexVecT = cub::BlockLoad<uint4, kNThreads, kNLoadsIndex,
|
||||
!(kIsEvenLen && kNLoadsIndex == 1) ? cub::BLOCK_LOAD_WARP_TRANSPOSE : cub::BLOCK_LOAD_DIRECT>;
|
||||
using BlockLoadWeightT = cub::BlockLoad<input_t, kNThreads, kNItems , cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
||||
using BlockLoadWeightVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads ,
|
||||
!kDirectIO ? cub::BLOCK_LOAD_WARP_TRANSPOSE : cub::BLOCK_LOAD_DIRECT>;
|
||||
using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
|
||||
using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, kNLoads,
|
||||
!kDirectIO ? cub::BLOCK_STORE_WARP_TRANSPOSE : cub::BLOCK_STORE_DIRECT>;
|
||||
// using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING_MEMOIZE>;
|
||||
// using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING>;
|
||||
using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_WARP_SCANS>;
|
||||
static constexpr int kSmemIOSize = custom_max({sizeof(typename BlockLoadT::TempStorage),
|
||||
sizeof(typename BlockLoadVecT::TempStorage),
|
||||
sizeof(typename BlockLoadIndexT::TempStorage),
|
||||
sizeof(typename BlockLoadIndexVecT::TempStorage),
|
||||
(int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightT::TempStorage),
|
||||
(int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightVecT::TempStorage),
|
||||
sizeof(typename BlockStoreT::TempStorage),
|
||||
sizeof(typename BlockStoreVecT::TempStorage)});
|
||||
static constexpr int kSmemSize = kSmemIOSize + sizeof(typename BlockScanT::TempStorage);
|
||||
};
|
||||
|
||||
template<typename Ktraits>
|
||||
__global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks)
|
||||
void selective_scan_fwd_kernel(SSMParamsBase params) {
|
||||
constexpr bool kIsVariableB = Ktraits::kIsVariableB;
|
||||
constexpr bool kIsVariableC = Ktraits::kIsVariableC;
|
||||
constexpr bool kHasZ = Ktraits::kHasZ;
|
||||
constexpr bool kUseIndex = Ktraits::kUseIndex;
|
||||
constexpr int kNThreads = Ktraits::kNThreads;
|
||||
constexpr int kNItems = Ktraits::kNItems;
|
||||
constexpr int kNRows = Ktraits::kNRows;
|
||||
constexpr bool kDirectIO = Ktraits::kDirectIO;
|
||||
using input_t = typename Ktraits::input_t;
|
||||
using weight_t = typename Ktraits::weight_t;
|
||||
using scan_t = typename Ktraits::scan_t;
|
||||
|
||||
// Shared memory.
|
||||
extern __shared__ char smem_[];
|
||||
// cast to lvalue reference of expected type
|
||||
// char *smem_loadstorescan = smem_ + 2 * MAX_DSTATE * sizeof(weight_t);
|
||||
// auto& smem_load = reinterpret_cast<typename BlockLoadT::TempStorage&>(smem_ + 2 * MAX_DSTATE * sizeof(weight_t));
|
||||
// auto& smem_load = reinterpret_cast<typename BlockLoadT::TempStorage&>(smem_loadstorescan);
|
||||
auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
|
||||
auto& smem_load_weight = reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage&>(smem_);
|
||||
auto& smem_load_index = reinterpret_cast<typename Ktraits::BlockLoadIndexT::TempStorage&>(smem_);
|
||||
auto& smem_load_weight1 = *reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage*>(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage));
|
||||
auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
|
||||
auto& smem_scan = *reinterpret_cast<typename Ktraits::BlockScanT::TempStorage*>(smem_ + Ktraits::kSmemIOSize);
|
||||
// weight_t *smem_a = reinterpret_cast<weight_t *>(smem_ + smem_loadstorescan_size);
|
||||
// weight_t *smem_bc = reinterpret_cast<weight_t *>(smem_a + MAX_DSTATE);
|
||||
scan_t *smem_running_prefix = reinterpret_cast<scan_t *>(smem_ + Ktraits::kSmemSize);
|
||||
|
||||
const int batch_id = blockIdx.x;
|
||||
const int dim_id = blockIdx.y;
|
||||
const int group_id = dim_id / (params.dim_ngroups_ratio);
|
||||
input_t *u = reinterpret_cast<input_t *>(params.u_ptr) + batch_id * params.u_batch_stride
|
||||
+ dim_id * kNRows * params.u_d_stride;
|
||||
input_t *delta = reinterpret_cast<input_t *>(params.delta_ptr) + batch_id * params.delta_batch_stride
|
||||
+ dim_id * kNRows * params.delta_d_stride;
|
||||
weight_t *A = reinterpret_cast<weight_t *>(params.A_ptr) + dim_id * kNRows * params.A_d_stride;
|
||||
weight_t *B = reinterpret_cast<weight_t *>(params.B_ptr) + dim_id * kNRows * params.B_d_stride;
|
||||
input_t *Bvar = reinterpret_cast<input_t *>(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride;
|
||||
weight_t *C = reinterpret_cast<weight_t *>(params.C_ptr) + dim_id * kNRows * params.C_d_stride;
|
||||
input_t *Cvar = reinterpret_cast<input_t *>(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride;
|
||||
scan_t *x = reinterpret_cast<scan_t *>(params.x_ptr) + (batch_id * params.dim + dim_id * kNRows) * params.n_chunks * params.dstate;
|
||||
int *index = !kUseIndex ? nullptr :reinterpret_cast<int *>(params.index_ptr) + batch_id * params.seqlen;
|
||||
|
||||
float D_val[kNRows] = {0};
|
||||
if (params.D_ptr != nullptr) {
|
||||
#pragma unroll
|
||||
for (int r = 0; r < kNRows; ++r) {
|
||||
D_val[r] = reinterpret_cast<float *>(params.D_ptr)[dim_id * kNRows + r];
|
||||
}
|
||||
}
|
||||
float delta_bias[kNRows] = {0};
|
||||
if (params.delta_bias_ptr != nullptr) {
|
||||
#pragma unroll
|
||||
for (int r = 0; r < kNRows; ++r) {
|
||||
delta_bias[r] = reinterpret_cast<float *>(params.delta_bias_ptr)[dim_id * kNRows + r];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) {
|
||||
// smem_a[state_idx] = A[state_idx * params.A_dstate_stride];
|
||||
// smem_bc[state_idx] = B[state_idx * params.B_dstate_stride] * C[state_idx * params.C_dstate_stride];
|
||||
// }
|
||||
|
||||
constexpr int kChunkSize = kNThreads * kNItems;
|
||||
for (int chunk = 0; chunk < params.n_chunks; ++chunk) {
|
||||
input_t u_vals[kNRows][kNItems], delta_vals_load[kNRows][kNItems];
|
||||
int index_vals_load[kNRows][kNItems];
|
||||
|
||||
__syncthreads();
|
||||
#pragma unroll
|
||||
for (int r = 0; r < kNRows; ++r) {
|
||||
if constexpr (!kDirectIO) {
|
||||
if (r > 0) { __syncthreads(); }
|
||||
}
|
||||
load_input<Ktraits>(u + r * params.u_d_stride, u_vals[r], smem_load, params.seqlen - chunk * kChunkSize);
|
||||
if constexpr (!kDirectIO) { __syncthreads(); }
|
||||
load_input<Ktraits>(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, params.seqlen - chunk * kChunkSize);
|
||||
if constexpr (kUseIndex) {
|
||||
load_index<Ktraits>(index + r * params.delta_d_stride, index_vals_load[r], smem_load_index, params.seqlen - chunk * kChunkSize);
|
||||
}
|
||||
}
|
||||
if constexpr (kUseIndex) {
|
||||
index += kChunkSize;
|
||||
}
|
||||
u += kChunkSize;
|
||||
delta += kChunkSize;
|
||||
|
||||
float delta_vals[kNRows][kNItems], delta_u_vals[kNRows][kNItems], out_vals[kNRows][kNItems];
|
||||
#pragma unroll
|
||||
for (int r = 0; r < kNRows; ++r) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNItems; ++i) {
|
||||
float u_val = float(u_vals[r][i]);
|
||||
delta_vals[r][i] = float(delta_vals_load[r][i]) + delta_bias[r];
|
||||
if (params.delta_softplus) {
|
||||
delta_vals[r][i] = delta_vals[r][i] <= 20.f ? log1pf(expf(delta_vals[r][i])) : delta_vals[r][i];
|
||||
}
|
||||
delta_u_vals[r][i] = delta_vals[r][i] * u_val;
|
||||
out_vals[r][i] = D_val[r] * u_val;
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
for (int state_idx = 0; state_idx < params.dstate; ++state_idx) {
|
||||
weight_t A_val[kNRows];
|
||||
#pragma unroll
|
||||
for (int r = 0; r < kNRows; ++r) {
|
||||
A_val[r] = A[state_idx * params.A_dstate_stride + r * params.A_d_stride];
|
||||
// Multiply the real part of A with LOG2E so we can use exp2f instead of expf.
|
||||
constexpr float kLog2e = M_LOG2E;
|
||||
A_val[r] *= kLog2e;
|
||||
}
|
||||
// This variable holds B * C if both B and C are constant across seqlen. If only B varies
|
||||
// across seqlen, this holds C. If only C varies across seqlen, this holds B.
|
||||
// If both B and C vary, this is unused.
|
||||
weight_t BC_val[kNRows];
|
||||
weight_t B_vals[kNItems], C_vals[kNItems];
|
||||
if constexpr (kIsVariableB) {
|
||||
load_weight<Ktraits>(Bvar + state_idx * params.B_dstate_stride, B_vals,
|
||||
smem_load_weight, (params.seqlen - chunk * kChunkSize) * (1));
|
||||
if constexpr (!kIsVariableC) {
|
||||
#pragma unroll
|
||||
for (int r = 0; r < kNRows; ++r) {
|
||||
BC_val[r] = C[state_idx * params.C_dstate_stride + r * params.C_d_stride];
|
||||
}
|
||||
}
|
||||
}
|
||||
if constexpr (kIsVariableC) {
|
||||
auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1;
|
||||
load_weight<Ktraits>(Cvar + state_idx * params.C_dstate_stride, C_vals,
|
||||
smem_load_weight_C, (params.seqlen - chunk * kChunkSize) * (1 ));
|
||||
if constexpr (!kIsVariableB) {
|
||||
#pragma unroll
|
||||
for (int r = 0; r < kNRows; ++r) {
|
||||
BC_val[r] = B[state_idx * params.B_dstate_stride + r * params.B_d_stride];
|
||||
}
|
||||
}
|
||||
}
|
||||
if constexpr (!kIsVariableB && !kIsVariableC) {
|
||||
#pragma unroll
|
||||
for (int r = 0; r < kNRows; ++r) {
|
||||
BC_val[r] = B[state_idx * params.B_dstate_stride + r * params.B_d_stride] * C[state_idx * params.C_dstate_stride + r * params.C_d_stride];
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int r = 0; r < kNRows; ++r) {
|
||||
if (r > 0) { __syncthreads(); } // Scan could be using the same smem
|
||||
scan_t thread_data[kNItems];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNItems; ++i) {
|
||||
thread_data[i] = make_float2(exp2f(delta_vals[r][i] * A_val[r]),
|
||||
!kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]);
|
||||
|
||||
// Reset A bar for cumulative sequences (Real)
|
||||
if constexpr (kUseIndex) {
|
||||
if (index_vals_load[r][i] == 0) {
|
||||
thread_data[i].x = 0.f;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct
|
||||
if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) {
|
||||
thread_data[i] = make_float2(1.f, 0.f);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Initialize running total
|
||||
scan_t running_prefix;
|
||||
// If we use WARP_SCAN then all lane 0 of all warps (not just thread 0) needs to read
|
||||
running_prefix = chunk == 0 ? x[(r * params.n_chunks) * params.dstate + state_idx] : ( threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.f, 0.f));
|
||||
// running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float2(1.f, 0.f);
|
||||
SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix);
|
||||
typename Ktraits::BlockScanT(smem_scan).InclusiveScan(
|
||||
thread_data, thread_data, SSMScanOp<weight_t>(), prefix_op
|
||||
);
|
||||
// There's a syncthreads in the scan op, so we don't need to sync here.
|
||||
// Unless there's only 1 warp, but then it's the same thread (0) reading and writing.
|
||||
if (threadIdx.x == 0) {
|
||||
smem_running_prefix[state_idx] = prefix_op.running_prefix;
|
||||
x[(r * params.n_chunks + chunk) * params.dstate + state_idx] = prefix_op.running_prefix;
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNItems; ++i) {
|
||||
const weight_t C_val = !kIsVariableC
|
||||
? BC_val[r]
|
||||
: (!kIsVariableB ? BC_val[r] * C_vals[i] : C_vals[i]);
|
||||
out_vals[r][i] += thread_data[i].y * C_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
|
||||
+ dim_id * kNRows * params.out_d_stride + chunk * kChunkSize;
|
||||
__syncthreads();
|
||||
#pragma unroll
|
||||
for (int r = 0; r < kNRows; ++r) {
|
||||
if constexpr (!kDirectIO) {
|
||||
if (r > 0) { __syncthreads(); }
|
||||
}
|
||||
store_output<Ktraits>(out + r * params.out_d_stride, out_vals[r], smem_store, params.seqlen - chunk * kChunkSize);
|
||||
}
|
||||
|
||||
if constexpr (kHasZ) {
|
||||
input_t *z = reinterpret_cast<input_t *>(params.z_ptr) + batch_id * params.z_batch_stride
|
||||
+ dim_id * kNRows * params.z_d_stride + chunk * kChunkSize;
|
||||
input_t *out_z = reinterpret_cast<input_t *>(params.out_z_ptr) + batch_id * params.out_z_batch_stride
|
||||
+ dim_id * kNRows * params.out_z_d_stride + chunk * kChunkSize;
|
||||
#pragma unroll
|
||||
for (int r = 0; r < kNRows; ++r) {
|
||||
input_t z_vals[kNItems];
|
||||
__syncthreads();
|
||||
load_input<Ktraits>(z + r * params.z_d_stride, z_vals, smem_load, params.seqlen - chunk * kChunkSize);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNItems; ++i) {
|
||||
float z_val = z_vals[i];
|
||||
out_vals[r][i] *= z_val / (1 + expf(-z_val));
|
||||
}
|
||||
__syncthreads();
|
||||
store_output<Ktraits>(out_z + r * params.out_z_d_stride, out_vals[r], smem_store, params.seqlen - chunk * kChunkSize);
|
||||
}
|
||||
}
|
||||
|
||||
Bvar += kChunkSize * 1;
|
||||
Cvar += kChunkSize * 1;
|
||||
}
|
||||
}
|
||||
|
||||
template<int kNThreads, int kNItems, typename input_t, typename weight_t>
|
||||
void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) {
|
||||
// Only kNRows == 1 is tested for now, which ofc doesn't differ from previously when we had each block
|
||||
// processing 1 row.
|
||||
constexpr int kNRows = 1;
|
||||
// kIsVariableB, kIsVariableC and kHasZ are all set to True to reduce binary size
|
||||
constexpr bool kIsVariableB = true;
|
||||
constexpr bool kIsVariableC = true;
|
||||
constexpr bool kHasZ = true;
|
||||
BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {
|
||||
BOOL_SWITCH(params.index_ptr != nullptr , kUseIndex, [&] {
|
||||
using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, kIsVariableC, kHasZ, kUseIndex, input_t, weight_t>;
|
||||
constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t);
|
||||
dim3 grid(params.batch, params.dim / kNRows);
|
||||
auto kernel = &selective_scan_fwd_kernel<Ktraits>;
|
||||
if (kSmemSize >= 48 * 1024) {
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
||||
}
|
||||
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template<typename input_t, typename weight_t>
|
||||
void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream) {
|
||||
|
||||
#ifndef USE_ROCM
|
||||
if (params.seqlen <= 128) {
|
||||
selective_scan_fwd_launch<32, 4, input_t, weight_t>(params, stream);
|
||||
} else if (params.seqlen <= 256) {
|
||||
selective_scan_fwd_launch<32, 8, input_t, weight_t>(params, stream);
|
||||
} else if (params.seqlen <= 512) {
|
||||
selective_scan_fwd_launch<32, 16, input_t, weight_t>(params, stream);
|
||||
} else if (params.seqlen <= 1024) {
|
||||
selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream);
|
||||
} else {
|
||||
selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream);
|
||||
}
|
||||
#else
|
||||
if (params.seqlen <= 256) {
|
||||
selective_scan_fwd_launch<64, 4, input_t, weight_t>(params, stream);
|
||||
} else if (params.seqlen <= 512) {
|
||||
selective_scan_fwd_launch<64, 8, input_t, weight_t>(params, stream);
|
||||
} else if (params.seqlen <= 1024) {
|
||||
selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream);
|
||||
} else {
|
||||
selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template void selective_scan_fwd_cuda<at::BFloat16, float>(SSMParamsBase ¶ms, cudaStream_t stream);
|
||||
template void selective_scan_fwd_cuda<at::Half, float>(SSMParamsBase ¶ms, cudaStream_t stream);
|
||||
template void selective_scan_fwd_cuda<float, float>(SSMParamsBase ¶ms, cudaStream_t stream);
|
||||
|
||||
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
|
||||
|
||||
#define DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \
|
||||
if (ITYPE == at::ScalarType::Half) { \
|
||||
using input_t = at::Half; \
|
||||
using weight_t = float; \
|
||||
__VA_ARGS__(); \
|
||||
} else if (ITYPE == at::ScalarType::BFloat16) { \
|
||||
using input_t = at::BFloat16; \
|
||||
using weight_t = float; \
|
||||
__VA_ARGS__(); \
|
||||
} else if (ITYPE == at::ScalarType::Float) { \
|
||||
using input_t = float; \
|
||||
using weight_t = float; \
|
||||
__VA_ARGS__(); \
|
||||
} else { \
|
||||
AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
|
||||
}
|
||||
|
||||
|
||||
template<typename input_t, typename weight_t>
|
||||
void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream);
|
||||
|
||||
void set_ssm_params_fwd(SSMParamsBase ¶ms,
|
||||
// sizes
|
||||
const size_t batch,
|
||||
const size_t dim,
|
||||
const size_t seqlen,
|
||||
const size_t dstate,
|
||||
const size_t n_groups,
|
||||
const size_t n_chunks,
|
||||
const bool is_variable_B,
|
||||
const bool is_variable_C,
|
||||
// device pointers
|
||||
const torch::Tensor u,
|
||||
const torch::Tensor delta,
|
||||
const torch::Tensor A,
|
||||
const torch::Tensor B,
|
||||
const torch::Tensor C,
|
||||
const torch::Tensor out,
|
||||
const torch::Tensor z,
|
||||
const torch::Tensor out_z,
|
||||
void* D_ptr,
|
||||
void* delta_bias_ptr,
|
||||
void* x_ptr,
|
||||
bool has_z,
|
||||
bool delta_softplus,
|
||||
void* index_ptr) {
|
||||
|
||||
// Reset the parameters
|
||||
memset(¶ms, 0, sizeof(params));
|
||||
|
||||
params.batch = batch;
|
||||
params.dim = dim;
|
||||
params.seqlen = seqlen;
|
||||
params.dstate = dstate;
|
||||
params.n_groups = n_groups;
|
||||
params.n_chunks = n_chunks;
|
||||
params.dim_ngroups_ratio = dim / n_groups;
|
||||
|
||||
params.delta_softplus = delta_softplus;
|
||||
|
||||
params.is_variable_B = is_variable_B;
|
||||
params.is_variable_C = is_variable_C;
|
||||
|
||||
// Set the pointers and strides.
|
||||
params.u_ptr = u.data_ptr();
|
||||
params.delta_ptr = delta.data_ptr();
|
||||
params.A_ptr = A.data_ptr();
|
||||
params.B_ptr = B.data_ptr();
|
||||
params.C_ptr = C.data_ptr();
|
||||
params.D_ptr = D_ptr;
|
||||
params.delta_bias_ptr = delta_bias_ptr;
|
||||
params.out_ptr = out.data_ptr();
|
||||
params.x_ptr = x_ptr;
|
||||
params.z_ptr = has_z ? z.data_ptr() : nullptr;
|
||||
params.out_z_ptr = has_z ? out_z.data_ptr() : nullptr;
|
||||
|
||||
params.index_ptr = index_ptr;
|
||||
|
||||
// All stride are in elements, not bytes.
|
||||
params.A_d_stride = A.stride(0);
|
||||
params.A_dstate_stride = A.stride(1);
|
||||
if (!is_variable_B) {
|
||||
params.B_d_stride = B.stride(0);
|
||||
} else {
|
||||
params.B_batch_stride = B.stride(0);
|
||||
params.B_group_stride = B.stride(1);
|
||||
}
|
||||
params.B_dstate_stride = !is_variable_B ? B.stride(1) : B.stride(2);
|
||||
if (!is_variable_C) {
|
||||
params.C_d_stride = C.stride(0);
|
||||
} else {
|
||||
params.C_batch_stride = C.stride(0);
|
||||
params.C_group_stride = C.stride(1);
|
||||
}
|
||||
params.C_dstate_stride = !is_variable_C ? C.stride(1) : C.stride(2);
|
||||
params.u_batch_stride = u.stride(0);
|
||||
params.u_d_stride = u.stride(1);
|
||||
params.delta_batch_stride = delta.stride(0);
|
||||
params.delta_d_stride = delta.stride(1);
|
||||
if (has_z) {
|
||||
params.z_batch_stride = z.stride(0);
|
||||
params.z_d_stride = z.stride(1);
|
||||
params.out_z_batch_stride = out_z.stride(0);
|
||||
params.out_z_d_stride = out_z.stride(1);
|
||||
}
|
||||
params.out_batch_stride = out.stride(0);
|
||||
params.out_d_stride = out.stride(1);
|
||||
}
|
||||
|
||||
std::vector<torch::Tensor>
|
||||
selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
|
||||
const torch::Tensor &A, const torch::Tensor &B, const torch::Tensor &C,
|
||||
const c10::optional<torch::Tensor> &D_,
|
||||
const c10::optional<torch::Tensor> &z_,
|
||||
const c10::optional<torch::Tensor> &delta_bias_,
|
||||
bool delta_softplus,
|
||||
const c10::optional<torch::Tensor> &index_,
|
||||
const c10::optional<torch::Tensor> &x) {
|
||||
auto input_type = u.scalar_type();
|
||||
auto weight_type = A.scalar_type();
|
||||
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
||||
TORCH_CHECK(weight_type == at::ScalarType::Float);
|
||||
|
||||
const bool is_variable_B = B.dim() >= 3;
|
||||
const bool is_variable_C = C.dim() >= 3;
|
||||
|
||||
TORCH_CHECK(delta.scalar_type() == input_type);
|
||||
TORCH_CHECK(B.scalar_type() == (!is_variable_B ? weight_type : input_type));
|
||||
TORCH_CHECK(C.scalar_type() == (!is_variable_C ? weight_type : input_type));
|
||||
|
||||
TORCH_CHECK(u.is_cuda());
|
||||
TORCH_CHECK(delta.is_cuda());
|
||||
TORCH_CHECK(A.is_cuda());
|
||||
TORCH_CHECK(B.is_cuda());
|
||||
TORCH_CHECK(C.is_cuda());
|
||||
|
||||
TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1);
|
||||
TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1);
|
||||
|
||||
const auto sizes = u.sizes();
|
||||
const int batch_size = sizes[0];
|
||||
const int dim = sizes[1];
|
||||
const int seqlen = sizes[2];
|
||||
const int dstate = A.size(1);
|
||||
const int n_groups = is_variable_B ? B.size(1) : 1;
|
||||
|
||||
TORCH_CHECK(dstate <= 256, "selective_scan only supports state dimension <= 256");
|
||||
|
||||
CHECK_SHAPE(u, batch_size, dim, seqlen);
|
||||
CHECK_SHAPE(delta, batch_size, dim, seqlen);
|
||||
CHECK_SHAPE(A, dim, dstate);
|
||||
TORCH_CHECK(is_variable_B, "is_variable_B = False is disabled in favor of reduced binary size")
|
||||
CHECK_SHAPE(B, batch_size, n_groups, dstate, seqlen );
|
||||
TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1);
|
||||
|
||||
TORCH_CHECK(is_variable_C, "is_variable_C = False is disabled in favor of reduced binary size")
|
||||
CHECK_SHAPE(C, batch_size, n_groups, dstate, seqlen);
|
||||
TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1);
|
||||
|
||||
if (D_.has_value()) {
|
||||
auto D = D_.value();
|
||||
TORCH_CHECK(D.scalar_type() == at::ScalarType::Float);
|
||||
TORCH_CHECK(D.is_cuda());
|
||||
TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1);
|
||||
CHECK_SHAPE(D, dim);
|
||||
}
|
||||
|
||||
if (delta_bias_.has_value()) {
|
||||
auto delta_bias = delta_bias_.value();
|
||||
TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float);
|
||||
TORCH_CHECK(delta_bias.is_cuda());
|
||||
TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1);
|
||||
CHECK_SHAPE(delta_bias, dim);
|
||||
}
|
||||
if (index_.has_value()) {
|
||||
auto index = index_.value();
|
||||
TORCH_CHECK(index.scalar_type() == at::ScalarType::Int);
|
||||
TORCH_CHECK(index.is_cuda());
|
||||
CHECK_SHAPE(index, batch_size, seqlen);
|
||||
}
|
||||
|
||||
at::Tensor z, out_z;
|
||||
const bool has_z = z_.has_value();
|
||||
TORCH_CHECK(has_z, "has_z = False is disabled in favor of reduced binary size")
|
||||
z = z_.value();
|
||||
TORCH_CHECK(z.scalar_type() == input_type);
|
||||
TORCH_CHECK(z.is_cuda());
|
||||
TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1);
|
||||
CHECK_SHAPE(z, batch_size, dim, seqlen);
|
||||
out_z = torch::empty_like(z);
|
||||
|
||||
const int n_chunks = (seqlen + 2048 - 1) / 2048;
|
||||
// const int n_chunks = (seqlen + 1024 - 1) / 1024;
|
||||
// at::Tensor out = torch::empty_like(u);
|
||||
// Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout
|
||||
at::Tensor out = torch::empty_like(delta);
|
||||
if (x.has_value()){
|
||||
auto _x = x.value();
|
||||
TORCH_CHECK(_x.scalar_type() == weight_type);
|
||||
TORCH_CHECK(_x.is_cuda());
|
||||
TORCH_CHECK(_x.stride(-1) == 1);
|
||||
CHECK_SHAPE(_x, batch_size, dim, n_chunks, dstate * 2);
|
||||
}
|
||||
|
||||
SSMParamsBase params;
|
||||
set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C,
|
||||
u, delta, A, B, C, out, z, out_z,
|
||||
D_.has_value() ? D_.value().data_ptr() : nullptr,
|
||||
delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr,
|
||||
x.value().data_ptr(),
|
||||
has_z,
|
||||
delta_softplus,
|
||||
index_.has_value() ? index_.value().data_ptr() : nullptr);
|
||||
|
||||
// Otherwise the kernel will be launched from cuda:0 device
|
||||
// Cast to char to avoid compiler warning about narrowing
|
||||
at::cuda::CUDAGuard device_guard{(char)u.get_device()};
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] {
|
||||
selective_scan_fwd_cuda<input_t, weight_t>(params, stream);
|
||||
});
|
||||
std::vector<at::Tensor> result = {out, x.value()};
|
||||
if (has_z) { result.push_back(out_z); }
|
||||
return result;
|
||||
}
|
||||
|
28
csrc/mamba/mamba_ssm/static_switch.h
Normal file
28
csrc/mamba/mamba_ssm/static_switch.h
Normal file
@ -0,0 +1,28 @@
|
||||
// Inspired by
|
||||
// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
|
||||
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
|
||||
|
||||
// clang-format off
|
||||
// adapted from https://github.com/state-spaces/mamba/blob/main/csrc/selective_scan/static_switch.h
|
||||
#pragma once
|
||||
|
||||
/// @param COND - a boolean expression to switch by
|
||||
/// @param CONST_NAME - a name given for the constexpr bool variable.
|
||||
/// @param ... - code to execute for true and false
|
||||
///
|
||||
/// Usage:
|
||||
/// ```
|
||||
/// BOOL_SWITCH(flag, BoolConst, [&] {
|
||||
/// some_function<BoolConst>(...);
|
||||
/// });
|
||||
/// ```
|
||||
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
|
||||
[&] { \
|
||||
if (COND) { \
|
||||
constexpr bool CONST_NAME = true; \
|
||||
return __VA_ARGS__(); \
|
||||
} else { \
|
||||
constexpr bool CONST_NAME = false; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
}()
|
1740
csrc/moe/marlin_moe_ops.cu
Normal file
1740
csrc/moe/marlin_moe_ops.cu
Normal file
File diff suppressed because it is too large
Load Diff
12
csrc/moe/marlin_moe_ops.h
Normal file
12
csrc/moe/marlin_moe_ops.h
Normal file
@ -0,0 +1,12 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/all.h>
|
||||
|
||||
torch::Tensor marlin_gemm_moe(
|
||||
const torch::Tensor& a, const torch::Tensor& b_q_weights,
|
||||
const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights,
|
||||
const torch::Tensor& topk_ids, const torch::Tensor& b_scales,
|
||||
const torch::Tensor& g_idx, const torch::Tensor& perm,
|
||||
torch::Tensor& workspace, int64_t size_m, int64_t size_n, int64_t size_k,
|
||||
bool is_k_full, int64_t num_experts, int64_t topk, int64_t moe_block_size,
|
||||
bool replicate_input, bool apply_weights);
|
@ -1,5 +1,6 @@
|
||||
#include "core/registration.h"
|
||||
#include "moe_ops.h"
|
||||
#include "marlin_moe_ops.h"
|
||||
|
||||
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
||||
// Apply topk softmax to the gating outputs.
|
||||
@ -7,6 +8,16 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
||||
"topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! "
|
||||
"token_expert_indices, Tensor gating_output) -> ()");
|
||||
m.impl("topk_softmax", torch::kCUDA, &topk_softmax);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
m.def(
|
||||
"marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
|
||||
"Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! "
|
||||
"g_idx, Tensor! perm, Tensor! workspace, int size_m, int size_n, int "
|
||||
"size_k, bool is_k_full, int num_experts, int topk, int moe_block_size, "
|
||||
"bool replicate_input, bool apply_weights) -> Tensor");
|
||||
m.impl("marlin_gemm_moe", torch::kCUDA, &marlin_gemm_moe);
|
||||
#endif
|
||||
}
|
||||
|
||||
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
||||
|
48
csrc/ops.h
48
csrc/ops.h
@ -54,10 +54,21 @@ void gelu_fast(torch::Tensor& out, torch::Tensor& input);
|
||||
|
||||
void gelu_quick(torch::Tensor& out, torch::Tensor& input);
|
||||
|
||||
void advance_step(int64_t num_seqs, int64_t num_queries, int64_t block_size,
|
||||
void advance_step_flashattn(int64_t num_seqs, int64_t num_queries,
|
||||
int64_t block_size, torch::Tensor& input_tokens,
|
||||
torch::Tensor& sampled_token_ids,
|
||||
torch::Tensor& input_positions,
|
||||
torch::Tensor& seq_lens,
|
||||
torch::Tensor& slot_mapping,
|
||||
torch::Tensor& block_tables);
|
||||
|
||||
void advance_step_flashinfer(
|
||||
int64_t num_seqs, int64_t num_queries, int64_t block_size,
|
||||
torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids,
|
||||
torch::Tensor& input_positions, torch::Tensor& seq_lens,
|
||||
torch::Tensor& slot_mapping, torch::Tensor& block_tables);
|
||||
torch::Tensor& slot_mapping, torch::Tensor& block_tables,
|
||||
torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr,
|
||||
torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bounds);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
|
||||
@ -123,9 +134,17 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
|
||||
int64_t size_k, int64_t size_n,
|
||||
int64_t num_bits);
|
||||
|
||||
torch::Tensor gptq_marlin_repack_meta(torch::Tensor& b_q_weight,
|
||||
torch::Tensor& perm, c10::SymInt size_k,
|
||||
c10::SymInt size_n, int64_t num_bits);
|
||||
|
||||
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
|
||||
int64_t size_n, int64_t num_bits);
|
||||
|
||||
torch::Tensor awq_marlin_repack_meta(torch::Tensor& b_q_weight,
|
||||
c10::SymInt size_k, c10::SymInt size_n,
|
||||
int64_t num_bits);
|
||||
|
||||
torch::Tensor ggml_dequantize(torch::Tensor W, int64_t type, int64_t m,
|
||||
int64_t n);
|
||||
|
||||
@ -170,9 +189,6 @@ void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||
void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||
torch::Tensor& scales);
|
||||
|
||||
void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
|
||||
torch::Tensor lookup_table);
|
||||
|
||||
torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
|
||||
torch::Tensor b_gptq_qzeros,
|
||||
torch::Tensor b_gptq_scales, torch::Tensor b_g_idx,
|
||||
@ -195,6 +211,28 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
||||
torch::Tensor experts_ids,
|
||||
torch::Tensor num_tokens_post_pad);
|
||||
|
||||
std::vector<torch::Tensor> selective_scan_fwd(
|
||||
const torch::Tensor& u, const torch::Tensor& delta, const torch::Tensor& A,
|
||||
const torch::Tensor& B, const torch::Tensor& C,
|
||||
const c10::optional<torch::Tensor>& D_,
|
||||
const c10::optional<torch::Tensor>& z_,
|
||||
const c10::optional<torch::Tensor>& delta_bias_, bool delta_softplus,
|
||||
const c10::optional<torch::Tensor>& index_,
|
||||
const c10::optional<torch::Tensor>& x);
|
||||
|
||||
at::Tensor causal_conv1d_update(const at::Tensor& x,
|
||||
const at::Tensor& conv_state,
|
||||
const at::Tensor& weight,
|
||||
const c10::optional<at::Tensor>& bias_,
|
||||
bool silu_activation);
|
||||
|
||||
at::Tensor causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,
|
||||
const c10::optional<at::Tensor>& bias_,
|
||||
const c10::optional<at::Tensor>& seq_idx_,
|
||||
const c10::optional<at::Tensor>& initial_states_,
|
||||
const c10::optional<at::Tensor>& final_states_out_,
|
||||
bool silu_activation);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
using fptr_t = int64_t;
|
||||
fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
|
||||
|
@ -12,12 +12,10 @@ namespace prepare_inputs {
|
||||
|
||||
//
|
||||
template <int const num_threads>
|
||||
__global__ void advance_step_kernel(int num_seqs, int num_queries,
|
||||
int block_size, long* input_tokens_ptr,
|
||||
long const* sampled_token_ids_ptr,
|
||||
long* input_positions_ptr,
|
||||
int* seq_lens_ptr, long* slot_mapping_ptr,
|
||||
int const* block_tables_ptr,
|
||||
__global__ void advance_step_flashattn_kernel(
|
||||
int num_seqs, int num_queries, int block_size, long* input_tokens_ptr,
|
||||
long const* sampled_token_ids_ptr, long* input_positions_ptr,
|
||||
int* seq_lens_ptr, long* slot_mapping_ptr, int const* block_tables_ptr,
|
||||
int64_t const block_tables_stride) {
|
||||
int num_query_blocks = div_ceil(num_queries, num_threads);
|
||||
|
||||
@ -79,7 +77,82 @@ inline void verify_tensor(std::string const& name, torch::Tensor& t,
|
||||
}
|
||||
}
|
||||
|
||||
void advance_step(int num_seqs, int num_queries, int block_size,
|
||||
__global__ void advance_step_flashinfer_kernel(
|
||||
int num_threads, int num_seqs, int num_queries, int block_size,
|
||||
long* input_tokens_ptr, long const* sampled_token_ids_ptr,
|
||||
long* input_positions_ptr, int* seq_lens_ptr, long* slot_mapping_ptr,
|
||||
int const* block_tables_ptr, int64_t const block_tables_stride,
|
||||
int* paged_kv_last_page_len_ptr, int* block_table_bound_ptr) {
|
||||
int num_query_blocks = div_ceil(num_queries, num_threads);
|
||||
|
||||
if (blockIdx.x < num_query_blocks) {
|
||||
int cur_query_id = blockIdx.x * num_threads + threadIdx.x;
|
||||
|
||||
if (cur_query_id < num_queries) {
|
||||
// Update input_tokens
|
||||
input_tokens_ptr[cur_query_id] = sampled_token_ids_ptr[cur_query_id];
|
||||
|
||||
int seq_len = seq_lens_ptr[cur_query_id];
|
||||
int next_seq_len = seq_len + 1;
|
||||
int next_input_pos = next_seq_len - 1;
|
||||
|
||||
// Update seq_lens
|
||||
seq_lens_ptr[cur_query_id] = next_seq_len;
|
||||
// Update input_positions
|
||||
input_positions_ptr[cur_query_id] = next_input_pos;
|
||||
|
||||
int const* seq_block_tables_ptr =
|
||||
block_tables_ptr + block_tables_stride * cur_query_id;
|
||||
|
||||
int block_index = next_input_pos / block_size;
|
||||
int block_offset = next_input_pos % block_size;
|
||||
|
||||
// Update paged_kv_last_page_len
|
||||
paged_kv_last_page_len_ptr[cur_query_id] = block_offset + 1;
|
||||
|
||||
int slot_num =
|
||||
seq_block_tables_ptr[block_index] * block_size + block_offset;
|
||||
// Update slot_mapping
|
||||
slot_mapping_ptr[cur_query_id] = slot_num;
|
||||
block_table_bound_ptr[cur_query_id] = div_ceil(next_seq_len, block_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void advance_step_flashinfer_indptr_kernel(
|
||||
int num_threads, int num_seqs, int num_queries, int* paged_kv_indptr_ptr,
|
||||
int* block_table_bound_ptr) {
|
||||
int idx = blockIdx.x * num_threads + threadIdx.x;
|
||||
|
||||
// Update paged_kv_indptr
|
||||
if (idx < num_queries) {
|
||||
int sum = 0;
|
||||
for (int i = 0; i <= idx; ++i) {
|
||||
sum += block_table_bound_ptr[i];
|
||||
}
|
||||
paged_kv_indptr_ptr[idx + 1] = sum;
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void advance_step_flashinfer_indices_kernel(
|
||||
int num_threads, int num_seqs, int num_queries, int const* block_tables_ptr,
|
||||
int64_t const block_tables_stride, int* paged_kv_indices_ptr,
|
||||
int* paged_kv_indptr_ptr, int* block_table_bound_ptr) {
|
||||
int idx = blockIdx.x * num_threads + threadIdx.x;
|
||||
int row = idx / block_tables_stride;
|
||||
int col = idx % block_tables_stride;
|
||||
|
||||
if (row < num_queries && col < block_table_bound_ptr[row]) {
|
||||
paged_kv_indices_ptr[paged_kv_indptr_ptr[row] + col] =
|
||||
block_tables_ptr[row * block_tables_stride + col];
|
||||
}
|
||||
// if cudagraph, fill padded seqs with the last valid seq's indptr
|
||||
if (num_queries < row && row <= num_seqs) {
|
||||
paged_kv_indptr_ptr[row] = paged_kv_indptr_ptr[num_queries];
|
||||
}
|
||||
}
|
||||
|
||||
void advance_step_flashattn(int num_seqs, int num_queries, int block_size,
|
||||
torch::Tensor& input_tokens, // type: long
|
||||
torch::Tensor& sampled_token_ids, // type: long
|
||||
torch::Tensor& input_positions, // type: long
|
||||
@ -88,7 +161,7 @@ void advance_step(int num_seqs, int num_queries, int block_size,
|
||||
torch::Tensor& block_tables) { // type: int
|
||||
|
||||
if (logging) {
|
||||
printf("advance_step:\n");
|
||||
printf("advance_step_flashattn:\n");
|
||||
printf(" num_seqs = %d\n", num_seqs);
|
||||
printf(" num_queries = %d\n", num_queries);
|
||||
printf(" block_size = %d\n", block_size);
|
||||
@ -108,7 +181,8 @@ void advance_step(int num_seqs, int num_queries, int block_size,
|
||||
int blocks;
|
||||
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
|
||||
|
||||
advance_step_kernel<max_threads><<<blocks, max_threads, 0, stream>>>(
|
||||
advance_step_flashattn_kernel<max_threads>
|
||||
<<<blocks, max_threads, 0, stream>>>(
|
||||
num_seqs, num_queries, block_size,
|
||||
reinterpret_cast<long*>(input_tokens.data_ptr()),
|
||||
reinterpret_cast<long const*>(sampled_token_ids.data_ptr()),
|
||||
@ -119,13 +193,114 @@ void advance_step(int num_seqs, int num_queries, int block_size,
|
||||
block_tables.stride(0));
|
||||
}
|
||||
|
||||
void advance_step_flashinfer(
|
||||
int num_seqs, int num_queries, int block_size,
|
||||
torch::Tensor& input_tokens, // type: long
|
||||
torch::Tensor& sampled_token_ids, // type: long
|
||||
torch::Tensor& input_positions, // type: long
|
||||
torch::Tensor& seq_lens, // type: int
|
||||
torch::Tensor& slot_mapping, // type: long
|
||||
torch::Tensor& block_tables, // type: int
|
||||
torch::Tensor& paged_kv_indices, // type: int
|
||||
torch::Tensor& paged_kv_indptr, // type: int
|
||||
torch::Tensor& paged_kv_last_page_len, // type: int
|
||||
torch::Tensor& block_table_bound) { // type: int
|
||||
|
||||
if (logging) {
|
||||
printf("advance_step_flashinfer:\n");
|
||||
printf(" num_seqs = %d\n", num_seqs);
|
||||
printf(" num_queries = %d\n", num_queries);
|
||||
printf(" block_size = %d\n", block_size);
|
||||
printf(" block_tables.stride(0) = %d\n", block_tables.stride(0));
|
||||
}
|
||||
// Verify all tensors
|
||||
verify_tensor("input_tokens", input_tokens, num_seqs, -1, at::kLong);
|
||||
// verify_tensor("sampled_token_ids", sampled_token_ids, num_queries, 1,
|
||||
// at::kLong);
|
||||
verify_tensor("input_positions", input_positions, num_seqs, -1, at::kLong);
|
||||
verify_tensor("seq_lens", seq_lens, num_seqs, -1, at::kInt);
|
||||
verify_tensor("slot_mapping", slot_mapping, num_seqs, -1, at::kLong);
|
||||
verify_tensor("block_tables", block_tables, num_seqs, -1, at::kInt);
|
||||
|
||||
verify_tensor("paged_kv_indices", paged_kv_indices, -1, -1, at::kInt);
|
||||
verify_tensor("paged_kv_indptr", paged_kv_indptr, num_seqs + 1, -1, at::kInt);
|
||||
verify_tensor("paged_kv_last_page_len", paged_kv_last_page_len, num_seqs, -1,
|
||||
at::kInt);
|
||||
|
||||
verify_tensor("block_table_bound", block_table_bound, num_seqs, -1, at::kInt);
|
||||
|
||||
int dev = sampled_token_ids.get_device();
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev);
|
||||
|
||||
int blocks;
|
||||
int threads;
|
||||
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
|
||||
cudaDeviceGetAttribute(&threads, cudaDevAttrMaxThreadsPerBlock, dev);
|
||||
if (logging) {
|
||||
printf("launching kernel with %d blocks\n", blocks);
|
||||
}
|
||||
|
||||
// TODO(will): support arbitrary block_tables stride
|
||||
if ((blocks * threads) / block_tables.stride(0) < num_queries) {
|
||||
TORCH_CHECK(false,
|
||||
"multi-step: not enough threads to map block_table to"
|
||||
"FlashInfer's paged_kv_indices on GPU. Try reducing the number "
|
||||
"of seqs,",
|
||||
" increasing the block size or take smaller steps.",
|
||||
" num_queries = ", num_queries,
|
||||
" block_tables.stride(0) = ", block_tables.stride(0),
|
||||
" blocks = ", blocks, " max_threads = ", threads);
|
||||
}
|
||||
|
||||
advance_step_flashinfer_kernel<<<blocks, threads, 0, stream>>>(
|
||||
threads, num_seqs, num_queries, block_size,
|
||||
reinterpret_cast<long*>(input_tokens.data_ptr()),
|
||||
reinterpret_cast<long const*>(sampled_token_ids.data_ptr()),
|
||||
reinterpret_cast<long*>(input_positions.data_ptr()),
|
||||
reinterpret_cast<int*>(seq_lens.data_ptr()),
|
||||
reinterpret_cast<long*>(slot_mapping.data_ptr()),
|
||||
reinterpret_cast<int const*>(block_tables.data_ptr()),
|
||||
block_tables.stride(0),
|
||||
reinterpret_cast<int*>(paged_kv_last_page_len.data_ptr()),
|
||||
reinterpret_cast<int*>(block_table_bound.data_ptr()));
|
||||
|
||||
advance_step_flashinfer_indptr_kernel<<<blocks, threads, 0, stream>>>(
|
||||
threads, num_seqs, num_queries,
|
||||
reinterpret_cast<int*>(paged_kv_indptr.data_ptr()),
|
||||
reinterpret_cast<int*>(block_table_bound.data_ptr()));
|
||||
|
||||
advance_step_flashinfer_indices_kernel<<<blocks, threads, 0, stream>>>(
|
||||
threads, num_seqs, num_queries,
|
||||
reinterpret_cast<int const*>(block_tables.data_ptr()),
|
||||
block_tables.stride(0),
|
||||
reinterpret_cast<int*>(paged_kv_indices.data_ptr()),
|
||||
reinterpret_cast<int*>(paged_kv_indptr.data_ptr()),
|
||||
reinterpret_cast<int*>(block_table_bound.data_ptr()));
|
||||
}
|
||||
|
||||
} // namespace prepare_inputs
|
||||
|
||||
void advance_step(int64_t num_seqs, int64_t num_queries, int64_t block_size,
|
||||
void advance_step_flashattn(int64_t num_seqs, int64_t num_queries,
|
||||
int64_t block_size, torch::Tensor& input_tokens,
|
||||
torch::Tensor& sampled_token_ids,
|
||||
torch::Tensor& input_positions,
|
||||
torch::Tensor& seq_lens,
|
||||
torch::Tensor& slot_mapping,
|
||||
torch::Tensor& block_tables) {
|
||||
prepare_inputs::advance_step_flashattn(
|
||||
num_seqs, num_queries, block_size, input_tokens, sampled_token_ids,
|
||||
input_positions, seq_lens, slot_mapping, block_tables);
|
||||
}
|
||||
|
||||
void advance_step_flashinfer(
|
||||
int64_t num_seqs, int64_t num_queries, int64_t block_size,
|
||||
torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids,
|
||||
torch::Tensor& input_positions, torch::Tensor& seq_lens,
|
||||
torch::Tensor& slot_mapping, torch::Tensor& block_tables) {
|
||||
prepare_inputs::advance_step(num_seqs, num_queries, block_size, input_tokens,
|
||||
sampled_token_ids, input_positions, seq_lens,
|
||||
slot_mapping, block_tables);
|
||||
torch::Tensor& slot_mapping, torch::Tensor& block_tables,
|
||||
torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr,
|
||||
torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bound) {
|
||||
prepare_inputs::advance_step_flashinfer(
|
||||
num_seqs, num_queries, block_size, input_tokens, sampled_token_ids,
|
||||
input_positions, seq_lens, slot_mapping, block_tables, paged_kv_indices,
|
||||
paged_kv_indptr, paged_kv_last_page_len, block_table_bound);
|
||||
}
|
@ -267,3 +267,15 @@ torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
torch::Tensor awq_marlin_repack_meta(torch::Tensor& b_q_weight,
|
||||
c10::SymInt size_k, c10::SymInt size_n,
|
||||
int64_t num_bits) {
|
||||
int const pack_factor = 32 / num_bits;
|
||||
auto options = torch::TensorOptions()
|
||||
.dtype(b_q_weight.dtype())
|
||||
.device(b_q_weight.device());
|
||||
return torch::empty_symint(
|
||||
{size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
|
||||
options);
|
||||
}
|
||||
|
@ -342,3 +342,15 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
torch::Tensor gptq_marlin_repack_meta(torch::Tensor& b_q_weight,
|
||||
torch::Tensor& perm, c10::SymInt size_k,
|
||||
c10::SymInt size_n, int64_t num_bits) {
|
||||
int const pack_factor = 32 / num_bits;
|
||||
auto options = torch::TensorOptions()
|
||||
.dtype(b_q_weight.dtype())
|
||||
.device(b_q_weight.device());
|
||||
return torch::empty_symint(
|
||||
{size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
|
||||
options);
|
||||
}
|
||||
|
@ -1,216 +0,0 @@
|
||||
#include <torch/all.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
// half-tensor
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <ATen/cuda/CUDATensorMethods.cuh>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#define BLOCKWIDTH 128
|
||||
#define BLOCKHEIGHT4 16
|
||||
|
||||
namespace vllm {
|
||||
namespace squeezellm {
|
||||
|
||||
__device__ inline unsigned int as_unsigned(int i) {
|
||||
return *reinterpret_cast<unsigned int*>(&i);
|
||||
}
|
||||
|
||||
// 4-bit matvec kernel (LUT-based)
|
||||
__global__ void NUQ4MatMulKernel(
|
||||
#ifndef USE_ROCM
|
||||
const half2* __restrict__ vec,
|
||||
#else
|
||||
const __half2* __restrict__ vec,
|
||||
#endif
|
||||
const int* __restrict__ mat,
|
||||
#ifndef USE_ROCM
|
||||
half2* __restrict__ mul,
|
||||
#else
|
||||
float2* __restrict__ mul,
|
||||
#endif
|
||||
const __half* __restrict__ lookup_table, int height, int width, int batch,
|
||||
int vec_height) {
|
||||
|
||||
const int blockwidth2 = BLOCKWIDTH / 2;
|
||||
|
||||
int row = BLOCKHEIGHT4 * blockIdx.x;
|
||||
int col = BLOCKWIDTH * blockIdx.y + threadIdx.x;
|
||||
|
||||
#ifndef USE_ROCM
|
||||
__shared__ half2 blockvec[blockwidth2];
|
||||
#else
|
||||
__shared__ __half2 blockvec[blockwidth2];
|
||||
#endif
|
||||
|
||||
__shared__ __half deq2[16][BLOCKWIDTH];
|
||||
int off = threadIdx.x;
|
||||
int column_offset = col * 16;
|
||||
for (int val = 0; val < 16; val += 1) {
|
||||
int lut_index = column_offset + val;
|
||||
deq2[val][off] = lookup_table[lut_index];
|
||||
}
|
||||
|
||||
__half res;
|
||||
#ifndef USE_ROCM
|
||||
half2 res2;
|
||||
half2 tmp2;
|
||||
#else
|
||||
__half2 res2;
|
||||
__half2 tmp2;
|
||||
#endif
|
||||
|
||||
int i;
|
||||
int k;
|
||||
|
||||
unsigned int tmp1;
|
||||
unsigned int lut_index1, lut_index2;
|
||||
|
||||
for (int b = 0; b < batch; ++b) {
|
||||
i = width * row + col;
|
||||
res = __int2half_rd(0);
|
||||
k = 0;
|
||||
|
||||
__syncthreads();
|
||||
if (threadIdx.x < blockwidth2)
|
||||
blockvec[threadIdx.x] =
|
||||
vec[b * vec_height / 2 + (row / BLOCKHEIGHT4) * blockwidth2 +
|
||||
threadIdx.x];
|
||||
__syncthreads();
|
||||
|
||||
while (k < blockwidth2) {
|
||||
tmp1 = as_unsigned(mat[i]);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
res2 = {};
|
||||
tmp2 = {};
|
||||
#else
|
||||
res2.x = __half_as_ushort(__float2half(0));
|
||||
res2.y = __half_as_ushort(__float2half(0));
|
||||
tmp2.x = __half_as_ushort(__float2half(0));
|
||||
tmp2.y = __half_as_ushort(__float2half(0));
|
||||
#endif
|
||||
|
||||
lut_index1 = tmp1 & 0xF;
|
||||
lut_index2 = (tmp1 >> 4) & 0xF;
|
||||
#ifndef USE_ROCM
|
||||
tmp2.x = deq2[lut_index1][off];
|
||||
tmp2.y = deq2[lut_index2][off];
|
||||
#else
|
||||
tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
|
||||
tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
|
||||
#endif
|
||||
res2 = __hfma2(tmp2, blockvec[k + 0], res2);
|
||||
|
||||
lut_index1 = (tmp1 >> 8) & 0xF;
|
||||
lut_index2 = (tmp1 >> 12) & 0xF;
|
||||
#ifndef USE_ROCM
|
||||
tmp2.x = deq2[lut_index1][off];
|
||||
tmp2.y = deq2[lut_index2][off];
|
||||
#else
|
||||
tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
|
||||
tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
|
||||
#endif
|
||||
res2 = __hfma2(tmp2, blockvec[k + 1], res2);
|
||||
|
||||
lut_index1 = (tmp1 >> 16) & 0xF;
|
||||
lut_index2 = (tmp1 >> 20) & 0xF;
|
||||
#ifndef USE_ROCM
|
||||
tmp2.x = deq2[lut_index1][off];
|
||||
tmp2.y = deq2[lut_index2][off];
|
||||
#else
|
||||
tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
|
||||
tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
|
||||
#endif
|
||||
res2 = __hfma2(tmp2, blockvec[k + 2], res2);
|
||||
|
||||
lut_index1 = (tmp1 >> 24) & 0xF;
|
||||
lut_index2 = (tmp1 >> 28) & 0xF;
|
||||
#ifndef USE_ROCM
|
||||
tmp2.x = deq2[lut_index1][off];
|
||||
tmp2.y = deq2[lut_index2][off];
|
||||
#else
|
||||
tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
|
||||
tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
|
||||
#endif
|
||||
res2 = __hfma2(tmp2, blockvec[k + 3], res2);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
res = __hadd(__hadd(res2.x, res2.y), res);
|
||||
#else
|
||||
res = __hadd(__hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)),
|
||||
res);
|
||||
#endif
|
||||
|
||||
i += width;
|
||||
k += 4;
|
||||
}
|
||||
|
||||
// col%2 -> only set one of the two values
|
||||
#ifndef USE_ROCM
|
||||
half2 res3 = {};
|
||||
if (col % 2 == 0) {
|
||||
res3.x = res;
|
||||
} else {
|
||||
res3.y = res;
|
||||
}
|
||||
#else
|
||||
__half2 res3;
|
||||
res3.x = __half_as_ushort(__float2half(0));
|
||||
res3.y = __half_as_ushort(__float2half(0));
|
||||
if (col % 2 == 0) {
|
||||
res3.x = __half_as_ushort(res);
|
||||
} else {
|
||||
res3.y = __half_as_ushort(res);
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
atomicAdd(&mul[b * width / 2 + col / 2], res3);
|
||||
#else
|
||||
int tmp_addr = b * width / 2 + col / 2;
|
||||
atomicAdd(&(mul[tmp_addr].x), __half2float(__ushort_as_half(res3.x)));
|
||||
atomicAdd(&(mul[tmp_addr].y), __half2float(__ushort_as_half(res3.y)));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace squeezellm
|
||||
} // namespace vllm
|
||||
|
||||
// 4-bit matvec kernel (LUT-based)
|
||||
void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
|
||||
torch::Tensor lookup_table) {
|
||||
int height = mat.size(0);
|
||||
int width = mat.size(1);
|
||||
|
||||
int batch = vec.size(0);
|
||||
int vec_height = vec.size(1);
|
||||
|
||||
dim3 blocks((height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
|
||||
(width + BLOCKWIDTH - 1) / BLOCKWIDTH);
|
||||
dim3 threads(BLOCKWIDTH);
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
vllm::squeezellm::NUQ4MatMulKernel<<<blocks, threads, 0, stream>>>(
|
||||
#ifndef USE_ROCM
|
||||
(half2*)vec.data_ptr<at::Half>(),
|
||||
#else
|
||||
(__half2*)vec.data_ptr<at::Half>(),
|
||||
#endif
|
||||
mat.data_ptr<int>(),
|
||||
#ifndef USE_ROCM
|
||||
(half2*)mul.data_ptr<at::Half>(),
|
||||
(__half*)lookup_table.data_ptr<at::Half>(),
|
||||
#else
|
||||
(float2*)mul.data_ptr<float>(),
|
||||
(__half*)lookup_table.data_ptr<at::Half>(),
|
||||
#endif
|
||||
height, width, batch, vec_height);
|
||||
}
|
||||
|
||||
#undef BLOCKWIDTH
|
||||
#undef BLOCKHEIGHT4
|
@ -36,8 +36,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
// PagedAttention V2.
|
||||
ops.def(
|
||||
"paged_attention_v2("
|
||||
" Tensor! out, Tensor exp_sums, Tensor max_logits,"
|
||||
" Tensor tmp_out, Tensor query, Tensor key_cache,"
|
||||
" Tensor! out, Tensor! exp_sums, Tensor! max_logits,"
|
||||
" Tensor! tmp_out, Tensor query, Tensor key_cache,"
|
||||
" Tensor value_cache, int num_kv_heads, float scale,"
|
||||
" Tensor block_tables, Tensor seq_lens, int block_size,"
|
||||
" int max_seq_len, Tensor? alibi_slopes,"
|
||||
@ -73,8 +73,23 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.impl("gelu_quick", torch::kCUDA, &gelu_quick);
|
||||
|
||||
// prepare_inputs advance_step
|
||||
ops.def("advance_step", &advance_step);
|
||||
ops.impl("advance_step", torch::kCUDA, &advance_step);
|
||||
ops.def(
|
||||
"advance_step_flashattn(int num_seqs, int num_queries, int block_size, "
|
||||
"Tensor! input_tokens, Tensor sampled_token_ids, "
|
||||
"Tensor! input_positions, Tensor! seq_lens, Tensor! slot_mapping, "
|
||||
"Tensor block_tables) -> ()");
|
||||
ops.impl("advance_step_flashattn", torch::kCUDA, &advance_step_flashattn);
|
||||
|
||||
ops.def(
|
||||
"advance_step_flashinfer("
|
||||
" int num_seqs, int num_queries, int block_size,"
|
||||
" Tensor! input_tokens, Tensor sampled_token_ids,"
|
||||
" Tensor! input_positions, Tensor! seq_lens, Tensor! slot_mapping,"
|
||||
" Tensor block_tables, Tensor! paged_kv_indices,"
|
||||
" Tensor! paged_kv_indptr, Tensor! paged_kv_last_page_len,"
|
||||
" Tensor! block_table_bounds"
|
||||
") -> ()");
|
||||
ops.impl("advance_step_flashinfer", torch::kCUDA, &advance_step_flashinfer);
|
||||
|
||||
// Layernorm
|
||||
// Apply Root Mean Square (RMS) Normalization to the input tensor.
|
||||
@ -110,27 +125,56 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
// Quantization ops
|
||||
#ifndef USE_ROCM
|
||||
// Quantized GEMM for AQLM.
|
||||
ops.def("aqlm_gemm", &aqlm_gemm);
|
||||
ops.def(
|
||||
"aqlm_gemm(Tensor input, Tensor codes, Tensor codebooks, "
|
||||
"Tensor scales, int[] codebook_partition_sizes, Tensor? bias) "
|
||||
"-> Tensor");
|
||||
ops.impl("aqlm_gemm", torch::kCUDA, &aqlm_gemm);
|
||||
|
||||
// Decompression method for AQLM.
|
||||
ops.def("aqlm_dequant", &aqlm_dequant);
|
||||
ops.def(
|
||||
"aqlm_dequant(Tensor codes, Tensor codebooks, "
|
||||
"int[] codebook_partition_sizes) -> Tensor");
|
||||
ops.impl("aqlm_dequant", torch::kCUDA, &aqlm_dequant);
|
||||
|
||||
// Quantized GEMM for AWQ.
|
||||
ops.def("awq_gemm", &awq_gemm);
|
||||
ops.def(
|
||||
"awq_gemm(Tensor _in_feats, Tensor _kernel, Tensor _scaling_factors, "
|
||||
"Tensor _zeros, int split_k_iters) -> Tensor");
|
||||
ops.impl("awq_gemm", torch::kCUDA, &awq_gemm);
|
||||
|
||||
// Dequantization for AWQ.
|
||||
ops.def("awq_dequantize", &awq_dequantize);
|
||||
ops.def(
|
||||
"awq_dequantize(Tensor _kernel, Tensor _scaling_factors, "
|
||||
"Tensor _zeros, int split_k_iters, int thx, int thy) -> Tensor");
|
||||
ops.impl("awq_dequantize", torch::kCUDA, &awq_dequantize);
|
||||
|
||||
// Note about marlin kernel 'workspace' arguments:
|
||||
// Technically these should be mutable since they are modified by the kernel.
|
||||
// But since they are set back to zero once the kernel is finished we can
|
||||
// hand wave and say that they have no net effect.
|
||||
//
|
||||
// The reason to mark 'workspace' as immutable is so that they don't interfere
|
||||
// with using ScalarType arguments in the ops. If they are marked as mutable,
|
||||
// pytorch throws an assert in
|
||||
// 'torch._higher_order_ops._register_effectful_op' that prevents these
|
||||
// kernels from being torch.compile'd.
|
||||
// See the following document for more info on custom types and ops that use
|
||||
// custom types:
|
||||
// https://docs.google.com/document/d/18fBMPuOJ0fY5ZQ6YyrHUppw9FA332CpNtgB6SOIgyuA
|
||||
|
||||
// Marlin (Dense) Optimized Quantized GEMM for GPTQ.
|
||||
ops.def("marlin_gemm", &marlin_gemm);
|
||||
ops.def(
|
||||
"marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
|
||||
"Tensor! workspace, int size_m, int size_n, int size_k) -> Tensor");
|
||||
ops.impl("marlin_gemm", torch::kCUDA, &marlin_gemm);
|
||||
|
||||
// Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ.
|
||||
ops.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm);
|
||||
ops.def(
|
||||
"gptq_marlin_24_gemm(Tensor a, Tensor b_q_weight, Tensor b_meta, "
|
||||
"Tensor b_scales, Tensor workspace, "
|
||||
"__torch__.torch.classes._core_C.ScalarType b_q_type, "
|
||||
"int size_m, int size_n, int size_k) -> Tensor");
|
||||
ops.impl("gptq_marlin_24_gemm", torch::kCUDA, &gptq_marlin_24_gemm);
|
||||
|
||||
// Machete (Dense) Optimized Mixed Precision GEMM for Hopper.
|
||||
@ -149,35 +193,55 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.impl("machete_prepack_B", torch::kCUDA, &machete::prepack_B);
|
||||
|
||||
// gptq_marlin Optimized Quantized GEMM for GPTQ.
|
||||
ops.def("gptq_marlin_gemm", &gptq_marlin_gemm);
|
||||
ops.def(
|
||||
"gptq_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
|
||||
"Tensor b_zeros, Tensor g_idx, Tensor perm, Tensor workspace, "
|
||||
"__torch__.torch.classes._core_C.ScalarType b_q_type, "
|
||||
"int size_m, int size_n, int size_k, bool is_k_full, "
|
||||
"bool has_zp, bool use_fp32_reduce) -> Tensor");
|
||||
ops.impl("gptq_marlin_gemm", torch::kCUDA, &gptq_marlin_gemm);
|
||||
|
||||
// gptq_marlin repack from GPTQ.
|
||||
ops.def("gptq_marlin_repack", &gptq_marlin_repack);
|
||||
ops.def(
|
||||
"gptq_marlin_repack(Tensor b_q_weight, Tensor perm, "
|
||||
"SymInt size_k, SymInt size_n, int num_bits) -> Tensor");
|
||||
ops.impl("gptq_marlin_repack", torch::kCUDA, &gptq_marlin_repack);
|
||||
ops.impl("gptq_marlin_repack", torch::kMeta, &gptq_marlin_repack_meta);
|
||||
|
||||
// awq_marlin repack from AWQ.
|
||||
ops.def("awq_marlin_repack", &awq_marlin_repack);
|
||||
ops.def(
|
||||
"awq_marlin_repack(Tensor b_q_weight, SymInt size_k, "
|
||||
"SymInt size_n, int num_bits) -> Tensor");
|
||||
ops.impl("awq_marlin_repack", torch::kCUDA, &awq_marlin_repack);
|
||||
ops.impl("awq_marlin_repack", torch::kMeta, &awq_marlin_repack_meta);
|
||||
|
||||
// Dequantization for GGML.
|
||||
ops.def("ggml_dequantize", &ggml_dequantize);
|
||||
ops.def("ggml_dequantize(Tensor W, int type, int m, int n) -> Tensor");
|
||||
ops.impl("ggml_dequantize", torch::kCUDA, &ggml_dequantize);
|
||||
|
||||
// mmvq kernel for GGML.
|
||||
ops.def("ggml_mul_mat_vec_a8", &ggml_mul_mat_vec_a8);
|
||||
ops.def(
|
||||
"ggml_mul_mat_vec_a8(Tensor W, Tensor X, int type, int row) "
|
||||
"-> Tensor");
|
||||
ops.impl("ggml_mul_mat_vec_a8", torch::kCUDA, &ggml_mul_mat_vec_a8);
|
||||
|
||||
// mmq kernel for GGML.
|
||||
ops.def("ggml_mul_mat_a8", &ggml_mul_mat_a8);
|
||||
ops.def("ggml_mul_mat_a8(Tensor W, Tensor X, int type, int row) -> Tensor");
|
||||
ops.impl("ggml_mul_mat_a8", torch::kCUDA, &ggml_mul_mat_a8);
|
||||
|
||||
// fp8_marlin Optimized Quantized GEMM for FP8 weight-only.
|
||||
ops.def("fp8_marlin_gemm", &fp8_marlin_gemm);
|
||||
ops.def(
|
||||
"fp8_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
|
||||
"Tensor! workspace, int num_bits, int size_m, int size_n, "
|
||||
"int size_k) -> Tensor");
|
||||
ops.impl("fp8_marlin_gemm", torch::kCUDA, &fp8_marlin_gemm);
|
||||
|
||||
// marlin_qqq_gemm for QQQ.
|
||||
ops.def("marlin_qqq_gemm", &marlin_qqq_gemm);
|
||||
ops.def(
|
||||
"marlin_qqq_gemm(Tensor a, Tensor b_q_weight, "
|
||||
"Tensor s_tok, Tensor s_ch, Tensor s_group, "
|
||||
"Tensor! workspace, int size_m, int size_n, "
|
||||
"int size_k) -> Tensor");
|
||||
ops.impl("marlin_qqq_gemm", torch::kCUDA, &marlin_qqq_gemm);
|
||||
|
||||
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
|
||||
@ -199,25 +263,49 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
|
||||
// Check if cutlass scaled_mm is supported for CUDA devices of the given
|
||||
// capability
|
||||
ops.def("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8);
|
||||
ops.impl("cutlass_scaled_mm_supports_fp8", torch::kCUDA,
|
||||
&cutlass_scaled_mm_supports_fp8);
|
||||
ops.def("cutlass_scaled_mm_supports_fp8(int cuda_device_capability) -> bool");
|
||||
ops.impl("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8);
|
||||
|
||||
// Mamba selective scan kernel
|
||||
ops.def(
|
||||
"selective_scan_fwd(Tensor! u, Tensor! delta,"
|
||||
"Tensor! A, Tensor! B, Tensor! C,"
|
||||
"Tensor? D_, Tensor? z_, Tensor? delta_bias_,"
|
||||
"bool delta_softplus,"
|
||||
"Tensor? index_, Tensor(a! -> *)? x) -> Tensor(a)[]");
|
||||
ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);
|
||||
|
||||
ops.def(
|
||||
"causal_conv1d_update(Tensor! x,"
|
||||
"Tensor! conv_state,"
|
||||
"Tensor! weight,"
|
||||
"Tensor? bias_,"
|
||||
"bool silu_activation) -> Tensor");
|
||||
ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update);
|
||||
|
||||
ops.def(
|
||||
"causal_conv1d_fwd(Tensor! x, Tensor! weight,"
|
||||
"Tensor? bias_,"
|
||||
"Tensor? seq_idx_,"
|
||||
"Tensor? initial_states_,"
|
||||
"Tensor? final_states_out_,"
|
||||
"bool silu_activation) -> Tensor");
|
||||
ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd);
|
||||
#endif
|
||||
|
||||
// Quantized GEMM for GPTQ.
|
||||
ops.def("gptq_gemm", &gptq_gemm);
|
||||
// Note: even though the C++ inferred schema is correct for this op, it seems
|
||||
// to prevent the meta function registry.
|
||||
ops.def(
|
||||
"gptq_gemm(Tensor a, Tensor b_q_weight, Tensor b_gptq_qzeros, "
|
||||
"Tensor b_gptq_scales, Tensor b_g_idx, bool use_exllama, int bit) "
|
||||
"-> Tensor");
|
||||
ops.impl("gptq_gemm", torch::kCUDA, &gptq_gemm);
|
||||
|
||||
// Post processing for GPTQ.
|
||||
ops.def("gptq_shuffle(Tensor! q_weight, Tensor q_perm, int bit) -> ()");
|
||||
ops.impl("gptq_shuffle", torch::kCUDA, &gptq_shuffle);
|
||||
|
||||
// Quantized GEMM for SqueezeLLM.
|
||||
ops.def(
|
||||
"squeezellm_gemm(Tensor vec, Tensor mat, Tensor! mul, Tensor "
|
||||
"lookup_table) -> ()");
|
||||
ops.impl("squeezellm_gemm", torch::kCUDA, &squeezellm_gemm);
|
||||
|
||||
// Compute FP8 quantized tensor for given scaling factor.
|
||||
ops.def(
|
||||
"static_scaled_fp8_quant(Tensor! out, Tensor input, Tensor scale) -> ()");
|
||||
@ -231,8 +319,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
|
||||
// Compute dynamic-per-token FP8 quantized tensor and scaling factor.
|
||||
ops.def(
|
||||
"dynamic_per_token_scaled_fp8_quant(Tensor! out, Tensor input, Tensor! "
|
||||
"scale, Tensor? scale_ub) -> "
|
||||
"dynamic_per_token_scaled_fp8_quant(Tensor! out, Tensor input, "
|
||||
"Tensor! scale, Tensor? scale_ub) -> "
|
||||
"()");
|
||||
ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA,
|
||||
&dynamic_per_token_scaled_fp8_quant);
|
||||
@ -269,8 +357,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
||||
|
||||
// Copy the cache blocks from src to dst.
|
||||
cache_ops.def(
|
||||
"copy_blocks(Tensor[]! key_caches, Tensor[]! value_caches, Tensor "
|
||||
"block_mapping) -> ()");
|
||||
"copy_blocks(Tensor(a!)[] key_caches, Tensor[](b!) value_caches, "
|
||||
"Tensor block_mapping) -> ()");
|
||||
cache_ops.impl("copy_blocks", torch::kCUDA, ©_blocks);
|
||||
|
||||
// Reshape the key and value tensors and cache them.
|
||||
@ -295,8 +383,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
||||
|
||||
// Convert the key and value cache to fp8 data type.
|
||||
cache_ops.def(
|
||||
"convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, str "
|
||||
"kv_cache_dtype) -> ()");
|
||||
"convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, "
|
||||
"str kv_cache_dtype) -> ()");
|
||||
cache_ops.impl("convert_fp8", torch::kCUDA, &convert_fp8);
|
||||
}
|
||||
|
||||
@ -304,24 +392,28 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {
|
||||
// Cuda utils
|
||||
|
||||
// Gets the specified device attribute.
|
||||
cuda_utils.def("get_device_attribute", &get_device_attribute);
|
||||
cuda_utils.impl("get_device_attribute", torch::kCUDA, &get_device_attribute);
|
||||
cuda_utils.def("get_device_attribute(int attribute, int device_id) -> int");
|
||||
cuda_utils.impl("get_device_attribute", &get_device_attribute);
|
||||
|
||||
// Gets the maximum shared memory per block device attribute.
|
||||
cuda_utils.def("get_max_shared_memory_per_block_device_attribute",
|
||||
&get_max_shared_memory_per_block_device_attribute);
|
||||
cuda_utils.def(
|
||||
"get_max_shared_memory_per_block_device_attribute(int device_id) -> int");
|
||||
cuda_utils.impl("get_max_shared_memory_per_block_device_attribute",
|
||||
torch::kCUDA,
|
||||
&get_max_shared_memory_per_block_device_attribute);
|
||||
}
|
||||
|
||||
#ifndef USE_ROCM
|
||||
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
|
||||
// Custom all-reduce kernels
|
||||
custom_ar.def("init_custom_ar", &init_custom_ar);
|
||||
custom_ar.def(
|
||||
"init_custom_ar(Tensor meta, Tensor rank_data, "
|
||||
"str[] handles, int[] offsets, int rank, "
|
||||
"bool full_nvlink) -> int");
|
||||
custom_ar.impl("init_custom_ar", torch::kCUDA, &init_custom_ar);
|
||||
|
||||
custom_ar.def("should_custom_ar", &should_custom_ar);
|
||||
custom_ar.def(
|
||||
"should_custom_ar(Tensor inp, int max_size, int world_size, "
|
||||
"bool full_nvlink) -> bool");
|
||||
custom_ar.impl("should_custom_ar", torch::kCUDA, &should_custom_ar);
|
||||
|
||||
custom_ar.def("all_reduce_reg(int fa, Tensor inp, Tensor! out) -> ()");
|
||||
@ -333,21 +425,15 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
|
||||
custom_ar.impl("all_reduce_unreg", torch::kCUDA, &all_reduce_unreg);
|
||||
|
||||
custom_ar.def("dispose", &dispose);
|
||||
custom_ar.impl("dispose", torch::kCPU, &dispose);
|
||||
|
||||
custom_ar.def("meta_size", &meta_size);
|
||||
custom_ar.impl("meta_size", torch::kCPU, &meta_size);
|
||||
|
||||
custom_ar.def("register_buffer", ®ister_buffer);
|
||||
custom_ar.def(
|
||||
"register_buffer(int fa, Tensor t, str[] handles, "
|
||||
"int[] offsets) -> ()");
|
||||
custom_ar.impl("register_buffer", torch::kCUDA, ®ister_buffer);
|
||||
|
||||
custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta);
|
||||
custom_ar.impl("get_graph_buffer_ipc_meta", torch::kCPU,
|
||||
&get_graph_buffer_ipc_meta);
|
||||
|
||||
custom_ar.def("register_graph_buffers", ®ister_graph_buffers);
|
||||
custom_ar.impl("register_graph_buffers", torch::kCPU,
|
||||
®ister_graph_buffers);
|
||||
}
|
||||
#endif
|
||||
|
||||
|
@ -11,4 +11,5 @@ pydantic >= 2.8
|
||||
torch
|
||||
py-cpuinfo
|
||||
transformers
|
||||
mistral_common >= 1.3.4
|
||||
openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
|
@ -5,6 +5,7 @@ vLLM Meetups
|
||||
|
||||
We host regular meetups in San Francisco Bay Area every 2 months. We will share the project updates from the vLLM team and have guest speakers from the industry to share their experience and insights. Please find the materials of our previous meetups below:
|
||||
|
||||
- `The sixth vLLM meetup <https://lu.ma/87q3nvnh>`__, with NVIDIA, September 9th 2024. `[Slides] <https://docs.google.com/presentation/d/1wrLGwytQfaOTd5wCGSPNhoaW3nq0E-9wqyP7ny93xRs/edit?usp=sharing>`__
|
||||
- `The fifth vLLM meetup <https://lu.ma/lp0gyjqr>`__, with AWS, July 24th 2024. `[Slides] <https://docs.google.com/presentation/d/1RgUD8aCfcHocghoP3zmXzck9vX3RCI9yfUAB2Bbcl4Y/edit?usp=sharing>`__
|
||||
- `The fourth vLLM meetup <https://lu.ma/agivllm>`__, with Cloudflare and BentoML, June 11th 2024. `[Slides] <https://docs.google.com/presentation/d/1iJ8o7V2bQEi0BFEljLTwc5G1S10_Rhv3beed5oB0NJ4/edit?usp=sharing>`__
|
||||
- `The third vLLM meetup <https://robloxandvllmmeetup2024.splashthat.com/>`__, with Roblox, April 2nd 2024. `[Slides] <https://docs.google.com/presentation/d/1A--47JAK4BJ39t954HyTkvtfwn0fkqtsL8NGFuslReM/edit?usp=sharing>`__
|
||||
|
@ -99,6 +99,7 @@ autodoc_mock_imports = [
|
||||
"aiohttp",
|
||||
"compressed_tensors",
|
||||
"cpuinfo",
|
||||
"cv2",
|
||||
"torch",
|
||||
"transformers",
|
||||
"psutil",
|
||||
|
@ -45,8 +45,6 @@ Base Classes
|
||||
|
||||
.. autodata:: vllm.multimodal.NestedTensors
|
||||
|
||||
.. autodata:: vllm.multimodal.BatchedTensors
|
||||
|
||||
.. autodata:: vllm.multimodal.BatchedTensorInputs
|
||||
|
||||
.. autoclass:: vllm.multimodal.MultiModalDataBuiltins
|
||||
|
@ -18,13 +18,27 @@ Traces can be visualized using https://ui.perfetto.dev/.
|
||||
|
||||
Only send a few requests through vLLM when profiling, as the traces can get quite large. Also, no need to untar the traces, they can be viewed directly.
|
||||
|
||||
Example commands:
|
||||
.. tip::
|
||||
|
||||
To stop the profiler - it flushes out all the profile trace files to the directory. This takes time, for example for about 100 requests worth of data for a llama 70b, it takes about 10 minutes to flush out on a H100.
|
||||
Set the env variable VLLM_RPC_GET_DATA_TIMEOUT_MS to a big number before you start the server. Say something like 30 minutes.
|
||||
``export VLLM_RPC_GET_DATA_TIMEOUT_MS=1800000``
|
||||
|
||||
Example commands and usage:
|
||||
===========================
|
||||
|
||||
Offline Inference:
|
||||
------------------
|
||||
|
||||
Refer to `examples/offline_inference_with_profiler.py <https://github.com/vllm-project/vllm/blob/main/examples/offline_inference_with_profiler.py>`_ for an example.
|
||||
|
||||
|
||||
OpenAI Server:
|
||||
--------------
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
VLLM_TORCH_PROFILER_DIR=/mnt/traces/ python -m vllm.entrypoints.openai.api_server --model meta-llama/Meta-Llama-3-70B
|
||||
VLLM_TORCH_PROFILER_DIR=./vllm_profile python -m vllm.entrypoints.openai.api_server --model meta-llama/Meta-Llama-3-70B
|
||||
|
||||
benchmark_serving.py:
|
||||
|
||||
|
@ -21,7 +21,7 @@ If you have already taken care of the above issues, but the vLLM instance still
|
||||
|
||||
With more logging, hopefully you can find the root cause of the issue.
|
||||
|
||||
If it crashes, and the error trace shows somewhere around ``self.graph.replay()`` in ``vllm/worker/model_runner.py``, it is a cuda error inside cudagraph. To know the particular cuda operation that causes the error, you can add ``--enforce-eager`` to the command line, or ``enforce_eager=True`` to the ``LLM`` class, to disable the cudagraph optimization. This way, you can locate the exact cuda operation that causes the error.
|
||||
If it crashes, and the error trace shows somewhere around ``self.graph.replay()`` in ``vllm/worker/model_runner.py``, it is a cuda error inside cudagraph. To know the particular cuda operation that causes the error, you can add ``--enforce-eager`` to the command line, or ``enforce_eager=True`` to the :class:`~vllm.LLM` class, to disable the cudagraph optimization. This way, you can locate the exact cuda operation that causes the error.
|
||||
|
||||
Here are some common issues that can cause hangs:
|
||||
|
||||
|
@ -24,7 +24,9 @@ Offline Batched Inference
|
||||
|
||||
We first show an example of using vLLM for offline batched inference on a dataset. In other words, we use vLLM to generate texts for a list of input prompts.
|
||||
|
||||
Import ``LLM`` and ``SamplingParams`` from vLLM. The ``LLM`` class is the main class for running offline inference with vLLM engine. The ``SamplingParams`` class specifies the parameters for the sampling process.
|
||||
Import :class:`~vllm.LLM` and :class:`~vllm.SamplingParams` from vLLM.
|
||||
The :class:`~vllm.LLM` class is the main class for running offline inference with vLLM engine.
|
||||
The :class:`~vllm.SamplingParams` class specifies the parameters for the sampling process.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@ -42,7 +44,7 @@ Define the list of input prompts and the sampling parameters for generation. The
|
||||
]
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
|
||||
Initialize vLLM's engine for offline inference with the ``LLM`` class and the `OPT-125M model <https://arxiv.org/abs/2205.01068>`_. The list of supported models can be found at :ref:`supported models <supported_models>`.
|
||||
Initialize vLLM's engine for offline inference with the :class:`~vllm.LLM` class and the `OPT-125M model <https://arxiv.org/abs/2205.01068>`_. The list of supported models can be found at :ref:`supported models <supported_models>`.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
@ -56,9 +56,10 @@ First, install the dependencies:
|
||||
$ pip uninstall torch torch-xla -y
|
||||
|
||||
$ # Install PyTorch and PyTorch XLA.
|
||||
$ export DATE="+20240808"
|
||||
$ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-nightly${DATE}-cp310-cp310-linux_x86_64.whl
|
||||
$ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly${DATE}-cp310-cp310-linux_x86_64.whl
|
||||
$ export DATE="20240828"
|
||||
$ export TORCH_VERSION="2.5.0"
|
||||
$ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-${TORCH_VERSION}.dev${DATE}-cp310-cp310-linux_x86_64.whl
|
||||
$ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-${TORCH_VERSION}.dev${DATE}-cp310-cp310-linux_x86_64.whl
|
||||
|
||||
$ # Install JAX and Pallas.
|
||||
$ pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
|
||||
|
@ -107,3 +107,55 @@ The following is an example request
|
||||
"max_tokens": 7,
|
||||
"temperature": 0
|
||||
}' | jq
|
||||
|
||||
|
||||
Dynamically serving LoRA Adapters
|
||||
---------------------------------
|
||||
|
||||
In addition to serving LoRA adapters at server startup, the vLLM server now supports dynamically loading and unloading
|
||||
LoRA adapters at runtime through dedicated API endpoints. This feature can be particularly useful when the flexibility
|
||||
to change models on-the-fly is needed.
|
||||
|
||||
Note: Enabling this feature in production environments is risky as user may participate model adapter management.
|
||||
|
||||
To enable dynamic LoRA loading and unloading, ensure that the environment variable `VLLM_ALLOW_RUNTIME_LORA_UPDATING`
|
||||
is set to `True`. When this option is enabled, the API server will log a warning to indicate that dynamic loading is active.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
export VLLM_ALLOW_RUNTIME_LORA_UPDATING=True
|
||||
|
||||
|
||||
Loading a LoRA Adapter:
|
||||
|
||||
To dynamically load a LoRA adapter, send a POST request to the `/v1/load_lora_adapter` endpoint with the necessary
|
||||
details of the adapter to be loaded. The request payload should include the name and path to the LoRA adapter.
|
||||
|
||||
Example request to load a LoRA adapter:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
curl -X POST http://localhost:8000/v1/load_lora_adapter \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"lora_name": "sql_adapter",
|
||||
"lora_path": "/path/to/sql-lora-adapter"
|
||||
}'
|
||||
|
||||
Upon a successful request, the API will respond with a 200 OK status code. If an error occurs, such as if the adapter
|
||||
cannot be found or loaded, an appropriate error message will be returned.
|
||||
|
||||
Unloading a LoRA Adapter:
|
||||
|
||||
To unload a LoRA adapter that has been previously loaded, send a POST request to the `/v1/unload_lora_adapter` endpoint
|
||||
with the name or ID of the adapter to be unloaded.
|
||||
|
||||
Example request to unload a LoRA adapter:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
curl -X POST http://localhost:8000/v1/unload_lora_adapter \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"lora_name": "sql_adapter"
|
||||
}'
|
||||
|
@ -161,6 +161,46 @@ A variety of speculative models of this type are available on HF hub:
|
||||
* `granite-7b-instruct-accelerator <https://huggingface.co/ibm-granite/granite-7b-instruct-accelerator>`_
|
||||
* `granite-20b-code-instruct-accelerator <https://huggingface.co/ibm-granite/granite-20b-code-instruct-accelerator>`_
|
||||
|
||||
Lossless guarantees of Speculative Decoding
|
||||
-------------------------------------------
|
||||
In vLLM, speculative decoding aims to enhance inference efficiency while maintaining accuracy. This section addresses the lossless guarantees of
|
||||
speculative decoding, breaking down the guarantees into three key areas:
|
||||
|
||||
1. **Theoretical Losslessness**
|
||||
- Speculative decoding sampling is theoretically lossless up to the precision limits of hardware numerics. Floating-point errors might
|
||||
cause slight variations in output distributions, as discussed
|
||||
in `Accelerating Large Language Model Decoding with Speculative Sampling <https://arxiv.org/pdf/2302.01318>`_
|
||||
|
||||
2. **Algorithmic Losslessness**
|
||||
- vLLM’s implementation of speculative decoding is algorithmically validated to be lossless. Key validation tests include:
|
||||
|
||||
- **Rejection Sampler Convergence**: Ensures that samples from vLLM’s rejection sampler align with the target
|
||||
distribution. `View Test Code <https://github.com/vllm-project/vllm/blob/47b65a550866c7ffbd076ecb74106714838ce7da/tests/samplers/test_rejection_sampler.py#L252>`_
|
||||
|
||||
- **Greedy Sampling Equality**: Confirms that greedy sampling with speculative decoding matches greedy sampling
|
||||
without it. This verifies that vLLM's speculative decoding framework, when integrated with the vLLM forward pass and the vLLM rejection sampler,
|
||||
provides a lossless guarantee. Almost all of the tests in `this directory <https://github.com/vllm-project/vllm/tree/b67ae00cdbbe1a58ffc8ff170f0c8d79044a684a/tests/spec_decode/e2e>`_
|
||||
verify this property using `this assertion implementation <https://github.com/vllm-project/vllm/blob/b67ae00cdbbe1a58ffc8ff170f0c8d79044a684a/tests/spec_decode/e2e/conftest.py#L291>`_
|
||||
|
||||
3. **vLLM Logprob Stability**
|
||||
- vLLM does not currently guarantee stable token log probabilities (logprobs). This can result in different outputs for the
|
||||
same request across runs. For more details, see the FAQ section
|
||||
titled *Can the output of a prompt vary across runs in vLLM?* in the `FAQs <../serving/faq.rst>`_.
|
||||
|
||||
|
||||
**Conclusion**
|
||||
|
||||
While vLLM strives to ensure losslessness in speculative decoding, variations in generated outputs with and without speculative decoding
|
||||
can occur due to following factors:
|
||||
|
||||
- **Floating-Point Precision**: Differences in hardware numerical precision may lead to slight discrepancies in the output distribution.
|
||||
|
||||
- **Batch Size and Numerical Stability**: Changes in batch size may cause variations in logprobs and output probabilities, potentially
|
||||
due to non-deterministic behavior in batched operations or numerical instability.
|
||||
|
||||
**Mitigation Strategies**
|
||||
|
||||
For mitigation strategies, please refer to the FAQ entry *Can the output of a prompt vary across runs in vLLM?* in the `FAQs <../serving/faq.rst>`_.
|
||||
|
||||
Resources for vLLM contributors
|
||||
-------------------------------
|
||||
|
@ -51,6 +51,10 @@ Decoder-only Language Models
|
||||
- DeciLM
|
||||
- :code:`Deci/DeciLM-7B`, :code:`Deci/DeciLM-7B-instruct`, etc.
|
||||
-
|
||||
* - :code:`ExaoneForCausalLM`
|
||||
- EXAONE-3
|
||||
- :code:`LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct`, etc.
|
||||
- ✅︎
|
||||
* - :code:`FalconForCausalLM`
|
||||
- Falcon
|
||||
- :code:`tiiuae/falcon-7b`, :code:`tiiuae/falcon-40b`, :code:`tiiuae/falcon-rw-7b`, etc.
|
||||
@ -143,6 +147,10 @@ Decoder-only Language Models
|
||||
- Phi-3-Small
|
||||
- :code:`microsoft/Phi-3-small-8k-instruct`, :code:`microsoft/Phi-3-small-128k-instruct`, etc.
|
||||
-
|
||||
* - :code:`PhiMoEForCausalLM`
|
||||
- Phi-3.5-MoE
|
||||
- :code:`microsoft/Phi-3.5-MoE-instruct`, etc.
|
||||
-
|
||||
* - :code:`PersimmonForCausalLM`
|
||||
- Persimmon
|
||||
- :code:`adept/persimmon-8b-base`, :code:`adept/persimmon-8b-chat`, etc.
|
||||
@ -186,12 +194,12 @@ Multimodal Language Models
|
||||
|
||||
* - Architecture
|
||||
- Models
|
||||
- Supported Modalities
|
||||
- Modalities
|
||||
- Example HuggingFace Models
|
||||
- :ref:`LoRA <lora>`
|
||||
* - :code:`Blip2ForConditionalGeneration`
|
||||
- BLIP-2
|
||||
- Image
|
||||
- Image\ :sup:`E`
|
||||
- :code:`Salesforce/blip2-opt-2.7b`, :code:`Salesforce/blip2-opt-6.7b`, etc.
|
||||
-
|
||||
* - :code:`ChameleonForConditionalGeneration`
|
||||
@ -206,44 +214,75 @@ Multimodal Language Models
|
||||
-
|
||||
* - :code:`InternVLChatModel`
|
||||
- InternVL2
|
||||
- Image
|
||||
- Image\ :sup:`E+`
|
||||
- :code:`OpenGVLab/InternVL2-4B`, :code:`OpenGVLab/InternVL2-8B`, etc.
|
||||
-
|
||||
* - :code:`LlavaForConditionalGeneration`
|
||||
- LLaVA-1.5
|
||||
- Image
|
||||
- Image\ :sup:`E+`
|
||||
- :code:`llava-hf/llava-1.5-7b-hf`, :code:`llava-hf/llava-1.5-13b-hf`, etc.
|
||||
-
|
||||
* - :code:`LlavaNextForConditionalGeneration`
|
||||
- LLaVA-NeXT
|
||||
- Image
|
||||
- Image\ :sup:`E+`
|
||||
- :code:`llava-hf/llava-v1.6-mistral-7b-hf`, :code:`llava-hf/llava-v1.6-vicuna-7b-hf`, etc.
|
||||
-
|
||||
* - :code:`LlavaNextVideoForConditionalGeneration`
|
||||
- LLaVA-NeXT-Video
|
||||
- Video
|
||||
- :code:`llava-hf/LLaVA-NeXT-Video-7B-hf`, etc. (see note)
|
||||
-
|
||||
* - :code:`MiniCPMV`
|
||||
- MiniCPM-V
|
||||
- Image\ :sup:`+`
|
||||
- :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, :code:`openbmb/MiniCPM-V-2_6`, etc.
|
||||
-
|
||||
* - :code:`PaliGemmaForConditionalGeneration`
|
||||
- PaliGemma
|
||||
- Image
|
||||
- Image\ :sup:`E`
|
||||
- :code:`google/paligemma-3b-pt-224`, :code:`google/paligemma-3b-mix-224`, etc.
|
||||
-
|
||||
* - :code:`Phi3VForCausalLM`
|
||||
- Phi-3-Vision, Phi-3.5-Vision
|
||||
- Image
|
||||
- Image\ :sup:`E+`
|
||||
- :code:`microsoft/Phi-3-vision-128k-instruct`, :code:`microsoft/Phi-3.5-vision-instruct` etc.
|
||||
-
|
||||
* - :code:`MiniCPMV`
|
||||
- MiniCPM-V
|
||||
- Image
|
||||
- :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, :code:`openbmb/MiniCPM-V-2_6`, etc.
|
||||
* - :code:`PixtralForConditionalGeneration`
|
||||
- Pixtral
|
||||
- Image\ :sup:`+`
|
||||
- :code:`mistralai/Pixtral-12B-2409`
|
||||
-
|
||||
* - :code:`QWenLMHeadModel`
|
||||
- Qwen-VL
|
||||
- Image\ :sup:`E+`
|
||||
- :code:`Qwen/Qwen-VL`, :code:`Qwen/Qwen-VL-Chat`, etc.
|
||||
-
|
||||
* - :code:`Qwen2VLForConditionalGeneration`
|
||||
- Qwen2-VL (see note)
|
||||
- Image\ :sup:`+` / Video\ :sup:`+`
|
||||
- :code:`Qwen/Qwen2-VL-2B-Instruct`, :code:`Qwen/Qwen2-VL-7B-Instruct`, :code:`Qwen/Qwen2-VL-72B-Instruct`, etc.
|
||||
-
|
||||
* - :code:`UltravoxModel`
|
||||
- Ultravox
|
||||
- Audio
|
||||
- Audio\ :sup:`E+`
|
||||
- :code:`fixie-ai/ultravox-v0_3`
|
||||
-
|
||||
|
||||
| :sup:`E` Pre-computed embeddings can be inputted for this modality.
|
||||
| :sup:`+` Multiple items can be inputted per text prompt for this modality.
|
||||
|
||||
.. note::
|
||||
For :code:`openbmb/MiniCPM-V-2`, the official repo doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now.
|
||||
For more details, please see: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630
|
||||
|
||||
.. note::
|
||||
For :code:`LLaVA-NeXT-Video` and :code:`Qwen2-VL`, the latest release of :code:`huggingface/transformers` doesn't work yet, so we need to use a developer version (:code:`21fac7abba2a37fae86106f87fcf9974fd1e3830`) for now.
|
||||
This can be installed by running the following command:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install git+https://github.com/huggingface/transformers.git@21fac7abba2a37fae86106f87fcf9974fd1e3830
|
||||
|
||||
----
|
||||
|
||||
If your model uses one of the above model architectures, you can seamlessly run your model with vLLM.
|
||||
|
@ -9,26 +9,23 @@ This document shows you how to run and serve these models using vLLM.
|
||||
.. important::
|
||||
We are actively iterating on VLM support. Expect breaking changes to VLM usage and development in upcoming releases without prior deprecation.
|
||||
|
||||
Currently, the support for vision language models on vLLM has the following limitations:
|
||||
|
||||
* Only single image input is supported per text prompt.
|
||||
|
||||
We are continuously improving user & developer experience for VLMs. Please `open an issue on GitHub <https://github.com/vllm-project/vllm/issues/new/choose>`_ if you have any feedback or feature requests.
|
||||
|
||||
Offline Batched Inference
|
||||
-------------------------
|
||||
Offline Inference
|
||||
-----------------
|
||||
|
||||
To initialize a VLM, the aforementioned arguments must be passed to the ``LLM`` class for instantiating the engine.
|
||||
Single-image input
|
||||
^^^^^^^^^^^^^^^^^^
|
||||
|
||||
The :class:`~vllm.LLM` class can be instantiated in much the same way as language-only models.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
llm = LLM(model="llava-hf/llava-1.5-7b-hf")
|
||||
|
||||
.. important::
|
||||
.. note::
|
||||
We have removed all vision language related CLI args in the ``0.5.1`` release. **This is a breaking change**, so please update your code to follow
|
||||
the above snippet. Specifically, ``image_feature_size`` is no longer required to be specified as we now calculate that
|
||||
internally for each model.
|
||||
|
||||
the above snippet. Specifically, ``image_feature_size`` can no longer be specified as we now calculate that internally for each model.
|
||||
|
||||
To pass an image to the model, note the following in :class:`vllm.inputs.PromptInputs`:
|
||||
|
||||
@ -86,61 +83,117 @@ To pass an image to the model, note the following in :class:`vllm.inputs.PromptI
|
||||
|
||||
A code example can be found in `examples/offline_inference_vision_language.py <https://github.com/vllm-project/vllm/blob/main/examples/offline_inference_vision_language.py>`_.
|
||||
|
||||
Multi-image input
|
||||
^^^^^^^^^^^^^^^^^
|
||||
|
||||
Online OpenAI Vision API Compatible Inference
|
||||
----------------------------------------------
|
||||
Multi-image input is only supported for a subset of VLMs, as shown :ref:`here <supported_vlms>`.
|
||||
|
||||
To enable multiple multi-modal items per text prompt, you have to set ``limit_mm_per_prompt`` for the :class:`~vllm.LLM` class.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
llm = LLM(
|
||||
model="microsoft/Phi-3.5-vision-instruct",
|
||||
trust_remote_code=True, # Required to load Phi-3.5-vision
|
||||
max_model_len=4096, # Otherwise, it may not fit in smaller GPUs
|
||||
limit_mm_per_prompt={"image": 2}, # The maximum number to accept
|
||||
)
|
||||
|
||||
Instead of passing in a single image, you can pass in a list of images.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Refer to the HuggingFace repo for the correct format to use
|
||||
prompt = "<|user|>\n<image_1>\n<image_2>\nWhat is the content of each image?<|end|>\n<|assistant|>\n"
|
||||
|
||||
# Load the images using PIL.Image
|
||||
image1 = PIL.Image.open(...)
|
||||
image2 = PIL.Image.open(...)
|
||||
|
||||
outputs = llm.generate({
|
||||
"prompt": prompt,
|
||||
"multi_modal_data": {
|
||||
"image": [image1, image2]
|
||||
},
|
||||
})
|
||||
|
||||
for o in outputs:
|
||||
generated_text = o.outputs[0].text
|
||||
print(generated_text)
|
||||
|
||||
A code example can be found in `examples/offline_inference_vision_language_multi_image.py <https://github.com/vllm-project/vllm/blob/main/examples/offline_inference_vision_language_multi_image.py>`_.
|
||||
|
||||
Online Inference
|
||||
----------------
|
||||
|
||||
OpenAI Vision API
|
||||
^^^^^^^^^^^^^^^^^
|
||||
|
||||
You can serve vision language models with vLLM's HTTP server that is compatible with `OpenAI Vision API <https://platform.openai.com/docs/guides/vision>`_.
|
||||
|
||||
.. note::
|
||||
Currently, vLLM supports only **single** ``image_url`` input per ``messages``. Support for multi-image inputs will be
|
||||
added in the future.
|
||||
|
||||
Below is an example on how to launch the same ``llava-hf/llava-1.5-7b-hf`` with vLLM API server.
|
||||
|
||||
.. important::
|
||||
Since OpenAI Vision API is based on `Chat <https://platform.openai.com/docs/api-reference/chat>`_ API, a chat template
|
||||
is **required** to launch the API server if the model's tokenizer does not come with one. In this example, we use the
|
||||
HuggingFace Llava chat template that you can find in the example folder `here <https://github.com/vllm-project/vllm/blob/main/examples/template_llava.jinja>`_.
|
||||
Below is an example on how to launch the same ``microsoft/Phi-3.5-vision-instruct`` with vLLM's OpenAI-compatible API server.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
vllm serve llava-hf/llava-1.5-7b-hf --chat-template template_llava.jinja
|
||||
vllm serve microsoft/Phi-3.5-vision-instruct --max-model-len 4096 \
|
||||
--trust-remote-code --limit-mm-per-prompt image=2
|
||||
|
||||
.. important::
|
||||
We have removed all vision language related CLI args in the ``0.5.1`` release. **This is a breaking change**, so please update your code to follow
|
||||
the above snippet. Specifically, ``image_feature_size`` is no longer required to be specified as we now calculate that
|
||||
internally for each model.
|
||||
Since OpenAI Vision API is based on `Chat Completions <https://platform.openai.com/docs/api-reference/chat>`_ API,
|
||||
a chat template is **required** to launch the API server.
|
||||
|
||||
Although Phi-3.5-Vision comes with a chat template, for other models you may have to provide one if the model's tokenizer does not come with it.
|
||||
The chat template can be inferred based on the documentation on the model's HuggingFace repo.
|
||||
For example, LLaVA-1.5 (``llava-hf/llava-1.5-7b-hf``) requires a chat template that can be found `here <https://github.com/vllm-project/vllm/blob/main/examples/template_llava.jinja>`_.
|
||||
|
||||
To consume the server, you can use the OpenAI client like in the example below:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
openai_api_key = "EMPTY"
|
||||
openai_api_base = "http://localhost:8000/v1"
|
||||
|
||||
client = OpenAI(
|
||||
api_key=openai_api_key,
|
||||
base_url=openai_api_base,
|
||||
)
|
||||
|
||||
# Single-image input inference
|
||||
image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
|
||||
|
||||
chat_response = client.chat.completions.create(
|
||||
model="llava-hf/llava-1.5-7b-hf",
|
||||
model="microsoft/Phi-3.5-vision-instruct",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": [
|
||||
# NOTE: The prompt formatting with the image token `<image>` is not needed
|
||||
# since the prompt will be processed automatically by the API server.
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg",
|
||||
},
|
||||
},
|
||||
{"type": "text", "text": "What’s in this image?"},
|
||||
{"type": "image_url", "image_url": {"url": image_url}},
|
||||
],
|
||||
}],
|
||||
)
|
||||
print("Chat response:", chat_response)
|
||||
print("Chat completion output:", chat_response.choices[0].message.content)
|
||||
|
||||
# Multi-image input inference
|
||||
image_url_duck = "https://upload.wikimedia.org/wikipedia/commons/d/da/2015_Kaczka_krzy%C5%BCowka_w_wodzie_%28samiec%29.jpg"
|
||||
image_url_lion = "https://upload.wikimedia.org/wikipedia/commons/7/77/002_The_lion_king_Snyggve_in_the_Serengeti_National_Park_Photo_by_Giles_Laurent.jpg"
|
||||
|
||||
chat_response = client.chat.completions.create(
|
||||
model="microsoft/Phi-3.5-vision-instruct",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What are the animals in these images?"},
|
||||
{"type": "image_url", "image_url": {"url": image_url_duck}},
|
||||
{"type": "image_url", "image_url": {"url": image_url_lion}},
|
||||
],
|
||||
}],
|
||||
)
|
||||
print("Chat completion output:", chat_response.choices[0].message.content)
|
||||
|
||||
|
||||
A full code example can be found in `examples/openai_vision_api_client.py <https://github.com/vllm-project/vllm/blob/main/examples/openai_vision_api_client.py>`_.
|
||||
|
||||
|
@ -20,4 +20,4 @@ The performance benchmarks and nightly benchmarks can be triggered by submitting
|
||||
|
||||
.. note::
|
||||
|
||||
Please refer to `vLLM performance benchmark descriptions <https://github.com/vllm-project/vllm/blob/main/.buildkite/nightly-benchmarks/tests/descriptions.md>`_ and `vLLM nightly benchmark descriptions <https://github.com/vllm-project/vllm/blob/main/.buildkite/nightly-benchmarks/nightly-descriptions.md>`_ for detailed descriptions on benchmark environment, workload and metrics.
|
||||
Please refer to `vLLM performance benchmark descriptions <https://github.com/vllm-project/vllm/blob/main/.buildkite/nightly-benchmarks/performance-benchmarks-descriptions.md>`_ and `vLLM nightly benchmark descriptions <https://github.com/vllm-project/vllm/blob/main/.buildkite/nightly-benchmarks/nightly-descriptions.md>`_ for detailed descriptions on benchmark environment, workload and metrics.
|
||||
|
@ -19,19 +19,21 @@ You can quantize your own models by installing AutoAWQ or picking one of the `40
|
||||
|
||||
$ pip install autoawq
|
||||
|
||||
After installing AutoAWQ, you are ready to quantize a model. Here is an example of how to quantize Vicuna 7B v1.5:
|
||||
After installing AutoAWQ, you are ready to quantize a model. Here is an example of how to quantize `mistralai/Mistral-7B-Instruct-v0.2`:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from awq import AutoAWQForCausalLM
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
model_path = 'lmsys/vicuna-7b-v1.5'
|
||||
quant_path = 'vicuna-7b-v1.5-awq'
|
||||
model_path = 'mistralai/Mistral-7B-Instruct-v0.2'
|
||||
quant_path = 'mistral-instruct-v0.2-awq'
|
||||
quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM" }
|
||||
|
||||
# Load model
|
||||
model = AutoAWQForCausalLM.from_pretrained(model_path, **{"low_cpu_mem_usage": True})
|
||||
model = AutoAWQForCausalLM.from_pretrained(
|
||||
model_path, **{"low_cpu_mem_usage": True, "use_cache": False}
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
|
||||
# Quantize
|
||||
@ -41,6 +43,8 @@ After installing AutoAWQ, you are ready to quantize a model. Here is an example
|
||||
model.save_quantized(quant_path)
|
||||
tokenizer.save_pretrained(quant_path)
|
||||
|
||||
print(f'Model is quantized and saved at "{quant_path}"')
|
||||
|
||||
To run an AWQ model with vLLM, you can use `TheBloke/Llama-2-7b-Chat-AWQ <https://huggingface.co/TheBloke/Llama-2-7b-Chat-AWQ>`_ with the following command:
|
||||
|
||||
.. code-block:: console
|
||||
|
@ -119,17 +119,6 @@ The table below shows the compatibility of various quantization implementations
|
||||
- ✗
|
||||
- ✗
|
||||
- ✗
|
||||
* - SqueezeLLM
|
||||
- ✅︎
|
||||
- ✅︎
|
||||
- ✅︎
|
||||
- ✅︎
|
||||
- ✅︎
|
||||
- ✗
|
||||
- ✗
|
||||
- ✗
|
||||
- ✗
|
||||
- ✗
|
||||
|
||||
Notes:
|
||||
^^^^^^
|
||||
|
@ -10,3 +10,22 @@ A: Assuming that you're referring to using OpenAI compatible server to serve mul
|
||||
Q: Which model to use for offline inference embedding?
|
||||
|
||||
A: If you want to use an embedding model, try: https://huggingface.co/intfloat/e5-mistral-7b-instruct. Instead models, such as Llama-3-8b, Mistral-7B-Instruct-v0.3, are generation models rather than an embedding model
|
||||
|
||||
----------------------------------------
|
||||
|
||||
Q: Can the output of a prompt vary across runs in vLLM?
|
||||
|
||||
A: Yes, it can. vLLM does not guarantee stable log probabilities (logprobs) for the output tokens. Variations in logprobs may occur due to
|
||||
numerical instability in Torch operations or non-deterministic behavior in batched Torch operations when batching changes. For more details,
|
||||
see the `Numerical Accuracy section <https://pytorch.org/docs/stable/notes/numerical_accuracy.html#batched-computations-or-slice-computations>`_.
|
||||
|
||||
In vLLM, the same requests might be batched differently due to factors such as other concurrent requests,
|
||||
changes in batch size, or batch expansion in speculative decoding. These batching variations, combined with numerical instability of Torch operations,
|
||||
can lead to slightly different logit/logprob values at each step. Such differences can accumulate, potentially resulting in
|
||||
different tokens being sampled. Once a different token is sampled, further divergence is likely.
|
||||
|
||||
**Mitigation Strategies**
|
||||
|
||||
- For improved stability and reduced variance, use `float32`. Note that this will require more memory.
|
||||
- If using `bfloat16`, switching to `float16` can also help.
|
||||
- Using request seeds can aid in achieving more stable generation for temperature > 0, but discrepancies due to precision differences may still occur.
|
||||
|
@ -110,14 +110,90 @@ directory [here](https://github.com/vllm-project/vllm/tree/main/examples/)
|
||||
:func: create_parser_for_docs
|
||||
:prog: vllm serve
|
||||
```
|
||||
## Tool Calling in the Chat Completion API
|
||||
### Named Function Calling
|
||||
vLLM supports only named function calling in the chat completion API by default. It does so using Outlines, so this is
|
||||
enabled by default, and will work with any supported model. You are guaranteed a validly-parsable function call - not a
|
||||
high-quality one.
|
||||
|
||||
To use a named function, you need to define the functions in the `tools` parameter of the chat completion request, and
|
||||
specify the `name` of one of the tools in the `tool_choice` parameter of the chat completion request.
|
||||
|
||||
### Config file
|
||||
|
||||
The `serve` module can also accept arguments from a config file in
|
||||
`yaml` format. The arguments in the yaml must be specified using the
|
||||
long form of the argument outlined [here](https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#command-line-arguments-for-the-server):
|
||||
|
||||
For example:
|
||||
|
||||
```yaml
|
||||
# config.yaml
|
||||
|
||||
host: "127.0.0.1"
|
||||
port: 6379
|
||||
uvicorn-log-level: "info"
|
||||
```
|
||||
|
||||
```bash
|
||||
$ vllm serve SOME_MODEL --config config.yaml
|
||||
```
|
||||
---
|
||||
**NOTE**
|
||||
In case an argument is supplied using command line and the config file, the value from the commandline will take precedence.
|
||||
The order of priorities is `command line > config file values > defaults`.
|
||||
|
||||
---
|
||||
|
||||
## Tool calling in the chat completion API
|
||||
vLLM supports only named function calling in the chat completion API. The `tool_choice` options `auto` and `required` are **not yet supported** but on the roadmap.
|
||||
|
||||
To use a named function you need to define the function in the `tools` parameter and call it in the `tool_choice` parameter.
|
||||
|
||||
It is the callers responsibility to prompt the model with the tool information, vLLM will not automatically manipulate the prompt. **This may change in the future.**
|
||||
It is the callers responsibility to prompt the model with the tool information, vLLM will not automatically manipulate the prompt.
|
||||
|
||||
vLLM will use guided decoding to ensure the response matches the tool parameter object defined by the JSON schema in the `tools` parameter.
|
||||
|
||||
Please refer to the OpenAI API reference documentation for more information.
|
||||
|
||||
### Automatic Function Calling
|
||||
To enable this feature, you should set the following flags:
|
||||
* `--enable-auto-tool-choice` -- **mandatory** Auto tool choice. tells vLLM that you want to enable the model to generate its own tool calls when it
|
||||
deems appropriate.
|
||||
* `--tool-call-parser` -- select the tool parser to use - currently either `hermes` or `mistral`. Additional tool parsers
|
||||
will continue to be added in the future.
|
||||
* `--chat-template` -- **optional** for auto tool choice. the path to the chat template which handles `tool`-role messages and `assistant`-role messages
|
||||
that contain previously generated tool calls. Hermes and Mistral models have tool-compatible chat templates in their
|
||||
`tokenizer_config.json` files, but you can specify a custom template. This argument can be set to `tool_use` if your model has a tool use-specific chat
|
||||
template configured in the `tokenizer_config.json`. In this case, it will be used per the `transformers` specification. More on this [here](https://huggingface.co/docs/transformers/en/chat_templating#why-do-some-models-have-multiple-templates)
|
||||
from HuggingFace; and you can find an example of this in a `tokenizer_config.json` [here](https://huggingface.co/NousResearch/Hermes-2-Pro-Llama-3-8B/blob/main/tokenizer_config.json)
|
||||
|
||||
If your favorite tool-calling model is not supported, please feel free to contribute a parser & tool use chat template!
|
||||
|
||||
#### Hermes Models
|
||||
All Nous Research Hermes-series models newer than Hermes 2 Pro should be supported.
|
||||
* `NousResearch/Hermes-2-Pro-*`
|
||||
* `NousResearch/Hermes-2-Theta-*`
|
||||
* `NousResearch/Hermes-3-*`
|
||||
|
||||
|
||||
_Note that the Hermes 2 **Theta** models are known to have degraded tool call quality & capabilities due to the merge
|
||||
step in their creation_.
|
||||
|
||||
Flags: `--tool-call-parser hermes`
|
||||
|
||||
#### Mistral Models
|
||||
Supported models:
|
||||
* `mistralai/Mistral-7B-Instruct-v0.3` (confirmed)
|
||||
* Additional mistral function-calling models are compatible as well.
|
||||
|
||||
Known issues:
|
||||
1. Mistral 7B struggles to generate parallel tool calls correctly.
|
||||
2. Mistral's `tokenizer_config.json` chat template requires tool call IDs that are exactly 9 digits, which is
|
||||
much shorter than what vLLM generates. Since an exception is thrown when this condition
|
||||
is not met, the following additional chat templates are provided:
|
||||
|
||||
* `examples/tool_chat_template_mistral.jinja` - this is the "official" Mistral chat template, but tweaked so that
|
||||
it works with vLLM's tool call IDs (provided `tool_call_id` fields are truncated to the last 9 digits)
|
||||
* `examples/tool_chat_template_mistral_parallel.jinja` - this is a "better" version that adds a tool-use system prompt
|
||||
when tools are provided, that results in much better reliability when working with parallel tool calling.
|
||||
|
||||
|
||||
Recommended flags: `--tool-call-parser mistral --chat-template examples/tool_chat_template_mistral_parallel.jinja`
|
||||
|
@ -62,7 +62,7 @@ This script evaluates the inference throughput of language models using various
|
||||
|
||||
python3 benchmarks/benchmark_throughput.py --help
|
||||
usage: benchmark_throughput.py [-h] [--backend {vllm,hf,mii}] [--dataset DATASET] [--input-len INPUT_LEN] [--output-len OUTPUT_LEN] [--model MODEL]
|
||||
[--tokenizer TOKENIZER] [--quantization {awq,gptq,squeezellm,None}] [--tensor-parallel-size TENSOR_PARALLEL_SIZE] [--n N]
|
||||
[--tokenizer TOKENIZER] [--quantization {awq,gptq,None}] [--tensor-parallel-size TENSOR_PARALLEL_SIZE] [--n N]
|
||||
[--use-beam-search] [--num-prompts NUM_PROMPTS] [--seed SEED] [--hf-max-batch-size HF_MAX_BATCH_SIZE] [--trust-remote-code]
|
||||
[--max-model-len MAX_MODEL_LEN] [--dtype {auto,half,float16,bfloat16,float,float32}] [--enforce-eager] [--kv-cache-dtype {auto,fp8}]
|
||||
[--quantization-param-path KV_CACHE_quantization_param_path]
|
||||
@ -76,7 +76,7 @@ optional arguments:
|
||||
--output-len OUTPUT_LEN Output length for each request. Overrides the output length from the dataset.
|
||||
--model MODEL
|
||||
--tokenizer TOKENIZER
|
||||
--quantization {awq,gptq,squeezellm,None}, -q {awq,gptq,squeezellm,None}
|
||||
--quantization {awq,gptq,None}, -q {awq,gptq,None}
|
||||
--tensor-parallel-size TENSOR_PARALLEL_SIZE, -tp TENSOR_PARALLEL_SIZE
|
||||
--n N Number of generated sequences per prompt.
|
||||
--use-beam-search
|
||||
|
@ -1,6 +1,6 @@
|
||||
### Quantizer Utilities
|
||||
`quantize.py`: NVIDIA Quantization utilities using AMMO, ported from TensorRT-LLM:
|
||||
`https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/quantization/quantize.py`
|
||||
`quantize.py`: NVIDIA Quantization utilities using TensorRT-Model-Optimizer, ported
|
||||
from TensorRT-LLM: [`examples/quantization/quantize.py`](https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/quantization/quantize.py)
|
||||
|
||||
### Prerequisite
|
||||
|
||||
|
@ -11,25 +11,33 @@ from vllm import LLM, SamplingParams
|
||||
from vllm.assets.audio import AudioAsset
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
# Input audio and question
|
||||
audio_and_sample_rate = AudioAsset("mary_had_lamb").audio_and_sample_rate
|
||||
question = "What is recited in the audio?"
|
||||
audio_assets = [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")]
|
||||
question_per_audio_count = [
|
||||
"What is recited in the audio?",
|
||||
"What sport and what nursery rhyme are referenced?"
|
||||
]
|
||||
|
||||
|
||||
# Ultravox 0.3
|
||||
def run_ultravox(question):
|
||||
def run_ultravox(question, audio_count):
|
||||
model_name = "fixie-ai/ultravox-v0_3"
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
messages = [{
|
||||
'role': 'user',
|
||||
'content': f"<|reserved_special_token_0|>\n{question}"
|
||||
'role':
|
||||
'user',
|
||||
'content':
|
||||
"<|reserved_special_token_0|>\n" * audio_count + question
|
||||
}]
|
||||
prompt = tokenizer.apply_chat_template(messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True)
|
||||
|
||||
llm = LLM(model=model_name)
|
||||
llm = LLM(model=model_name,
|
||||
enforce_eager=True,
|
||||
enable_chunked_prefill=False,
|
||||
max_model_len=8192,
|
||||
limit_mm_per_prompt={"audio": audio_count})
|
||||
stop_token_ids = None
|
||||
return llm, prompt, stop_token_ids
|
||||
|
||||
@ -44,7 +52,9 @@ def main(args):
|
||||
if model not in model_example_map:
|
||||
raise ValueError(f"Model type {model} is not supported.")
|
||||
|
||||
llm, prompt, stop_token_ids = model_example_map[model](question)
|
||||
audio_count = args.num_audios
|
||||
llm, prompt, stop_token_ids = model_example_map[model](
|
||||
question_per_audio_count[audio_count - 1], audio_count)
|
||||
|
||||
# We set temperature to 0.2 so that outputs can be different
|
||||
# even when all prompts are identical when running batch inference.
|
||||
@ -53,23 +63,18 @@ def main(args):
|
||||
stop_token_ids=stop_token_ids)
|
||||
|
||||
assert args.num_prompts > 0
|
||||
if args.num_prompts == 1:
|
||||
# Single inference
|
||||
inputs = {
|
||||
"prompt": prompt,
|
||||
"multi_modal_data": {
|
||||
"audio": audio_and_sample_rate
|
||||
"audio": [
|
||||
asset.audio_and_sample_rate
|
||||
for asset in audio_assets[:audio_count]
|
||||
]
|
||||
},
|
||||
}
|
||||
|
||||
else:
|
||||
if args.num_prompts > 1:
|
||||
# Batch inference
|
||||
inputs = [{
|
||||
"prompt": prompt,
|
||||
"multi_modal_data": {
|
||||
"audio": audio_and_sample_rate
|
||||
},
|
||||
} for _ in range(args.num_prompts)]
|
||||
inputs = [inputs] * args.num_prompts
|
||||
|
||||
outputs = llm.generate(inputs, sampling_params=sampling_params)
|
||||
|
||||
@ -92,6 +97,11 @@ if __name__ == "__main__":
|
||||
type=int,
|
||||
default=1,
|
||||
help='Number of prompts to run.')
|
||||
parser.add_argument("--num-audios",
|
||||
type=int,
|
||||
default=1,
|
||||
choices=[1, 2],
|
||||
help="Number of audio items per prompt.")
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
@ -1,5 +1,12 @@
|
||||
import os
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
# creates XLA hlo graphs for all the context length buckets.
|
||||
os.environ['NEURON_CONTEXT_LENGTH_BUCKETS'] = "128,512,1024,2048"
|
||||
# creates XLA hlo graphs for all the token gen buckets.
|
||||
os.environ['NEURON_TOKEN_GEN_BUCKETS'] = "128,512,1024,2048"
|
||||
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
@ -19,8 +26,8 @@ llm = LLM(
|
||||
# Currently, this is a known limitation in continuous batching support
|
||||
# in transformers-neuronx.
|
||||
# TODO(liangfu): Support paged-attention in transformers-neuronx.
|
||||
max_model_len=128,
|
||||
block_size=128,
|
||||
max_model_len=2048,
|
||||
block_size=2048,
|
||||
# The device can be automatically detected when AWS Neuron SDK is installed.
|
||||
# The device argument can be either unspecified for automated detection,
|
||||
# or explicitly assigned.
|
||||
|
50
examples/offline_inference_neuron_int8_quantization.py
Normal file
50
examples/offline_inference_neuron_int8_quantization.py
Normal file
@ -0,0 +1,50 @@
|
||||
import os
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
# creates XLA hlo graphs for all the context length buckets.
|
||||
os.environ['NEURON_CONTEXT_LENGTH_BUCKETS'] = "128,512,1024,2048"
|
||||
# creates XLA hlo graphs for all the token gen buckets.
|
||||
os.environ['NEURON_TOKEN_GEN_BUCKETS'] = "128,512,1024,2048"
|
||||
# Quantizes neuron model weight to int8 ,
|
||||
# The default config for quantization is int8 dtype.
|
||||
os.environ['NEURON_QUANT_DTYPE'] = "s8"
|
||||
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
# Create a sampling params object.
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
|
||||
# Create an LLM.
|
||||
llm = LLM(
|
||||
model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
max_num_seqs=8,
|
||||
# The max_model_len and block_size arguments are required to be same as
|
||||
# max sequence length when targeting neuron device.
|
||||
# Currently, this is a known limitation in continuous batching support
|
||||
# in transformers-neuronx.
|
||||
# TODO(liangfu): Support paged-attention in transformers-neuronx.
|
||||
max_model_len=2048,
|
||||
block_size=2048,
|
||||
# The device can be automatically detected when AWS Neuron SDK is installed.
|
||||
# The device argument can be either unspecified for automated detection,
|
||||
# or explicitly assigned.
|
||||
device="neuron",
|
||||
quantization="neuron_quant",
|
||||
override_neuron_config={
|
||||
"cast_logits_dtype": "bfloat16",
|
||||
},
|
||||
tensor_parallel_size=2)
|
||||
# Generate texts from the prompts. The output is a list of RequestOutput objects
|
||||
# that contain the prompt, generated text, and other information.
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
# Print the outputs.
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
165
examples/offline_inference_pixtral.py
Normal file
165
examples/offline_inference_pixtral.py
Normal file
@ -0,0 +1,165 @@
|
||||
# ruff: noqa
|
||||
import argparse
|
||||
|
||||
from vllm import LLM
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
# This script is an offline demo for running Pixtral.
|
||||
#
|
||||
# If you want to run a server/client setup, please follow this code:
|
||||
#
|
||||
# - Server:
|
||||
#
|
||||
# ```bash
|
||||
# vllm serve mistralai/Pixtral-12B-2409 --tokenizer-mode mistral --limit-mm-per-prompt 'image=4' --max-model-len 16384
|
||||
# ```
|
||||
#
|
||||
# - Client:
|
||||
#
|
||||
# ```bash
|
||||
# curl --location 'http://<your-node-url>:8000/v1/chat/completions' \
|
||||
# --header 'Content-Type: application/json' \
|
||||
# --header 'Authorization: Bearer token' \
|
||||
# --data '{
|
||||
# "model": "mistralai/Pixtral-12B-2409",
|
||||
# "messages": [
|
||||
# {
|
||||
# "role": "user",
|
||||
# "content": [
|
||||
# {"type" : "text", "text": "Describe this image in detail please."},
|
||||
# {"type": "image_url", "image_url": {"url": "https://s3.amazonaws.com/cms.ipressroom.com/338/files/201808/5b894ee1a138352221103195_A680%7Ejogging-edit/A680%7Ejogging-edit_hero.jpg"}},
|
||||
# {"type" : "text", "text": "and this one as well. Answer in French."},
|
||||
# {"type": "image_url", "image_url": {"url": "https://www.wolframcloud.com/obj/resourcesystem/images/a0e/a0ee3983-46c6-4c92-b85d-059044639928/6af8cfb971db031b.png"}}
|
||||
# ]
|
||||
# }
|
||||
# ]
|
||||
# }'
|
||||
# ```
|
||||
#
|
||||
# Usage:
|
||||
# python demo.py simple
|
||||
# python demo.py advanced
|
||||
|
||||
|
||||
def run_simple_demo():
|
||||
model_name = "mistralai/Pixtral-12B-2409"
|
||||
sampling_params = SamplingParams(max_tokens=8192)
|
||||
|
||||
# Lower max_num_seqs or max_model_len on low-VRAM GPUs.
|
||||
llm = LLM(model=model_name, tokenizer_mode="mistral")
|
||||
|
||||
prompt = "Describe this image in one sentence."
|
||||
image_url = "https://picsum.photos/id/237/200/300"
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": prompt
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url
|
||||
}
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
outputs = llm.chat(messages, sampling_params=sampling_params)
|
||||
|
||||
print(outputs[0].outputs[0].text)
|
||||
|
||||
|
||||
def run_advanced_demo():
|
||||
model_name = "mistralai/Pixtral-12B-2409"
|
||||
max_img_per_msg = 5
|
||||
max_tokens_per_img = 4096
|
||||
|
||||
sampling_params = SamplingParams(max_tokens=8192, temperature=0.7)
|
||||
llm = LLM(
|
||||
model=model_name,
|
||||
tokenizer_mode="mistral",
|
||||
limit_mm_per_prompt={"image": max_img_per_msg},
|
||||
max_model_len=max_img_per_msg * max_tokens_per_img,
|
||||
)
|
||||
|
||||
prompt = "Describe the following image."
|
||||
|
||||
url_1 = "https://huggingface.co/datasets/patrickvonplaten/random_img/resolve/main/yosemite.png"
|
||||
url_2 = "https://picsum.photos/seed/picsum/200/300"
|
||||
url_3 = "https://picsum.photos/id/32/512/512"
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": prompt
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": url_1
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": url_2
|
||||
}
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "The images show nature.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "More details please and answer only in French!.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": url_3
|
||||
}
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
outputs = llm.chat(messages=messages, sampling_params=sampling_params)
|
||||
print(outputs[0].outputs[0].text)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run a demo in simple or advanced mode.")
|
||||
|
||||
parser.add_argument(
|
||||
"mode",
|
||||
choices=["simple", "advanced"],
|
||||
help="Specify the demo mode: 'simple' or 'advanced'",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.mode == "simple":
|
||||
print("Running simple demo...")
|
||||
run_simple_demo()
|
||||
elif args.mode == "advanced":
|
||||
print("Running advanced demo...")
|
||||
run_advanced_demo()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -9,12 +9,9 @@ from transformers import AutoTokenizer
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.assets.image import ImageAsset
|
||||
from vllm.assets.video import VideoAsset
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
# Input image and question
|
||||
image = ImageAsset("cherry_blossom").pil_image.convert("RGB")
|
||||
question = "What is the content of this image?"
|
||||
|
||||
|
||||
# LLaVA-1.5
|
||||
def run_llava(question):
|
||||
@ -30,7 +27,16 @@ def run_llava(question):
|
||||
def run_llava_next(question):
|
||||
|
||||
prompt = f"[INST] <image>\n{question} [/INST]"
|
||||
llm = LLM(model="llava-hf/llava-v1.6-mistral-7b-hf")
|
||||
llm = LLM(model="llava-hf/llava-v1.6-mistral-7b-hf", max_model_len=8192)
|
||||
stop_token_ids = None
|
||||
return llm, prompt, stop_token_ids
|
||||
|
||||
|
||||
# LlaVA-NeXT-Video
|
||||
# Currently only support for video input
|
||||
def run_llava_next_video(question):
|
||||
prompt = f"USER: <video>\n{question} ASSISTANT:"
|
||||
llm = LLM(model="llava-hf/LLaVA-NeXT-Video-7B-hf", max_model_len=8192)
|
||||
stop_token_ids = None
|
||||
return llm, prompt, stop_token_ids
|
||||
|
||||
@ -159,9 +165,41 @@ def run_blip2(question):
|
||||
return llm, prompt, stop_token_ids
|
||||
|
||||
|
||||
# Qwen
|
||||
def run_qwen_vl(question):
|
||||
|
||||
llm = LLM(
|
||||
model="Qwen/Qwen-VL",
|
||||
trust_remote_code=True,
|
||||
max_num_seqs=5,
|
||||
)
|
||||
|
||||
prompt = f"{question}Picture 1: <img></img>\n"
|
||||
stop_token_ids = None
|
||||
return llm, prompt, stop_token_ids
|
||||
|
||||
|
||||
# Qwen2-VL
|
||||
def run_qwen2_vl(question):
|
||||
model_name = "Qwen/Qwen2-VL-7B-Instruct"
|
||||
|
||||
llm = LLM(
|
||||
model=model_name,
|
||||
max_num_seqs=5,
|
||||
)
|
||||
|
||||
prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
|
||||
"<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>"
|
||||
f"{question}<|im_end|>\n"
|
||||
"<|im_start|>assistant\n")
|
||||
stop_token_ids = None
|
||||
return llm, prompt, stop_token_ids
|
||||
|
||||
|
||||
model_example_map = {
|
||||
"llava": run_llava,
|
||||
"llava-next": run_llava_next,
|
||||
"llava-next-video": run_llava_next_video,
|
||||
"fuyu": run_fuyu,
|
||||
"phi3_v": run_phi3v,
|
||||
"paligemma": run_paligemma,
|
||||
@ -169,14 +207,54 @@ model_example_map = {
|
||||
"minicpmv": run_minicpmv,
|
||||
"blip-2": run_blip2,
|
||||
"internvl_chat": run_internvl,
|
||||
"qwen_vl": run_qwen_vl,
|
||||
"qwen2_vl": run_qwen2_vl,
|
||||
}
|
||||
|
||||
|
||||
def get_multi_modal_input(args):
|
||||
"""
|
||||
return {
|
||||
"data": image or video,
|
||||
"question": question,
|
||||
}
|
||||
"""
|
||||
if args.modality == "image":
|
||||
# Input image and question
|
||||
image = ImageAsset("cherry_blossom") \
|
||||
.pil_image.convert("RGB")
|
||||
img_question = "What is the content of this image?"
|
||||
|
||||
return {
|
||||
"data": image,
|
||||
"question": img_question,
|
||||
}
|
||||
|
||||
if args.modality == "video":
|
||||
# Input video and question
|
||||
video = VideoAsset(name="sample_demo_1.mp4",
|
||||
num_frames=args.num_frames).np_ndarrays
|
||||
vid_question = "Why is this video funny?"
|
||||
|
||||
return {
|
||||
"data": video,
|
||||
"question": vid_question,
|
||||
}
|
||||
|
||||
msg = f"Modality {args.modality} is not supported."
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
def main(args):
|
||||
model = args.model_type
|
||||
if model not in model_example_map:
|
||||
raise ValueError(f"Model type {model} is not supported.")
|
||||
|
||||
modality = args.modality
|
||||
mm_input = get_multi_modal_input(args)
|
||||
data = mm_input["data"]
|
||||
question = mm_input["question"]
|
||||
|
||||
llm, prompt, stop_token_ids = model_example_map[model](question)
|
||||
|
||||
# We set temperature to 0.2 so that outputs can be different
|
||||
@ -191,7 +269,7 @@ def main(args):
|
||||
inputs = {
|
||||
"prompt": prompt,
|
||||
"multi_modal_data": {
|
||||
"image": image
|
||||
modality: data
|
||||
},
|
||||
}
|
||||
|
||||
@ -200,7 +278,7 @@ def main(args):
|
||||
inputs = [{
|
||||
"prompt": prompt,
|
||||
"multi_modal_data": {
|
||||
"image": image
|
||||
modality: data
|
||||
},
|
||||
} for _ in range(args.num_prompts)]
|
||||
|
||||
@ -223,8 +301,15 @@ if __name__ == "__main__":
|
||||
help='Huggingface "model_type".')
|
||||
parser.add_argument('--num-prompts',
|
||||
type=int,
|
||||
default=1,
|
||||
default=4,
|
||||
help='Number of prompts to run.')
|
||||
|
||||
parser.add_argument('--modality',
|
||||
type=str,
|
||||
default="image",
|
||||
help='Modality of the input.')
|
||||
parser.add_argument('--num-frames',
|
||||
type=int,
|
||||
default=16,
|
||||
help='Number of frames to extract from the video.')
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
243
examples/offline_inference_vision_language_multi_image.py
Normal file
243
examples/offline_inference_vision_language_multi_image.py
Normal file
@ -0,0 +1,243 @@
|
||||
"""
|
||||
This example shows how to use vLLM for running offline inference with
|
||||
multi-image input on vision language models, using the chat template defined
|
||||
by the model.
|
||||
"""
|
||||
from argparse import Namespace
|
||||
from typing import List
|
||||
|
||||
from transformers import AutoProcessor, AutoTokenizer
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.multimodal.utils import fetch_image
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
QUESTION = "What is the content of each image?"
|
||||
IMAGE_URLS = [
|
||||
"https://upload.wikimedia.org/wikipedia/commons/d/da/2015_Kaczka_krzy%C5%BCowka_w_wodzie_%28samiec%29.jpg",
|
||||
"https://upload.wikimedia.org/wikipedia/commons/7/77/002_The_lion_king_Snyggve_in_the_Serengeti_National_Park_Photo_by_Giles_Laurent.jpg",
|
||||
]
|
||||
|
||||
|
||||
def load_qwenvl_chat(question: str, image_urls: List[str]):
|
||||
model_name = "Qwen/Qwen-VL-Chat"
|
||||
llm = LLM(
|
||||
model=model_name,
|
||||
trust_remote_code=True,
|
||||
max_num_seqs=5,
|
||||
limit_mm_per_prompt={"image": len(image_urls)},
|
||||
)
|
||||
placeholders = "".join(f"Picture {i}: <img></img>\n"
|
||||
for i, _ in enumerate(image_urls, start=1))
|
||||
|
||||
# This model does not have a chat_template attribute on its tokenizer,
|
||||
# so we need to explicitly pass it. We use ChatML since it's used in the
|
||||
# generation utils of the model:
|
||||
# https://huggingface.co/Qwen/Qwen-VL-Chat/blob/main/qwen_generation_utils.py#L265
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name,
|
||||
trust_remote_code=True)
|
||||
|
||||
# Copied from: https://huggingface.co/docs/transformers/main/en/chat_templating
|
||||
chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" # noqa: E501
|
||||
|
||||
messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}]
|
||||
prompt = tokenizer.apply_chat_template(messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
chat_template=chat_template)
|
||||
|
||||
stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>"]
|
||||
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
|
||||
return llm, prompt, stop_token_ids, None, chat_template
|
||||
|
||||
|
||||
def load_phi3v(question: str, image_urls: List[str]):
|
||||
llm = LLM(
|
||||
model="microsoft/Phi-3.5-vision-instruct",
|
||||
trust_remote_code=True,
|
||||
max_model_len=4096,
|
||||
limit_mm_per_prompt={"image": len(image_urls)},
|
||||
)
|
||||
placeholders = "\n".join(f"<|image_{i}|>"
|
||||
for i, _ in enumerate(image_urls, start=1))
|
||||
prompt = f"<|user|>\n{placeholders}\n{question}<|end|>\n<|assistant|>\n"
|
||||
stop_token_ids = None
|
||||
return llm, prompt, stop_token_ids, None, None
|
||||
|
||||
|
||||
def load_internvl(question: str, image_urls: List[str]):
|
||||
model_name = "OpenGVLab/InternVL2-2B"
|
||||
|
||||
llm = LLM(
|
||||
model=model_name,
|
||||
trust_remote_code=True,
|
||||
max_num_seqs=5,
|
||||
max_model_len=4096,
|
||||
limit_mm_per_prompt={"image": len(image_urls)},
|
||||
)
|
||||
|
||||
placeholders = "\n".join(f"Image-{i}: <image>\n"
|
||||
for i, _ in enumerate(image_urls, start=1))
|
||||
messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}]
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name,
|
||||
trust_remote_code=True)
|
||||
prompt = tokenizer.apply_chat_template(messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True)
|
||||
|
||||
# Stop tokens for InternVL
|
||||
# models variants may have different stop tokens
|
||||
# please refer to the model card for the correct "stop words":
|
||||
# https://huggingface.co/OpenGVLab/InternVL2-2B#service
|
||||
stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
|
||||
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
|
||||
|
||||
return llm, prompt, stop_token_ids, None, None
|
||||
|
||||
|
||||
def load_qwen2_vl(question, image_urls: List[str]):
|
||||
try:
|
||||
from qwen_vl_utils import process_vision_info
|
||||
except ModuleNotFoundError:
|
||||
print('WARNING: `qwen-vl-utils` not installed, input images will not '
|
||||
'be automatically resized. You can enable this functionality by '
|
||||
'`pip install qwen-vl-utils`.')
|
||||
process_vision_info = None
|
||||
|
||||
model_name = "Qwen/Qwen2-VL-7B-Instruct"
|
||||
|
||||
llm = LLM(
|
||||
model=model_name,
|
||||
max_num_seqs=5,
|
||||
max_model_len=32768 if process_vision_info is None else 4096,
|
||||
limit_mm_per_prompt={"image": len(image_urls)},
|
||||
)
|
||||
|
||||
placeholders = [{"type": "image", "image": url} for url in image_urls]
|
||||
messages = [{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
}, {
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
*placeholders,
|
||||
{
|
||||
"type": "text",
|
||||
"text": question
|
||||
},
|
||||
],
|
||||
}]
|
||||
|
||||
processor = AutoProcessor.from_pretrained(model_name)
|
||||
|
||||
prompt = processor.apply_chat_template(messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True)
|
||||
|
||||
stop_token_ids = None
|
||||
|
||||
if process_vision_info is None:
|
||||
image_data = [fetch_image(url) for url in image_urls]
|
||||
else:
|
||||
image_data, _ = process_vision_info(messages)
|
||||
|
||||
return llm, prompt, stop_token_ids, image_data, None
|
||||
|
||||
|
||||
model_example_map = {
|
||||
"phi3_v": load_phi3v,
|
||||
"internvl_chat": load_internvl,
|
||||
"qwen2_vl": load_qwen2_vl,
|
||||
"qwen_vl_chat": load_qwenvl_chat,
|
||||
}
|
||||
|
||||
|
||||
def run_generate(model, question: str, image_urls: List[str]):
|
||||
llm, prompt, stop_token_ids, image_data, _ = model_example_map[model](
|
||||
question, image_urls)
|
||||
if image_data is None:
|
||||
image_data = [fetch_image(url) for url in image_urls]
|
||||
|
||||
sampling_params = SamplingParams(temperature=0.0,
|
||||
max_tokens=128,
|
||||
stop_token_ids=stop_token_ids)
|
||||
|
||||
outputs = llm.generate(
|
||||
{
|
||||
"prompt": prompt,
|
||||
"multi_modal_data": {
|
||||
"image": image_data
|
||||
},
|
||||
},
|
||||
sampling_params=sampling_params)
|
||||
|
||||
for o in outputs:
|
||||
generated_text = o.outputs[0].text
|
||||
print(generated_text)
|
||||
|
||||
|
||||
def run_chat(model: str, question: str, image_urls: List[str]):
|
||||
llm, _, stop_token_ids, _, chat_template = model_example_map[model](
|
||||
question, image_urls)
|
||||
|
||||
sampling_params = SamplingParams(temperature=0.0,
|
||||
max_tokens=128,
|
||||
stop_token_ids=stop_token_ids)
|
||||
outputs = llm.chat(
|
||||
[{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": question,
|
||||
},
|
||||
*({
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url
|
||||
},
|
||||
} for image_url in image_urls),
|
||||
],
|
||||
}],
|
||||
sampling_params=sampling_params,
|
||||
chat_template=chat_template,
|
||||
)
|
||||
|
||||
for o in outputs:
|
||||
generated_text = o.outputs[0].text
|
||||
print(generated_text)
|
||||
|
||||
|
||||
def main(args: Namespace):
|
||||
model = args.model_type
|
||||
method = args.method
|
||||
|
||||
if method == "generate":
|
||||
run_generate(model, QUESTION, IMAGE_URLS)
|
||||
elif method == "chat":
|
||||
run_chat(model, QUESTION, IMAGE_URLS)
|
||||
else:
|
||||
raise ValueError(f"Invalid method: {method}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser(
|
||||
description='Demo on using vLLM for offline inference with '
|
||||
'vision language models that support multi-image input')
|
||||
parser.add_argument('--model-type',
|
||||
'-m',
|
||||
type=str,
|
||||
default="phi3_v",
|
||||
choices=model_example_map.keys(),
|
||||
help='Huggingface "model_type".')
|
||||
parser.add_argument("--method",
|
||||
type=str,
|
||||
default="generate",
|
||||
choices=["generate", "chat"],
|
||||
help="The method to run in `vllm.LLM`.")
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
33
examples/offline_inference_with_profiler.py
Normal file
33
examples/offline_inference_with_profiler.py
Normal file
@ -0,0 +1,33 @@
|
||||
import os
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
# enable torch profiler, can also be set on cmd line
|
||||
os.environ["VLLM_TORCH_PROFILER_DIR"] = "./vllm_profile"
|
||||
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
# Create a sampling params object.
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
|
||||
# Create an LLM.
|
||||
llm = LLM(model="facebook/opt-125m", tensor_parallel_size=1)
|
||||
|
||||
llm.start_profile()
|
||||
|
||||
# Generate texts from the prompts. The output is a list of RequestOutput objects
|
||||
# that contain the prompt, generated text, and other information.
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
llm.stop_profile()
|
||||
|
||||
# Print the outputs.
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
162
examples/openai_chat_completion_client_with_tools.py
Normal file
162
examples/openai_chat_completion_client_with_tools.py
Normal file
@ -0,0 +1,162 @@
|
||||
"""
|
||||
Set up this example by starting a vLLM OpenAI-compatible server with tool call
|
||||
options enabled. For example:
|
||||
|
||||
IMPORTANT: for mistral, you must use one of the provided mistral tool call
|
||||
templates, or your own - the model default doesn't work for tool calls with vLLM
|
||||
See the vLLM docs on OpenAI server & tool calling for more details.
|
||||
|
||||
vllm serve --model mistralai/Mistral-7B-Instruct-v0.3 \
|
||||
--chat-template examples/tool_chat_template_mistral.jinja \
|
||||
--enable-auto-tool-choice --tool-call-parser mistral
|
||||
|
||||
OR
|
||||
vllm serve --model NousResearch/Hermes-2-Pro-Llama-3-8B \
|
||||
--chat-template examples/tool_chat_template_hermes.jinja \
|
||||
--enable-auto-tool-choice --tool-call-parser hermes
|
||||
"""
|
||||
import json
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
# Modify OpenAI's API key and API base to use vLLM's API server.
|
||||
openai_api_key = "EMPTY"
|
||||
openai_api_base = "http://localhost:8000/v1"
|
||||
|
||||
client = OpenAI(
|
||||
# defaults to os.environ.get("OPENAI_API_KEY")
|
||||
api_key=openai_api_key,
|
||||
base_url=openai_api_base,
|
||||
)
|
||||
|
||||
models = client.models.list()
|
||||
model = models.data[0].id
|
||||
|
||||
tools = [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type":
|
||||
"string",
|
||||
"description":
|
||||
"The city to find the weather for, e.g. 'San Francisco'"
|
||||
},
|
||||
"state": {
|
||||
"type":
|
||||
"string",
|
||||
"description":
|
||||
"the two-letter abbreviation for the state that the city is"
|
||||
" in, e.g. 'CA' which would mean 'California'"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"description": "The unit to fetch the temperature in",
|
||||
"enum": ["celsius", "fahrenheit"]
|
||||
}
|
||||
},
|
||||
"required": ["city", "state", "unit"]
|
||||
}
|
||||
}
|
||||
}]
|
||||
|
||||
messages = [{
|
||||
"role": "user",
|
||||
"content": "Hi! How are you doing today?"
|
||||
}, {
|
||||
"role": "assistant",
|
||||
"content": "I'm doing well! How can I help you?"
|
||||
}, {
|
||||
"role":
|
||||
"user",
|
||||
"content":
|
||||
"Can you tell me what the temperate will be in Dallas, in fahrenheit?"
|
||||
}]
|
||||
|
||||
chat_completion = client.chat.completions.create(messages=messages,
|
||||
model=model,
|
||||
tools=tools)
|
||||
|
||||
print("Chat completion results:")
|
||||
print(chat_completion)
|
||||
print("\n\n")
|
||||
|
||||
tool_calls_stream = client.chat.completions.create(messages=messages,
|
||||
model=model,
|
||||
tools=tools,
|
||||
stream=True)
|
||||
|
||||
chunks = []
|
||||
for chunk in tool_calls_stream:
|
||||
chunks.append(chunk)
|
||||
if chunk.choices[0].delta.tool_calls:
|
||||
print(chunk.choices[0].delta.tool_calls[0])
|
||||
else:
|
||||
print(chunk.choices[0].delta)
|
||||
|
||||
arguments = []
|
||||
tool_call_idx = -1
|
||||
for chunk in chunks:
|
||||
|
||||
if chunk.choices[0].delta.tool_calls:
|
||||
tool_call = chunk.choices[0].delta.tool_calls[0]
|
||||
|
||||
if tool_call.index != tool_call_idx:
|
||||
if tool_call_idx >= 0:
|
||||
print(
|
||||
f"streamed tool call arguments: {arguments[tool_call_idx]}"
|
||||
)
|
||||
tool_call_idx = chunk.choices[0].delta.tool_calls[0].index
|
||||
arguments.append("")
|
||||
if tool_call.id:
|
||||
print(f"streamed tool call id: {tool_call.id} ")
|
||||
|
||||
if tool_call.function:
|
||||
if tool_call.function.name:
|
||||
print(f"streamed tool call name: {tool_call.function.name}")
|
||||
|
||||
if tool_call.function.arguments:
|
||||
arguments[tool_call_idx] += tool_call.function.arguments
|
||||
|
||||
if len(arguments):
|
||||
print(f"streamed tool call arguments: {arguments[-1]}")
|
||||
|
||||
print("\n\n")
|
||||
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"tool_calls": chat_completion.choices[0].message.tool_calls
|
||||
})
|
||||
|
||||
|
||||
# Now, simulate a tool call
|
||||
def get_current_weather(city: str, state: str, unit: 'str'):
|
||||
return ("The weather in Dallas, Texas is 85 degrees fahrenheit. It is "
|
||||
"partly cloudly, with highs in the 90's.")
|
||||
|
||||
|
||||
available_tools = {"get_current_weather": get_current_weather}
|
||||
|
||||
completion_tool_calls = chat_completion.choices[0].message.tool_calls
|
||||
for call in completion_tool_calls:
|
||||
tool_to_call = available_tools[call.function.name]
|
||||
args = json.loads(call.function.arguments)
|
||||
result = tool_to_call(**args)
|
||||
print(result)
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"content": result,
|
||||
"tool_call_id": call.id,
|
||||
"name": call.function.name
|
||||
})
|
||||
|
||||
chat_completion_2 = client.chat.completions.create(messages=messages,
|
||||
model=model,
|
||||
tools=tools,
|
||||
stream=False)
|
||||
print("\n\n")
|
||||
print(chat_completion_2)
|
@ -19,7 +19,6 @@ responses = client.embeddings.create(
|
||||
"The best thing about vLLM is that it supports many different models"
|
||||
],
|
||||
model=model,
|
||||
encoding_format="float",
|
||||
)
|
||||
|
||||
for data in responses.data:
|
||||
|
@ -1,7 +1,13 @@
|
||||
"""An example showing how to use vLLM to serve VLMs.
|
||||
|
||||
Launch the vLLM server with the following command:
|
||||
|
||||
(single image inference with Llava)
|
||||
vllm serve llava-hf/llava-1.5-7b-hf --chat-template template_llava.jinja
|
||||
|
||||
(multi-image inference with Phi-3.5-vision-instruct)
|
||||
vllm serve microsoft/Phi-3.5-vision-instruct --max-model-len 4096 \
|
||||
--trust-remote-code --limit-mm-per-prompt image=2
|
||||
"""
|
||||
import base64
|
||||
|
||||
@ -21,9 +27,10 @@ client = OpenAI(
|
||||
models = client.models.list()
|
||||
model = models.data[0].id
|
||||
|
||||
# Single-image input inference
|
||||
image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
|
||||
|
||||
# Use image url in the payload
|
||||
## Use image url in the payload
|
||||
chat_completion_from_url = client.chat.completions.create(
|
||||
messages=[{
|
||||
"role":
|
||||
@ -46,10 +53,10 @@ chat_completion_from_url = client.chat.completions.create(
|
||||
)
|
||||
|
||||
result = chat_completion_from_url.choices[0].message.content
|
||||
print(f"Chat completion output:{result}")
|
||||
print("Chat completion output:", result)
|
||||
|
||||
|
||||
# Use base64 encoded image in the payload
|
||||
## Use base64 encoded image in the payload
|
||||
def encode_image_base64_from_url(image_url: str) -> str:
|
||||
"""Encode an image retrieved from a remote url to base64 format."""
|
||||
|
||||
@ -84,3 +91,36 @@ chat_completion_from_base64 = client.chat.completions.create(
|
||||
|
||||
result = chat_completion_from_base64.choices[0].message.content
|
||||
print(f"Chat completion output:{result}")
|
||||
|
||||
# Multi-image input inference
|
||||
image_url_duck = "https://upload.wikimedia.org/wikipedia/commons/d/da/2015_Kaczka_krzy%C5%BCowka_w_wodzie_%28samiec%29.jpg"
|
||||
image_url_lion = "https://upload.wikimedia.org/wikipedia/commons/7/77/002_The_lion_king_Snyggve_in_the_Serengeti_National_Park_Photo_by_Giles_Laurent.jpg"
|
||||
chat_completion_from_url = client.chat.completions.create(
|
||||
messages=[{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What are the animals in these images?"
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url_duck
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url_lion
|
||||
},
|
||||
},
|
||||
],
|
||||
}],
|
||||
model=model,
|
||||
max_tokens=64,
|
||||
)
|
||||
|
||||
result = chat_completion_from_url.choices[0].message.content
|
||||
print("Chat completion output:", result)
|
||||
|
130
examples/tool_chat_template_hermes.jinja
Normal file
130
examples/tool_chat_template_hermes.jinja
Normal file
@ -0,0 +1,130 @@
|
||||
{%- macro json_to_python_type(json_spec) %}
|
||||
{%- set basic_type_map = {
|
||||
"string": "str",
|
||||
"number": "float",
|
||||
"integer": "int",
|
||||
"boolean": "bool"
|
||||
} %}
|
||||
|
||||
{%- if basic_type_map[json_spec.type] is defined %}
|
||||
{{- basic_type_map[json_spec.type] }}
|
||||
{%- elif json_spec.type == "array" %}
|
||||
{{- "list[" + json_to_python_type(json_spec|items) + "]" }}
|
||||
{%- elif json_spec.type == "object" %}
|
||||
{%- if json_spec.additionalProperties is defined %}
|
||||
{{- "dict[str, " + json_to_python_type(json_spec.additionalProperties) + ']' }}
|
||||
{%- else %}
|
||||
{{- "dict" }}
|
||||
{%- endif %}
|
||||
{%- elif json_spec.type is iterable %}
|
||||
{{- "Union[" }}
|
||||
{%- for t in json_spec.type %}
|
||||
{{- json_to_python_type({"type": t}) }}
|
||||
{%- if not loop.last %}
|
||||
{{- "," }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{{- "]" }}
|
||||
{%- else %}
|
||||
{{- "Any" }}
|
||||
{%- endif %}
|
||||
{%- endmacro %}
|
||||
|
||||
|
||||
{{- bos_token }}
|
||||
{{- "<|im_start|>system\nYou are a function calling AI model. You are provided with function signatures within <tools></tools> XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: <tools> " }}
|
||||
{%- if tools is iterable and tools | length > 0 %}
|
||||
{%- for tool in tools %}
|
||||
{%- if tool.function is defined %}
|
||||
{%- set tool = tool.function %}
|
||||
{%- endif %}
|
||||
{{- '{"type": "function", "function": ' }}
|
||||
{{- '{"name": "' + tool.name + '", ' }}
|
||||
{{- '"description": "' + tool.name + '(' }}
|
||||
{%- for param_name, param_fields in tool.parameters.properties|items %}
|
||||
{{- param_name + ": " + json_to_python_type(param_fields) }}
|
||||
{%- if not loop.last %}
|
||||
{{- ", " }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{{- ")" }}
|
||||
{%- if tool.return is defined %}
|
||||
{{- " -> " + json_to_python_type(tool.return) }}
|
||||
{%- endif %}
|
||||
{{- " - " + tool.description + "\n\n" }}
|
||||
{%- for param_name, param_fields in tool.parameters.properties|items %}
|
||||
{%- if loop.first %}
|
||||
{{- " Args:\n" }}
|
||||
{%- endif %}
|
||||
{{- " " + param_name + "(" + json_to_python_type(param_fields) + "): " + param_fields.description|trim }}
|
||||
{%- endfor %}
|
||||
{%- if tool.return is defined and tool.return.description is defined %}
|
||||
{{- "\n Returns:\n " + tool.return.description }}
|
||||
{%- endif %}
|
||||
{{- '"' }}
|
||||
{{- ', "parameters": ' }}
|
||||
{%- if tool.parameters.properties | length == 0 %}
|
||||
{{- "{}" }}
|
||||
{%- else %}
|
||||
{{- tool.parameters|tojson }}
|
||||
{%- endif %}
|
||||
{{- "}" }}
|
||||
{%- if not loop.last %}
|
||||
{{- "\n" }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- endif %}
|
||||
{{- " </tools>" }}
|
||||
{{- 'Use the following pydantic model json schema for each tool call you will make: {"properties": {"name": {"title": "Name", "type": "string"}, "arguments": {"title": "Arguments", "type": "object"}}, "required": ["name", "arguments"], "title": "FunctionCall", "type": "object"}}
|
||||
' }}
|
||||
{{- "For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:
|
||||
" }}
|
||||
{{- "<tool_call>
|
||||
" }}
|
||||
{{- '{"name": <function-name>, "arguments": <args-dict>}
|
||||
' }}
|
||||
{{- '</tool_call><|im_end|>' }}
|
||||
{%- for message in messages %}
|
||||
{%- if message.role == "user" or message.role == "system" or (message.role == "assistant" and message.tool_calls is not defined) %}
|
||||
{{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }}
|
||||
{%- elif message.role == "assistant" and message.tool_calls is defined %}
|
||||
{{- '<|im_start|>' + message.role }}
|
||||
{%- for tool_call in message.tool_calls %}
|
||||
{{- '\n<tool_call>\n' }}
|
||||
{%- if tool_call.function is defined %}
|
||||
{%- set tool_call = tool_call.function %}
|
||||
{%- endif %}
|
||||
{{- '{' }}
|
||||
{{- '"name": "' }}
|
||||
{{- tool_call.name }}
|
||||
{{- '"' }}
|
||||
{%- if tool_call.arguments is defined %}
|
||||
{{- ', ' }}
|
||||
{{- '"arguments": ' }}
|
||||
{{- tool_call.arguments|tojson }}
|
||||
{%- endif %}
|
||||
{{- '}' }}
|
||||
{{- '\n</tool_call>' }}
|
||||
{%- endfor %}
|
||||
{{- '<|im_end|>\n' }}
|
||||
{%- elif message.role == "tool" %}
|
||||
{%- if loop.previtem and loop.previtem.role != "tool" %}
|
||||
{{- '<|im_start|>tool\n' }}
|
||||
{%- endif %}
|
||||
{{- '<tool_response>\n' }}
|
||||
{{- message.content }}
|
||||
{%- if not loop.last %}
|
||||
{{- '\n</tool_response>\n' }}
|
||||
{%- else %}
|
||||
{{- '\n</tool_response>' }}
|
||||
{%- endif %}
|
||||
{%- if not loop.last and loop.nextitem.role != "tool" %}
|
||||
{{- '<|im_end|>' }}
|
||||
{%- elif loop.last %}
|
||||
{{- '<|im_end|>' }}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- if add_generation_prompt %}
|
||||
{{- '<|im_start|>assistant\n' }}
|
||||
{%- endif %}
|
86
examples/tool_chat_template_mistral.jinja
Normal file
86
examples/tool_chat_template_mistral.jinja
Normal file
@ -0,0 +1,86 @@
|
||||
{%- if messages[0]["role"] == "system" %}
|
||||
{%- set system_message = messages[0]["content"] %}
|
||||
{%- set loop_messages = messages[1:] %}
|
||||
{%- else %}
|
||||
{%- set loop_messages = messages %}
|
||||
{%- endif %}
|
||||
{%- if not tools is defined %}
|
||||
{%- set tools = none %}
|
||||
{%- endif %}
|
||||
{%- set user_messages = loop_messages | selectattr("role", "equalto", "user") | list %}
|
||||
|
||||
{%- for message in loop_messages | rejectattr("role", "equalto", "tool") | rejectattr("role", "equalto", "tool_results") | selectattr("tool_calls", "undefined") %}
|
||||
{%- if (message["role"] == "user") != (loop.index0 % 2 == 0) %}
|
||||
{{- raise_exception("After the optional system message, conversation roles must alternate user/assistant/user/assistant/...") }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
|
||||
{{- bos_token }}
|
||||
{%- for message in loop_messages %}
|
||||
{%- if message["role"] == "user" %}
|
||||
{%- if tools is not none and (message == user_messages[-1]) %}
|
||||
{{- "[AVAILABLE_TOOLS] [" }}
|
||||
{%- for tool in tools %}
|
||||
{%- set tool = tool.function %}
|
||||
{{- '{"type": "function", "function": {' }}
|
||||
{%- for key, val in tool.items() if key != "return" %}
|
||||
{%- if val is string %}
|
||||
{{- '"' + key + '": "' + val + '"' }}
|
||||
{%- else %}
|
||||
{{- '"' + key + '": ' + val|tojson }}
|
||||
{%- endif %}
|
||||
{%- if not loop.last %}
|
||||
{{- ", " }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{{- "}}" }}
|
||||
{%- if not loop.last %}
|
||||
{{- ", " }}
|
||||
{%- else %}
|
||||
{{- "]" }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{{- "[/AVAILABLE_TOOLS]" }}
|
||||
{%- endif %}
|
||||
{%- if loop.last and system_message is defined %}
|
||||
{{- "[INST] " + system_message + "\n\n" + message["content"] + "[/INST]" }}
|
||||
{%- else %}
|
||||
{{- "[INST] " + message["content"] + "[/INST]" }}
|
||||
{%- endif %}
|
||||
{%- elif message["role"] == "tool_calls" or message.tool_calls is defined %}
|
||||
{%- if message.tool_calls is defined %}
|
||||
{%- set tool_calls = message.tool_calls %}
|
||||
{%- else %}
|
||||
{%- set tool_calls = message.content %}
|
||||
{%- endif %}
|
||||
{{- "[TOOL_CALLS] [" }}
|
||||
{%- for tool_call in tool_calls %}
|
||||
{%- set out = tool_call.function|tojson %}
|
||||
{{- out[:-1] }}
|
||||
{%- if not tool_call.id is defined or tool_call.id|length < 9 %}
|
||||
{{- raise_exception("Tool call IDs should be alphanumeric strings with length >= 9! (1)" + tool_call.id) }}
|
||||
{%- endif %}
|
||||
{{- ', "id": "' + tool_call.id[-9:] + '"}' }}
|
||||
{%- if not loop.last %}
|
||||
{{- ", " }}
|
||||
{%- else %}
|
||||
{{- "]" + eos_token }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- elif message["role"] == "assistant" %}
|
||||
{{- " " + message["content"] + eos_token }}
|
||||
{%- elif message["role"] == "tool_results" or message["role"] == "tool" %}
|
||||
{%- if message.content is defined and message.content.content is defined %}
|
||||
{%- set content = message.content.content %}
|
||||
{%- else %}
|
||||
{%- set content = message.content %}
|
||||
{%- endif %}
|
||||
{{- '[TOOL_RESULTS] {"content": ' + content|string + ", " }}
|
||||
{%- if not message.tool_call_id is defined or message.tool_call_id|length < 9 %}
|
||||
{{- raise_exception("Tool call IDs should be alphanumeric strings with length >= 9! (2)" + message.tool_call_id) }}
|
||||
{%- endif %}
|
||||
{{- '"call_id": "' + message.tool_call_id[-9:] + '"}[/TOOL_RESULTS]' }}
|
||||
{%- else %}
|
||||
{{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
94
examples/tool_chat_template_mistral_parallel.jinja
Normal file
94
examples/tool_chat_template_mistral_parallel.jinja
Normal file
@ -0,0 +1,94 @@
|
||||
{%- if messages[0]["role"] == "system" %}
|
||||
{%- set system_message = messages[0]["content"] %}
|
||||
{%- set loop_messages = messages[1:] %}
|
||||
{%- else %}
|
||||
{%- set loop_messages = messages %}
|
||||
{%- endif %}
|
||||
{%- if not tools is defined %}
|
||||
{%- set tools = none %}
|
||||
{%- endif %}
|
||||
{%- if tools is defined %}
|
||||
{%- set parallel_tool_prompt = "You are a helpful assistant that can call tools. If you call one or more tools, format them in a single JSON array or objects, where each object is a tool call, not as separate objects outside of an array or multiple arrays. Use the format [{\"name\": tool call name, \"arguments\": tool call arguments}, additional tool calls] if you call more than one tool. If you call tools, do not attempt to interpret them or otherwise provide a response until you receive a tool call result that you can interpret for the user." %}
|
||||
{%- if system_message is defined %}
|
||||
{%- set system_message = parallel_tool_prompt + "\n\n" + system_message %}
|
||||
{%- else %}
|
||||
{%- set system_message = parallel_tool_prompt %}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- set user_messages = loop_messages | selectattr("role", "equalto", "user") | list %}
|
||||
|
||||
{%- for message in loop_messages | rejectattr("role", "equalto", "tool") | rejectattr("role", "equalto", "tool_results") | selectattr("tool_calls", "undefined") %}
|
||||
{%- if (message["role"] == "user") != (loop.index0 % 2 == 0) %}
|
||||
{{- raise_exception("After the optional system message, conversation roles must alternate user/assistant/user/assistant/...") }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
|
||||
{{- bos_token }}
|
||||
{%- for message in loop_messages %}
|
||||
{%- if message["role"] == "user" %}
|
||||
{%- if tools is not none and (message == user_messages[-1]) %}
|
||||
{{- "[AVAILABLE_TOOLS] [" }}
|
||||
{%- for tool in tools %}
|
||||
{%- set tool = tool.function %}
|
||||
{{- '{"type": "function", "function": {' }}
|
||||
{%- for key, val in tool.items() if key != "return" %}
|
||||
{%- if val is string %}
|
||||
{{- '"' + key + '": "' + val + '"' }}
|
||||
{%- else %}
|
||||
{{- '"' + key + '": ' + val|tojson }}
|
||||
{%- endif %}
|
||||
{%- if not loop.last %}
|
||||
{{- ", " }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{{- "}}" }}
|
||||
{%- if not loop.last %}
|
||||
{{- ", " }}
|
||||
{%- else %}
|
||||
{{- "]" }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{{- "[/AVAILABLE_TOOLS]" }}
|
||||
{%- endif %}
|
||||
{%- if loop.last and system_message is defined %}
|
||||
{{- "[INST] " + system_message + "\n\n" + message["content"] + "[/INST]" }}
|
||||
{%- else %}
|
||||
{{- "[INST] " + message["content"] + "[/INST]" }}
|
||||
{%- endif %}
|
||||
{%- elif message["role"] == "tool_calls" or message.tool_calls is defined %}
|
||||
{%- if message.tool_calls is defined %}
|
||||
{%- set tool_calls = message.tool_calls %}
|
||||
{%- else %}
|
||||
{%- set tool_calls = message.content %}
|
||||
{%- endif %}
|
||||
{{- "[TOOL_CALLS] [" }}
|
||||
{%- for tool_call in tool_calls %}
|
||||
{%- set out = tool_call.function|tojson %}
|
||||
{{- out[:-1] }}
|
||||
{%- if not tool_call.id is defined or tool_call.id|length < 9 %}
|
||||
{{- raise_exception("Tool call IDs should be alphanumeric strings with length >= 9! (1)" + tool_call.id) }}
|
||||
{%- endif %}
|
||||
{{- ', "id": "' + tool_call.id[-9:] + '"}' }}
|
||||
{%- if not loop.last %}
|
||||
{{- ", " }}
|
||||
{%- else %}
|
||||
{{- "]" + eos_token }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- elif message["role"] == "assistant" %}
|
||||
{{- " " + message["content"] + eos_token }}
|
||||
{%- elif message["role"] == "tool_results" or message["role"] == "tool" %}
|
||||
{%- if message.content is defined and message.content.content is defined %}
|
||||
{%- set content = message.content.content %}
|
||||
{%- else %}
|
||||
{%- set content = message.content %}
|
||||
{%- endif %}
|
||||
{{- '[TOOL_RESULTS] {"content": ' + content|string + ", " }}
|
||||
{%- if not message.tool_call_id is defined or message.tool_call_id|length < 9 %}
|
||||
{{- raise_exception("Tool call IDs should be alphanumeric strings with length >= 9! (2)" + message.tool_call_id) }}
|
||||
{%- endif %}
|
||||
{{- '"call_id": "' + message.tool_call_id[-9:] + '"}[/TOOL_RESULTS]' }}
|
||||
{%- else %}
|
||||
{{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
@ -99,7 +99,6 @@ echo 'vLLM mypy:'
|
||||
mypy --follow-imports skip # Note that this is less strict than CI
|
||||
mypy tests --follow-imports skip
|
||||
mypy vllm/attention --follow-imports skip
|
||||
mypy vllm/core --follow-imports skip
|
||||
mypy vllm/distributed --follow-imports skip
|
||||
mypy vllm/engine --follow-imports skip
|
||||
mypy vllm/executor --follow-imports skip
|
||||
|
@ -58,6 +58,7 @@ files = [
|
||||
"vllm/adapter_commons",
|
||||
"vllm/assets",
|
||||
"vllm/entrypoints",
|
||||
"vllm/core",
|
||||
"vllm/inputs",
|
||||
"vllm/logging",
|
||||
"vllm/multimodal",
|
||||
@ -75,7 +76,7 @@ exclude = [
|
||||
|
||||
[tool.codespell]
|
||||
ignore-words-list = "dout, te, indicies, subtile"
|
||||
skip = "./tests/prompts,./benchmarks/sonnet.txt,./tests/lora/data,./build"
|
||||
skip = "./tests/models/fixtures,./tests/prompts,./benchmarks/sonnet.txt,./tests/lora/data,./build"
|
||||
|
||||
[tool.isort]
|
||||
use_parentheses = true
|
||||
|
@ -1,3 +0,0 @@
|
||||
# Dependencies for Ray accelerated DAG
|
||||
cupy-cuda12x
|
||||
ray >= 2.32
|
@ -7,11 +7,11 @@ py-cpuinfo
|
||||
transformers >= 4.43.2 # Required for Chameleon and Llama 3.1 hotfox.
|
||||
tokenizers >= 0.19.1 # Required for Llama 3.
|
||||
protobuf # Required by LlamaTokenizer.
|
||||
fastapi
|
||||
fastapi >= 0.114.1
|
||||
aiohttp
|
||||
openai >= 1.0 # Ensure modern openai package (ensure types module present)
|
||||
openai >= 1.40.0 # Ensure modern openai package (ensure types module present)
|
||||
uvicorn[standard]
|
||||
pydantic >= 2.8 # Required for OpenAI server.
|
||||
pydantic >= 2.9 # Required for fastapi >= 0.113.0
|
||||
pillow # Required for image processing
|
||||
prometheus_client >= 0.18.0
|
||||
prometheus-fastapi-instrumentator >= 7.0.0
|
||||
@ -20,9 +20,12 @@ lm-format-enforcer == 0.10.6
|
||||
outlines >= 0.0.43, < 0.1 # Requires torch >= 2.1.0
|
||||
typing_extensions >= 4.10
|
||||
filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4
|
||||
partial-json-parser # used for parsing partial JSON outputs
|
||||
pyzmq
|
||||
msgspec
|
||||
librosa # Required for audio processing
|
||||
soundfile # Required for audio processing
|
||||
gguf == 0.9.1
|
||||
importlib_metadata
|
||||
mistral_common >= 1.4.0
|
||||
pyyaml
|
||||
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12
|
||||
einops # Required for Qwen2-VL.
|
||||
|
@ -1,3 +0,0 @@
|
||||
# Mamba dependencies
|
||||
mamba-ssm>=1.2.2
|
||||
causal-conv1d>=1.2.0
|
@ -8,3 +8,4 @@ botocore
|
||||
ray >= 2.10.0
|
||||
peft
|
||||
pytest-asyncio
|
||||
tensorizer>=2.9.0
|
@ -1,6 +1,3 @@
|
||||
# Needed for Ray accelerated DAG tests
|
||||
-r requirements-adag.txt
|
||||
|
||||
# testing
|
||||
pytest
|
||||
tensorizer>=2.9.0
|
||||
@ -11,12 +8,15 @@ pytest-shard
|
||||
|
||||
# testing utils
|
||||
awscli
|
||||
einops # required for MPT and qwen-vl
|
||||
einops # required for MPT, qwen-vl and Mamba
|
||||
httpx
|
||||
librosa # required for audio test
|
||||
opencv-python # required for video test
|
||||
peft
|
||||
requests
|
||||
ray
|
||||
ray[adag]>=2.35
|
||||
sentence-transformers # required for embedding
|
||||
soundfile # required for audio test
|
||||
compressed-tensors==0.4.0 # required for compressed-tensors
|
||||
timm # required for internvl test
|
||||
transformers_stream_generator # required for qwen-vl test
|
||||
|
@ -4,4 +4,4 @@
|
||||
# Dependencies for TPU
|
||||
# Currently, the TPU backend uses a nightly version of PyTorch XLA.
|
||||
# You can install the dependencies in Dockerfile.tpu.
|
||||
ray
|
||||
ray[default]
|
||||
|
8
setup.py
8
setup.py
@ -170,14 +170,17 @@ class cmake_build_ext(build_ext):
|
||||
|
||||
if is_sccache_available():
|
||||
cmake_args += [
|
||||
'-DCMAKE_C_COMPILER_LAUNCHER=sccache',
|
||||
'-DCMAKE_CXX_COMPILER_LAUNCHER=sccache',
|
||||
'-DCMAKE_CUDA_COMPILER_LAUNCHER=sccache',
|
||||
'-DCMAKE_C_COMPILER_LAUNCHER=sccache',
|
||||
'-DCMAKE_HIP_COMPILER_LAUNCHER=sccache',
|
||||
]
|
||||
elif is_ccache_available():
|
||||
cmake_args += [
|
||||
'-DCMAKE_C_COMPILER_LAUNCHER=ccache',
|
||||
'-DCMAKE_CXX_COMPILER_LAUNCHER=ccache',
|
||||
'-DCMAKE_CUDA_COMPILER_LAUNCHER=ccache',
|
||||
'-DCMAKE_HIP_COMPILER_LAUNCHER=ccache',
|
||||
]
|
||||
|
||||
# Pass the python executable to cmake so it can find an exact
|
||||
@ -362,6 +365,7 @@ def get_vllm_version() -> str:
|
||||
version = find_version(get_path("vllm", "version.py"))
|
||||
|
||||
if _no_device():
|
||||
if envs.VLLM_TARGET_DEVICE == "empty":
|
||||
version += "+empty"
|
||||
elif _is_cuda():
|
||||
cuda_version = str(get_nvcc_cuda_version())
|
||||
@ -501,6 +505,8 @@ setup(
|
||||
ext_modules=ext_modules,
|
||||
extras_require={
|
||||
"tensorizer": ["tensorizer>=2.9.0"],
|
||||
"video": ["opencv-python"], # Required for video processing
|
||||
"audio": ["librosa", "soundfile"] # Required for audio processing
|
||||
},
|
||||
cmdclass={"build_ext": cmake_build_ext} if len(ext_modules) > 0 else {},
|
||||
package_data=package_data,
|
||||
|
@ -1,4 +1,3 @@
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
@ -26,8 +25,7 @@ def _query_server_long(prompt: str) -> dict:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api_server(tokenizer_pool_size: int, engine_use_ray: bool,
|
||||
worker_use_ray: bool):
|
||||
def api_server(tokenizer_pool_size: int, worker_use_ray: bool):
|
||||
script_path = Path(__file__).parent.joinpath(
|
||||
"api_server_async_engine.py").absolute()
|
||||
commands = [
|
||||
@ -37,25 +35,17 @@ def api_server(tokenizer_pool_size: int, engine_use_ray: bool,
|
||||
str(tokenizer_pool_size)
|
||||
]
|
||||
|
||||
# Copy the environment variables and append `VLLM_ALLOW_ENGINE_USE_RAY=1`
|
||||
# to prevent `--engine-use-ray` raises an exception due to it deprecation
|
||||
env_vars = os.environ.copy()
|
||||
env_vars["VLLM_ALLOW_ENGINE_USE_RAY"] = "1"
|
||||
|
||||
if engine_use_ray:
|
||||
commands.append("--engine-use-ray")
|
||||
if worker_use_ray:
|
||||
commands.append("--worker-use-ray")
|
||||
uvicorn_process = subprocess.Popen(commands, env=env_vars)
|
||||
uvicorn_process = subprocess.Popen(commands)
|
||||
yield
|
||||
uvicorn_process.terminate()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tokenizer_pool_size", [0, 2])
|
||||
@pytest.mark.parametrize("worker_use_ray", [False, True])
|
||||
@pytest.mark.parametrize("engine_use_ray", [False, True])
|
||||
def test_api_server(api_server, tokenizer_pool_size: int, worker_use_ray: bool,
|
||||
engine_use_ray: bool):
|
||||
def test_api_server(api_server, tokenizer_pool_size: int,
|
||||
worker_use_ray: bool):
|
||||
"""
|
||||
Run the API server and test it.
|
||||
|
||||
|
@ -1,8 +1,10 @@
|
||||
import asyncio
|
||||
import os
|
||||
import uuid
|
||||
from asyncio import CancelledError
|
||||
from copy import copy
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
@ -12,6 +14,7 @@ from vllm import SamplingParams
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.engine.async_llm_engine import AsyncEngineArgs, AsyncLLMEngine
|
||||
from vllm.outputs import RequestOutput as RealRequestOutput
|
||||
from vllm.sampling_params import RequestOutputKind
|
||||
|
||||
from ..conftest import cleanup
|
||||
from ..utils import wait_for_gpu_memory_to_clear
|
||||
@ -72,14 +75,12 @@ class MockEngine:
|
||||
|
||||
|
||||
class MockAsyncLLMEngine(AsyncLLMEngine):
|
||||
|
||||
def _init_engine(self, *args, **kwargs):
|
||||
return MockEngine()
|
||||
_engine_class = MockEngine
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_requests_event():
|
||||
engine = MockAsyncLLMEngine(worker_use_ray=False, engine_use_ray=False)
|
||||
engine = MockAsyncLLMEngine(worker_use_ray=False)
|
||||
engine.start_background_loop()
|
||||
await asyncio.sleep(0.01)
|
||||
assert engine.engine.step_calls == 0
|
||||
@ -112,16 +113,11 @@ async def test_new_requests_event():
|
||||
assert engine.engine.add_request_calls == 3
|
||||
assert engine.engine.step_calls == old_step_calls + 1
|
||||
|
||||
# Allow deprecated engine_use_ray to not raise exception
|
||||
os.environ["VLLM_ALLOW_ENGINE_USE_RAY"] = "1"
|
||||
|
||||
engine = MockAsyncLLMEngine(worker_use_ray=True, engine_use_ray=True)
|
||||
engine = MockAsyncLLMEngine(worker_use_ray=True)
|
||||
assert engine.get_model_config() is not None
|
||||
assert engine.get_tokenizer() is not None
|
||||
assert engine.get_decoding_config() is not None
|
||||
|
||||
os.environ.pop("VLLM_ALLOW_ENGINE_USE_RAY")
|
||||
|
||||
|
||||
def start_engine():
|
||||
wait_for_gpu_memory_to_clear(
|
||||
@ -130,8 +126,17 @@ def start_engine():
|
||||
timeout_s=60,
|
||||
)
|
||||
|
||||
num_scheduler_steps = int(os.getenv("NUM_SCHEDULER_STEPS", "1"))
|
||||
print(f"Starting engine with num_scheduler_steps={num_scheduler_steps}")
|
||||
|
||||
return AsyncLLMEngine.from_engine_args(
|
||||
AsyncEngineArgs(model="facebook/opt-125m", enforce_eager=True))
|
||||
AsyncEngineArgs(model="facebook/opt-125m",
|
||||
enforce_eager=True,
|
||||
num_scheduler_steps=num_scheduler_steps))
|
||||
|
||||
|
||||
def uid() -> str:
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="module")
|
||||
@ -156,57 +161,177 @@ def should_do_global_cleanup_after_test(request) -> bool:
|
||||
@pytest.mark.asyncio(scope="module")
|
||||
async def test_asyncio_run(async_engine):
|
||||
|
||||
scheduler_config = await async_engine.get_scheduler_config()
|
||||
num_scheduler_steps = scheduler_config.num_scheduler_steps
|
||||
|
||||
async def run(prompt: str):
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0,
|
||||
max_tokens=32,
|
||||
min_tokens=32,
|
||||
)
|
||||
|
||||
output_count = 0
|
||||
final_output = None
|
||||
async for output in async_engine.generate(prompt,
|
||||
sampling_params,
|
||||
request_id=prompt):
|
||||
request_id=uid()):
|
||||
output_count += 1
|
||||
final_output = output
|
||||
return final_output
|
||||
return final_output, output_count
|
||||
|
||||
results = await asyncio.gather(
|
||||
run("test0"),
|
||||
run("test1"),
|
||||
run("test0"),
|
||||
)
|
||||
assert len(results) == 2
|
||||
first, second = results
|
||||
|
||||
# remove nondeterministic fields for comparison
|
||||
first[0].metrics = None
|
||||
second[0].metrics = None
|
||||
first[0].request_id = None
|
||||
second[0].request_id = None
|
||||
|
||||
assert str(first) == str(second)
|
||||
|
||||
output_count = results[0][1]
|
||||
if num_scheduler_steps == 1:
|
||||
assert output_count == 32
|
||||
else:
|
||||
assert 1 < output_count < 32
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="module")
|
||||
async def test_output_kinds(async_engine):
|
||||
"""Test that output_kind works as expected and that
|
||||
results are equivalent across different kinds."""
|
||||
|
||||
scheduler_config = await async_engine.get_scheduler_config()
|
||||
num_scheduler_steps = scheduler_config.num_scheduler_steps
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0,
|
||||
max_tokens=32,
|
||||
min_tokens=32,
|
||||
)
|
||||
|
||||
async def run(prompt: str, kind: RequestOutputKind):
|
||||
params = copy(sampling_params)
|
||||
params.output_kind = kind
|
||||
|
||||
output_count = 0
|
||||
final_output = None
|
||||
async for output in async_engine.generate(prompt,
|
||||
params,
|
||||
request_id=uid()):
|
||||
output_count += 1
|
||||
final_output = output
|
||||
|
||||
assert final_output is not None
|
||||
return (final_output.prompt_token_ids,
|
||||
final_output.outputs[0].token_ids,
|
||||
final_output.outputs[0].text, output_count)
|
||||
|
||||
async def run_deltas(prompt: str):
|
||||
params = copy(sampling_params)
|
||||
params.output_kind = RequestOutputKind.DELTA
|
||||
|
||||
prompt_tokens = None
|
||||
output_tokens: List[int] = []
|
||||
output_text = ""
|
||||
output_count = 0
|
||||
async for output in async_engine.generate(prompt,
|
||||
params,
|
||||
request_id=uid()):
|
||||
token_ids = output.outputs[0].token_ids
|
||||
text = output.outputs[0].text
|
||||
|
||||
# Ensure we get prompt ids iff we haven't yet received output tokens
|
||||
if output_tokens:
|
||||
assert 1 <= len(token_ids) <= num_scheduler_steps
|
||||
assert text
|
||||
assert not output.prompt_token_ids
|
||||
else:
|
||||
assert output.prompt_token_ids
|
||||
prompt_tokens = output.prompt_token_ids
|
||||
|
||||
output_tokens.extend(token_ids)
|
||||
output_text += text
|
||||
|
||||
output_count += 1
|
||||
return prompt_tokens, output_tokens, output_text, output_count
|
||||
|
||||
results = await asyncio.gather(
|
||||
run("common input prompt", RequestOutputKind.CUMULATIVE),
|
||||
run("common input prompt", RequestOutputKind.FINAL_ONLY),
|
||||
run_deltas("common input prompt"))
|
||||
|
||||
# Make sure outputs are the same
|
||||
prompt_set = set(tuple(prompt_ids) for prompt_ids, _, _, _ in results)
|
||||
assert len(prompt_set) == 1
|
||||
|
||||
text_set = set(text for _, _, text, _ in results)
|
||||
assert len(text_set) == 1
|
||||
|
||||
tokens_set = set(tuple(ids) for _, ids, _, _ in results)
|
||||
assert len(tokens_set) == 1
|
||||
|
||||
cumulative, final, deltas = results
|
||||
|
||||
# output message counts
|
||||
assert cumulative[3] == deltas[3]
|
||||
|
||||
if num_scheduler_steps == 1:
|
||||
assert cumulative[3] == 32
|
||||
else:
|
||||
assert 1 < cumulative[3] < 32
|
||||
|
||||
assert final[3] == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="module")
|
||||
async def test_cancellation(async_engine):
|
||||
scheduler_config = await async_engine.get_scheduler_config()
|
||||
num_scheduler_steps = scheduler_config.num_scheduler_steps
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0,
|
||||
min_tokens=10,
|
||||
max_tokens=10,
|
||||
min_tokens=13,
|
||||
max_tokens=13,
|
||||
)
|
||||
|
||||
stop_at = 5 if num_scheduler_steps == 1 else 1
|
||||
|
||||
request_id = uid()
|
||||
|
||||
i = 0
|
||||
with pytest.raises(CancelledError):
|
||||
async for output in async_engine.generate("test2",
|
||||
sampling_params,
|
||||
request_id="test2"):
|
||||
request_id=request_id):
|
||||
assert not output.finished
|
||||
i += 1
|
||||
if i == 5:
|
||||
await async_engine.abort("test2")
|
||||
if i == stop_at:
|
||||
await async_engine.abort(request_id)
|
||||
|
||||
assert i == 5
|
||||
assert i == stop_at
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="module")
|
||||
async def test_delayed_generator(async_engine):
|
||||
scheduler_config = await async_engine.get_scheduler_config()
|
||||
|
||||
if scheduler_config.num_scheduler_steps != 1:
|
||||
pytest.skip("no need to test this one with multistep")
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0,
|
||||
min_tokens=10,
|
||||
max_tokens=10,
|
||||
)
|
||||
|
||||
stream = async_engine.generate("test3",
|
||||
sampling_params,
|
||||
request_id="test3")
|
||||
stream = async_engine.generate("test3", sampling_params, request_id=uid())
|
||||
i = 0
|
||||
final_output: Optional[RealRequestOutput] = None
|
||||
async for output in stream:
|
||||
|
@ -1,6 +1,7 @@
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.chat_utils import apply_chat_template, load_chat_template
|
||||
from vllm.entrypoints.chat_utils import (apply_hf_chat_template,
|
||||
load_chat_template)
|
||||
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
@ -87,7 +88,7 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
|
||||
add_generation_prompt=add_generation_prompt)
|
||||
|
||||
# Call the function and get the result
|
||||
result = apply_chat_template(
|
||||
result = apply_hf_chat_template(
|
||||
tokenizer,
|
||||
conversation=mock_request.messages,
|
||||
chat_template=mock_request.chat_template or template_content,
|
||||
|
@ -1,5 +1,6 @@
|
||||
import openai # use the official client for correctness check
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from ..utils import VLLM_PATH, RemoteOpenAIServer
|
||||
|
||||
@ -18,22 +19,18 @@ def server():
|
||||
"--max-model-len",
|
||||
"2048",
|
||||
"--enforce-eager",
|
||||
"--engine-use-ray",
|
||||
"--chat-template",
|
||||
str(chatml_jinja_path),
|
||||
]
|
||||
|
||||
# Allow `--engine-use-ray`, otherwise the launch of the server throw
|
||||
# an error due to try to use a deprecated feature
|
||||
env_dict = {"VLLM_ALLOW_ENGINE_USE_RAY": "1"}
|
||||
with RemoteOpenAIServer(MODEL_NAME, args,
|
||||
env_dict=env_dict) as remote_server:
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def client(server):
|
||||
return server.get_async_client()
|
||||
@pytest_asyncio.fixture
|
||||
async def client(server):
|
||||
async with server.get_async_client() as async_client:
|
||||
yield async_client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
@ -3,12 +3,16 @@
|
||||
Run `pytest tests/basic_correctness/test_basic_correctness.py`.
|
||||
"""
|
||||
import os
|
||||
import pickle
|
||||
import re
|
||||
import weakref
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import LLM
|
||||
from vllm.utils import is_hip
|
||||
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
|
||||
|
||||
from ..models.utils import check_outputs_equal
|
||||
|
||||
@ -64,3 +68,29 @@ def test_models(
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
)
|
||||
|
||||
|
||||
def test_model_with_failure(vllm_runner) -> None:
|
||||
try:
|
||||
with patch("vllm.model_executor.models.opt.OPTForCausalLM.forward",
|
||||
side_effect=ValueError()):
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
vllm_runner("facebook/opt-125m",
|
||||
dtype="half",
|
||||
enforce_eager=False,
|
||||
gpu_memory_utilization=0.7)
|
||||
matches = re.search(r"input dumped to (.+).pkl",
|
||||
str(exc_info.value))
|
||||
assert matches is not None
|
||||
filename = f"{matches.group(1)}.pkl"
|
||||
|
||||
with open(filename, "rb") as filep:
|
||||
inputs = pickle.load(filep)
|
||||
|
||||
if any(key not in inputs for key in ("arg_1", "arg_2", "arg_3")):
|
||||
raise AssertionError("Missing keys in dumped inputs. Dumped keys: "
|
||||
f"{list(inputs.keys())}")
|
||||
assert isinstance(inputs["arg_1"],
|
||||
ModelInputForGPUWithSamplingMetadata)
|
||||
finally:
|
||||
os.remove(filename)
|
||||
|
@ -6,6 +6,7 @@ prefill requests are chunked.
|
||||
|
||||
Run `pytest tests/models/test_chunked_prefill.py`.
|
||||
"""
|
||||
from contextlib import nullcontext
|
||||
|
||||
import pytest
|
||||
|
||||
@ -15,18 +16,6 @@ MODELS = [
|
||||
"facebook/opt-125m",
|
||||
"meta-llama/Llama-2-7b-hf",
|
||||
]
|
||||
E5M2_KV_MODELS = [
|
||||
"facebook/opt-125m",
|
||||
"meta-llama/Llama-2-7b-chat-hf",
|
||||
]
|
||||
E4M3_KV_MODELS = [
|
||||
"meta-llama/Llama-2-7b-chat-hf", "nm-testing/Qwen2-1.5B-Instruct-FP8-K-V",
|
||||
"nm-testing/TinyLlama-1.1B-compressed-tensors-kv-cache-scheme"
|
||||
]
|
||||
KV_CACHE_QUANTIZATION_PATHS = {
|
||||
"meta-llama/Llama-2-7b-chat-hf":
|
||||
"./tests/fp8_kv/llama2-7b-fp8-kv/kv_cache_scales.json"
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@ -77,10 +66,10 @@ def test_models(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("kv_cache_dtype,model",
|
||||
[("fp8_e5m2", m)
|
||||
for m in E5M2_KV_MODELS] + [("fp8_e4m3", m)
|
||||
for m in E4M3_KV_MODELS])
|
||||
@pytest.mark.parametrize(
|
||||
"kv_cache_dtype,model",
|
||||
[("fp8_e4m3",
|
||||
"nm-testing/TinyLlama-1.1B-compressed-tensors-kv-cache-scheme")])
|
||||
# Due to low-precision numerical divergence, we only test logprob of 4 tokens
|
||||
@pytest.mark.parametrize("max_tokens", [4])
|
||||
@pytest.mark.parametrize("chunked_prefill_token_size", [4, 16])
|
||||
@ -88,6 +77,9 @@ def test_models(
|
||||
# NOTE: Increasing this in this suite will fail CI because we currently cannot
|
||||
# reset distributed env properly. Use a value > 1 just when you test.
|
||||
@pytest.mark.parametrize("tensor_parallel_size", [1])
|
||||
# Due to low-precision numerical divergence, this test is too sensitive to
|
||||
# the async postprocessor
|
||||
@pytest.mark.parametrize("disable_async_output_proc", [True])
|
||||
def test_models_with_fp8_kv_cache(
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
@ -97,36 +89,25 @@ def test_models_with_fp8_kv_cache(
|
||||
chunked_prefill_token_size: int,
|
||||
enforce_eager: bool,
|
||||
tensor_parallel_size: int,
|
||||
disable_async_output_proc: bool,
|
||||
) -> None:
|
||||
"""
|
||||
Only checks log probs match between chunked-prefill and
|
||||
non-chunked-prefill version of vLLM model runner.
|
||||
|
||||
This test is used when there is discrepancy in kernels
|
||||
/ numerics (e.g. when using lower-precision types like FP8).
|
||||
Check output logprobs match between no_chunked_prefill and chunked_prefill
|
||||
with fp8 kv cache. General fp8 kv-cache tests are covered in test_fp8.py,
|
||||
so here we only check chunked prefill.
|
||||
"""
|
||||
NUM_LOG_PROBS = 8
|
||||
|
||||
if model == "facebook/opt-125m":
|
||||
pytest.skip(
|
||||
"#7378: CUDA illegal memory access (undiagnosed) facebook/opt-125m"
|
||||
)
|
||||
|
||||
max_num_seqs = chunked_prefill_token_size
|
||||
max_num_batched_tokens = chunked_prefill_token_size
|
||||
|
||||
extra_kwargs = {}
|
||||
if model in KV_CACHE_QUANTIZATION_PATHS:
|
||||
extra_kwargs["quantization_param_path"] = KV_CACHE_QUANTIZATION_PATHS[
|
||||
model]
|
||||
|
||||
with vllm_runner(
|
||||
model,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
enforce_eager=enforce_eager,
|
||||
max_num_seqs=max_num_seqs,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
**extra_kwargs,
|
||||
disable_async_output_proc=disable_async_output_proc,
|
||||
) as vllm_model:
|
||||
no_chunked_prefill_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, NUM_LOG_PROBS)
|
||||
@ -139,7 +120,7 @@ def test_models_with_fp8_kv_cache(
|
||||
enforce_eager=enforce_eager,
|
||||
max_num_seqs=max_num_seqs,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
**extra_kwargs,
|
||||
disable_async_output_proc=disable_async_output_proc,
|
||||
) as vllm_model:
|
||||
chunked_prefill_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, NUM_LOG_PROBS)
|
||||
@ -150,3 +131,68 @@ def test_models_with_fp8_kv_cache(
|
||||
name_0="no_chunked_prefill",
|
||||
name_1="chunked_prefill",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("max_tokens", [16])
|
||||
@pytest.mark.parametrize("enforce_eager", [False])
|
||||
@pytest.mark.parametrize("chunk_size", [30, 32])
|
||||
@pytest.mark.parametrize("use_v2_block_manager", [False, True])
|
||||
# NOTE: Increasing this in this suite will fail CI because we currently cannot
|
||||
# reset distributed env properly. Use a value > 1 just when you test.
|
||||
@pytest.mark.parametrize("tensor_parallel_size", [1])
|
||||
def test_with_prefix_caching(
|
||||
vllm_runner,
|
||||
max_tokens: int,
|
||||
enforce_eager: bool,
|
||||
chunk_size: int,
|
||||
use_v2_block_manager: bool,
|
||||
tensor_parallel_size: int,
|
||||
) -> None:
|
||||
"""
|
||||
Checks exact match decode with and without prefix caching
|
||||
with chunked prefill enabled.
|
||||
"""
|
||||
model = "meta-llama/Llama-2-7b-chat-hf"
|
||||
# The common prompt has 142 tokens with Llama-2 tokenizer.
|
||||
common_prompt = "You are a helpful AI assistant " * 20
|
||||
unique_prompts = [
|
||||
"Question", # Warmup
|
||||
"Question", # Fully cached
|
||||
"Another question", # Partial cached
|
||||
]
|
||||
full_prompts = [f"{common_prompt}\n{p}" for p in unique_prompts]
|
||||
|
||||
max_num_batched_tokens = max_num_seqs = chunk_size
|
||||
outputs = {} # type: ignore
|
||||
check_result = True
|
||||
for enable in (True, False):
|
||||
with vllm_runner(
|
||||
model,
|
||||
dtype="half",
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
enable_chunked_prefill=True,
|
||||
enable_prefix_caching=enable,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
use_v2_block_manager=use_v2_block_manager,
|
||||
enforce_eager=enforce_eager,
|
||||
max_num_seqs=max_num_seqs,
|
||||
) as vllm_model:
|
||||
# It should fail when prefix caching is enable and chunk
|
||||
# size is not a multiple of block size (16).
|
||||
should_fail = chunk_size % 16 != 0 and enable
|
||||
check_result &= not should_fail
|
||||
outputs[enable] = []
|
||||
# Send the request one-by-one to ensure the cache is populated.
|
||||
with pytest.raises(ValueError) if should_fail else nullcontext():
|
||||
for prompt in full_prompts:
|
||||
outputs[enable] += vllm_model.generate_greedy([prompt],
|
||||
max_tokens)
|
||||
|
||||
# Check results only if we did not expect a failure.
|
||||
if check_result:
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=outputs[False],
|
||||
outputs_1_lst=outputs[True],
|
||||
name_0="w/o prefix caching",
|
||||
name_1="with prefix caching",
|
||||
)
|
||||
|
@ -64,6 +64,7 @@ def test_chunked_prefill_recompute(
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
max_num_seqs=max_num_seqs,
|
||||
worker_use_ray=worker_use_ray,
|
||||
disable_log_stats=False,
|
||||
) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt
|
||||
@ -209,7 +210,6 @@ def test_swap_infeasible(
|
||||
prefill_blocks = 2
|
||||
decode_blocks = max_tokens // BLOCK_SIZE
|
||||
example_prompts = example_prompts[:1]
|
||||
|
||||
with vllm_runner(
|
||||
model,
|
||||
dtype=dtype,
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user