mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
Compare commits
113 Commits
Author | SHA1 | Date | |
---|---|---|---|
32e7db2536 | |||
008cf886c9 | |||
77d9e514a2 | |||
e02ce498be | |||
561d6f8077 | |||
d1dec64243 | |||
2ad2e5608e | |||
d3311562fb | |||
ccd7207191 | |||
855c262a6b | |||
2be8ec6e71 | |||
e16fa99a6a | |||
61f4a93d14 | |||
d4db9f53c8 | |||
2188a60c7e | |||
dc0b6066ab | |||
0af3abe3d3 | |||
f1575dc99f | |||
c02638efb3 | |||
652c83b697 | |||
6d646d08a2 | |||
95a178f861 | |||
bd852f2a8b | |||
ec266536b7 | |||
0fbc6696c2 | |||
6e36f4fa6c | |||
dd2a6a82e3 | |||
4ca65a9763 | |||
e2b2aa5a0f | |||
e6a26ed037 | |||
f8d60145b4 | |||
5b86b19954 | |||
5231f0898e | |||
8423aef4c8 | |||
4f5d8446ed | |||
d05f0a9db2 | |||
622f8abff8 | |||
1248e8506a | |||
2684efc467 | |||
058344f89a | |||
98cef6a227 | |||
f97be32d1d | |||
afd39a4511 | |||
2148441fd3 | |||
dc13e99348 | |||
34a0e96d46 | |||
80c7b089b1 | |||
428dd1445e | |||
4abed65c58 | |||
0c785d344d | |||
4664ceaad6 | |||
257afc37c5 | |||
86a677de42 | |||
d78789ac16 | |||
c334b1898b | |||
6b3421567d | |||
3f60f2244e | |||
f205c09854 | |||
ef99a78760 | |||
74d5543ec5 | |||
a7f65c2be9 | |||
4289cad37f | |||
af59df0a10 | |||
ce6bf3a2cf | |||
3cdfe1f38b | |||
fdd9daafa3 | |||
8c56e57def | |||
eeffde1ac0 | |||
e5697d161c | |||
b98cc28f91 | |||
ef9baee3c5 | |||
98c12cffe5 | |||
f52a43a8b9 | |||
e3580537a4 | |||
f508e03e7f | |||
51f86bf487 | |||
c166e7e43e | |||
bc6e42a9b1 | |||
fab5f53e2d | |||
9c71c97ae2 | |||
5340a2dccf | |||
345be0e244 | |||
fc911880cc | |||
ed6f002d33 | |||
b09c755be8 | |||
42e932c7d4 | |||
076169f603 | |||
9db642138b | |||
6fc4e6e07a | |||
9606c7197d | |||
64cc644425 | |||
39178c7fbc | |||
2eedede875 | |||
015e6cc252 | |||
760e9f71a8 | |||
05826c887b | |||
dd9857f5fa | |||
665304092d | |||
2deb029d11 | |||
029c71de11 | |||
0b769992ec | |||
1856aff4d6 | |||
70c094ade6 | |||
2059b8d9ca | |||
8aaf3d5347 | |||
80162c44b1 | |||
aab0fcdb63 | |||
ea9fa160e3 | |||
7d9ffa2ae1 | |||
d81abefd2e | |||
8da48e4d95 | |||
6885fde317 | |||
9db93de20c |
@ -1,36 +1,43 @@
|
||||
import os
|
||||
import sys
|
||||
import zipfile
|
||||
|
||||
MAX_SIZE_MB = 250
|
||||
# Read the VLLM_MAX_SIZE_MB environment variable, defaulting to 250 MB
|
||||
VLLM_MAX_SIZE_MB = int(os.environ.get('VLLM_MAX_SIZE_MB', 250))
|
||||
|
||||
|
||||
def print_top_10_largest_files(zip_file):
|
||||
"""Print the top 10 largest files in the given zip file."""
|
||||
with zipfile.ZipFile(zip_file, 'r') as z:
|
||||
file_sizes = [(f, z.getinfo(f).file_size) for f in z.namelist()]
|
||||
file_sizes.sort(key=lambda x: x[1], reverse=True)
|
||||
for f, size in file_sizes[:10]:
|
||||
print(f"{f}: {size/(1024*1024)} MBs uncompressed.")
|
||||
print(f"{f}: {size / (1024 * 1024):.2f} MBs uncompressed.")
|
||||
|
||||
|
||||
def check_wheel_size(directory):
|
||||
"""Check the size of .whl files in the given directory."""
|
||||
for root, _, files in os.walk(directory):
|
||||
for f in files:
|
||||
if f.endswith(".whl"):
|
||||
wheel_path = os.path.join(root, f)
|
||||
wheel_size = os.path.getsize(wheel_path)
|
||||
wheel_size_mb = wheel_size / (1024 * 1024)
|
||||
if wheel_size_mb > MAX_SIZE_MB:
|
||||
print(
|
||||
f"Wheel {wheel_path} is too large ({wheel_size_mb} MB) "
|
||||
f"compare to the allowed size ({MAX_SIZE_MB} MB).")
|
||||
for file_name in files:
|
||||
if file_name.endswith(".whl"):
|
||||
wheel_path = os.path.join(root, file_name)
|
||||
wheel_size_mb = os.path.getsize(wheel_path) / (1024 * 1024)
|
||||
if wheel_size_mb > VLLM_MAX_SIZE_MB:
|
||||
print(f"Not allowed: Wheel {wheel_path} is larger "
|
||||
f"({wheel_size_mb:.2f} MB) than the limit "
|
||||
f"({VLLM_MAX_SIZE_MB} MB).")
|
||||
print_top_10_largest_files(wheel_path)
|
||||
return 1
|
||||
else:
|
||||
print(f"Wheel {wheel_path} is within the allowed size "
|
||||
f"({wheel_size_mb} MB).")
|
||||
f"({wheel_size_mb:.2f} MB).")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
sys.exit(check_wheel_size(sys.argv[1]))
|
||||
if len(sys.argv) < 2:
|
||||
print("Usage: python check-wheel-size.py <directory>")
|
||||
sys.exit(1)
|
||||
|
||||
directory = sys.argv[1]
|
||||
sys.exit(check_wheel_size(directory))
|
@ -1,5 +1,4 @@
|
||||
Meta-Llama-3-8B-Instruct.yaml
|
||||
Meta-Llama-3-8B-Instruct-FP8.yaml
|
||||
Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml
|
||||
Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml
|
||||
Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml
|
||||
|
46
.buildkite/run-amd-test.sh
Normal file → Executable file
46
.buildkite/run-amd-test.sh
Normal file → Executable file
@ -1,5 +1,5 @@
|
||||
# This script runs test inside the corresponding ROCm docker container.
|
||||
set -ex
|
||||
set -o pipefail
|
||||
|
||||
# Print ROCm version
|
||||
echo "--- Confirming Clean Initial State"
|
||||
@ -70,15 +70,51 @@ HF_CACHE="$(realpath ~)/huggingface"
|
||||
mkdir -p ${HF_CACHE}
|
||||
HF_MOUNT="/root/.cache/huggingface"
|
||||
|
||||
docker run \
|
||||
commands=$@
|
||||
PARALLEL_JOB_COUNT=8
|
||||
# check if the command contains shard flag, we will run all shards in parallel because the host have 8 GPUs.
|
||||
if [[ $commands == *"--shard-id="* ]]; then
|
||||
for GPU in $(seq 0 $(($PARALLEL_JOB_COUNT-1))); do
|
||||
#replace shard arguments
|
||||
commands=${@//"--shard-id= "/"--shard-id=${GPU} "}
|
||||
commands=${commands//"--num-shards= "/"--num-shards=${PARALLEL_JOB_COUNT} "}
|
||||
docker run \
|
||||
--device /dev/kfd --device /dev/dri \
|
||||
--network host \
|
||||
--shm-size=16gb \
|
||||
--rm \
|
||||
-e HIP_VISIBLE_DEVICES=${GPU} \
|
||||
-e HF_TOKEN \
|
||||
-v ${HF_CACHE}:${HF_MOUNT} \
|
||||
-e HF_HOME=${HF_MOUNT} \
|
||||
--name ${container_name} \
|
||||
--name ${container_name}_${GPU} \
|
||||
${image_name} \
|
||||
/bin/bash -c "${@}"
|
||||
|
||||
/bin/bash -c "${commands}" \
|
||||
|& while read -r line; do echo ">>Shard $GPU: $line"; done &
|
||||
PIDS+=($!)
|
||||
done
|
||||
#wait for all processes to finish and collect exit codes
|
||||
for pid in ${PIDS[@]}; do
|
||||
wait ${pid}
|
||||
STATUS+=($?)
|
||||
done
|
||||
for st in ${STATUS[@]}; do
|
||||
if [[ ${st} -ne 0 ]]; then
|
||||
echo "One of the processes failed with $st"
|
||||
exit ${st}
|
||||
fi
|
||||
done
|
||||
else
|
||||
docker run \
|
||||
--device /dev/kfd --device /dev/dri \
|
||||
--network host \
|
||||
--shm-size=16gb \
|
||||
--rm \
|
||||
-e HIP_VISIBLE_DEVICES=0 \
|
||||
-e HF_TOKEN \
|
||||
-v ${HF_CACHE}:${HF_MOUNT} \
|
||||
-e HF_HOME=${HF_MOUNT} \
|
||||
--name ${container_name} \
|
||||
${image_name} \
|
||||
/bin/bash -c "${commands}"
|
||||
fi
|
||||
|
@ -23,7 +23,12 @@ docker exec cpu-test-avx2 bash -c "python3 examples/offline_inference.py"
|
||||
# Run basic model test
|
||||
docker exec cpu-test bash -c "
|
||||
pip install pytest matplotlib einops transformers_stream_generator
|
||||
pytest -v -s tests/models -m \"not vlm\" --ignore=tests/models/test_embedding.py --ignore=tests/models/test_oot_registration.py --ignore=tests/models/test_registry.py --ignore=tests/models/test_jamba.py --ignore=tests/models/test_danube3_4b.py" # Mamba and Danube3-4B on CPU is not supported
|
||||
pytest -v -s tests/models -m \"not vlm\" --ignore=tests/models/test_embedding.py \
|
||||
--ignore=tests/models/test_oot_registration.py \
|
||||
--ignore=tests/models/test_registry.py \
|
||||
--ignore=tests/models/test_fp8.py \
|
||||
--ignore=tests/models/test_jamba.py \
|
||||
--ignore=tests/models/test_danube3_4b.py" # Mamba and Danube3-4B on CPU is not supported
|
||||
|
||||
# online inference
|
||||
docker exec cpu-test bash -c "
|
||||
|
@ -12,5 +12,4 @@ remove_docker_container
|
||||
# For HF_TOKEN.
|
||||
source /etc/environment
|
||||
# Run a simple end-to-end example.
|
||||
docker run --privileged --net host --shm-size=16G -it -e HF_TOKEN=$HF_TOKEN --name tpu-test vllm-tpu \
|
||||
python3 /workspace/vllm/examples/offline_inference_tpu.py
|
||||
docker run --privileged --net host --shm-size=16G -it -e HF_TOKEN=$HF_TOKEN --name tpu-test vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git && python3 -m pip install pytest && pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py && python3 /workspace/vllm/tests/tpu/test_compilation.py && python3 /workspace/vllm/examples/offline_inference_tpu.py"
|
||||
|
@ -87,8 +87,11 @@ steps:
|
||||
commands:
|
||||
- pip install -e ./plugins/vllm_add_dummy_model
|
||||
- pip install git+https://github.com/EleutherAI/lm-evaluation-harness.git@a4987bba6e9e9b3f22bd3a6c1ecf0abd04fd5622#egg=lm_eval[api]
|
||||
- pytest -v -s entrypoints/llm
|
||||
- pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py
|
||||
- pytest -v -s entrypoints/llm/test_lazy_outlines.py # it needs a clean process
|
||||
- pytest -v -s entrypoints/openai
|
||||
- pytest -v -s entrypoints/test_chat_utils.py
|
||||
|
||||
|
||||
- label: Distributed Tests (4 GPUs) # 10min
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
@ -172,6 +175,7 @@ steps:
|
||||
- vllm/
|
||||
commands:
|
||||
- pytest -v -s ./compile/test_full_graph.py
|
||||
- pytest -v -s ./compile/test_wrapper.py
|
||||
|
||||
|
||||
- label: Vision Language Models Test # 42min
|
||||
@ -215,9 +219,9 @@ steps:
|
||||
- pytest -v -s spec_decode
|
||||
|
||||
- label: LoRA Test %N # 30min each
|
||||
mirror_hardwares: [amd]
|
||||
source_file_dependencies:
|
||||
- vllm/lora
|
||||
- csrc/punica
|
||||
- tests/lora
|
||||
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py
|
||||
parallelism: 4
|
||||
@ -232,12 +236,13 @@ steps:
|
||||
parallelism: 4
|
||||
|
||||
- label: Tensorizer Test # 11min
|
||||
mirror_hardwares: [amd]
|
||||
soft_fail: true
|
||||
source_file_dependencies:
|
||||
- vllm/model_executor/model_loader
|
||||
- tests/tensorizer_loader
|
||||
commands:
|
||||
- apt-get install -y curl libsodium23
|
||||
- apt-get update && apt-get install -y curl libsodium23
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
- pytest -v -s tensorizer_loader
|
||||
|
||||
@ -267,6 +272,15 @@ steps:
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
- bash ./run-tests.sh -c configs/models-small.txt -t 1
|
||||
|
||||
- label: OpenAI-Compatible Tool Use # 20 min
|
||||
fast_check: false
|
||||
mirror_hardwares: [ amd ]
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/tool_use
|
||||
commands:
|
||||
- pytest -v -s tool_use
|
||||
|
||||
##### 1 GPU test #####
|
||||
##### multi gpus test #####
|
||||
|
||||
@ -335,7 +349,8 @@ steps:
|
||||
- vllm/engine
|
||||
- tests/multi_step
|
||||
commands:
|
||||
- pytest -v -s multi_step/test_correctness.py
|
||||
- pytest -v -s multi_step/test_correctness_async_llm.py
|
||||
- pytest -v -s multi_step/test_correctness_llm.py
|
||||
|
||||
- label: Pipeline Parallelism Test # 23min
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
@ -355,7 +370,6 @@ steps:
|
||||
num_gpus: 4
|
||||
source_file_dependencies:
|
||||
- vllm/lora
|
||||
- csrc/punica
|
||||
- tests/lora/test_long_context
|
||||
commands:
|
||||
# FIXIT: find out which code initialize cuda before running the test
|
||||
|
23
.github/workflows/add_label_ready_comment.yml
vendored
23
.github/workflows/add_label_ready_comment.yml
vendored
@ -1,23 +0,0 @@
|
||||
name: Add Ready Label on Ready Comment
|
||||
|
||||
on:
|
||||
issue_comment:
|
||||
types: [created]
|
||||
|
||||
jobs:
|
||||
add-ready-label:
|
||||
runs-on: ubuntu-latest
|
||||
if: github.event.issue.pull_request && contains(github.event.comment.body, '/ready')
|
||||
steps:
|
||||
- name: Add label
|
||||
uses: actions/github-script@v5
|
||||
with:
|
||||
script: |
|
||||
github.rest.issues.addLabels({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: context.issue.number,
|
||||
labels: ['ready']
|
||||
})
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
1
.github/workflows/mypy.yaml
vendored
1
.github/workflows/mypy.yaml
vendored
@ -35,7 +35,6 @@ jobs:
|
||||
mypy
|
||||
mypy tests --follow-imports skip
|
||||
mypy vllm/attention --follow-imports skip
|
||||
mypy vllm/core --follow-imports skip
|
||||
mypy vllm/distributed --follow-imports skip
|
||||
mypy vllm/engine --follow-imports skip
|
||||
mypy vllm/executor --follow-imports skip
|
||||
|
2
.github/workflows/reminder_comment.yml
vendored
2
.github/workflows/reminder_comment.yml
vendored
@ -15,7 +15,7 @@ jobs:
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: context.issue.number,
|
||||
body: '👋 Hi! Thank you for contributing to the vLLM project.\n Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run `fastcheck` CI which consists a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of default ones by unblocking the steps in your `fast-check` build on Buildkite UI. \n\nOnce the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge).\n\n To run full CI, you can do one of these:\n- Comment `/ready` on the PR\n- Add `ready` label to the PR\n- Enable auto-merge.\n\n🚀'
|
||||
body: '👋 Hi! Thank you for contributing to the vLLM project.\n Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run `fastcheck` CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your `fastcheck` build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping `simon-mo` or `khluu` to add you in our Buildkite org. \n\nOnce the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.\n\n To run CI, PR reviewers can do one of these:\n- Add `ready` label to the PR\n- Enable auto-merge.\n\n🚀'
|
||||
})
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
@ -1,23 +0,0 @@
|
||||
name: Remove ready Label on notready Comment
|
||||
|
||||
on:
|
||||
issue_comment:
|
||||
types: [created]
|
||||
|
||||
jobs:
|
||||
add-ready-label:
|
||||
runs-on: ubuntu-latest
|
||||
if: github.event.issue.pull_request && contains(github.event.comment.body, '/notready')
|
||||
steps:
|
||||
- name: Remove ready label
|
||||
uses: actions/github-script@v5
|
||||
with:
|
||||
script: |
|
||||
github.rest.issues.removeLabel({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: context.issue.number,
|
||||
name: 'ready'
|
||||
})
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
@ -203,6 +203,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
FetchContent_MakeAvailable(cutlass)
|
||||
|
||||
list(APPEND VLLM_EXT_SRC
|
||||
"csrc/mamba/mamba_ssm/selective_scan_fwd.cu"
|
||||
"csrc/mamba/causal_conv1d/causal_conv1d.cu"
|
||||
"csrc/quantization/aqlm/gemm_kernels.cu"
|
||||
"csrc/quantization/awq/gemm_kernels.cu"
|
||||
"csrc/quantization/marlin/dense/marlin_cuda_kernel.cu"
|
||||
@ -296,6 +298,11 @@ set(VLLM_MOE_EXT_SRC
|
||||
"csrc/moe/torch_bindings.cpp"
|
||||
"csrc/moe/topk_softmax_kernels.cu")
|
||||
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
list(APPEND VLLM_MOE_EXT_SRC
|
||||
"csrc/moe/marlin_moe_ops.cu")
|
||||
endif()
|
||||
|
||||
define_gpu_extension_target(
|
||||
_moe_C
|
||||
DESTINATION vllm
|
||||
|
38
Dockerfile
38
Dockerfile
@ -42,9 +42,6 @@ COPY requirements-cuda.txt requirements-cuda.txt
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
python3 -m pip install -r requirements-cuda.txt
|
||||
|
||||
COPY requirements-mamba.txt requirements-mamba.txt
|
||||
RUN python3 -m pip install packaging
|
||||
RUN python3 -m pip install -r requirements-mamba.txt
|
||||
|
||||
# cuda arch list used by torch
|
||||
# can be useful for both `dev` and `test`
|
||||
@ -111,10 +108,17 @@ RUN --mount=type=cache,target=/root/.cache/ccache \
|
||||
python3 setup.py bdist_wheel --dist-dir=dist --py-limited-api=cp38; \
|
||||
fi
|
||||
|
||||
# check the size of the wheel, we cannot upload wheels larger than 100MB
|
||||
# Check the size of the wheel if RUN_WHEEL_CHECK is true
|
||||
COPY .buildkite/check-wheel-size.py check-wheel-size.py
|
||||
RUN python3 check-wheel-size.py dist
|
||||
|
||||
# Default max size of the wheel is 250MB
|
||||
ARG VLLM_MAX_SIZE_MB=250
|
||||
ENV VLLM_MAX_SIZE_MB=$VLLM_MAX_SIZE_MB
|
||||
ARG RUN_WHEEL_CHECK=true
|
||||
RUN if [ "$RUN_WHEEL_CHECK" = "true" ]; then \
|
||||
python3 check-wheel-size.py dist; \
|
||||
else \
|
||||
echo "Skipping wheel size check."; \
|
||||
fi
|
||||
#################### EXTENSION Build IMAGE ####################
|
||||
|
||||
#################### DEV IMAGE ####################
|
||||
@ -127,22 +131,6 @@ RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
python3 -m pip install -r requirements-dev.txt
|
||||
|
||||
#################### DEV IMAGE ####################
|
||||
#################### MAMBA Build IMAGE ####################
|
||||
FROM dev as mamba-builder
|
||||
# max jobs used for build
|
||||
ARG max_jobs=2
|
||||
ENV MAX_JOBS=${max_jobs}
|
||||
|
||||
WORKDIR /usr/src/mamba
|
||||
|
||||
COPY requirements-mamba.txt requirements-mamba.txt
|
||||
|
||||
# Download the wheel or build it if a pre-compiled release doesn't exist
|
||||
RUN pip --verbose wheel -r requirements-mamba.txt \
|
||||
--no-build-isolation --no-deps --no-cache-dir
|
||||
|
||||
#################### MAMBA Build IMAGE ####################
|
||||
|
||||
#################### vLLM installation IMAGE ####################
|
||||
# image with vLLM installed
|
||||
FROM nvidia/cuda:${CUDA_VERSION}-base-ubuntu20.04 AS vllm-base
|
||||
@ -179,13 +167,9 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist
|
||||
--mount=type=cache,target=/root/.cache/pip \
|
||||
python3 -m pip install dist/*.whl --verbose
|
||||
|
||||
RUN --mount=type=bind,from=mamba-builder,src=/usr/src/mamba,target=/usr/src/mamba \
|
||||
--mount=type=cache,target=/root/.cache/pip \
|
||||
python3 -m pip install /usr/src/mamba/*.whl --no-cache-dir
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
. /etc/environment && \
|
||||
python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.4/flashinfer-0.1.4+cu121torch2.4-cp${PYTHON_VERSION_STR}-cp${PYTHON_VERSION_STR}-linux_x86_64.whl
|
||||
python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.6/flashinfer-0.1.6+cu121torch2.4-cp${PYTHON_VERSION_STR}-cp${PYTHON_VERSION_STR}-linux_x86_64.whl
|
||||
#################### vLLM installation IMAGE ####################
|
||||
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
ARG NIGHTLY_DATE="20240808"
|
||||
ARG NIGHTLY_DATE="20240828"
|
||||
ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_$NIGHTLY_DATE"
|
||||
|
||||
FROM $BASE_IMAGE
|
||||
|
@ -56,20 +56,27 @@ class BenchmarkMetrics:
|
||||
total_input: int
|
||||
total_output: int
|
||||
request_throughput: float
|
||||
input_throughput: float
|
||||
output_throughput: float
|
||||
total_token_throughput: float
|
||||
mean_ttft_ms: float
|
||||
median_ttft_ms: float
|
||||
std_ttft_ms: float
|
||||
p99_ttft_ms: float
|
||||
percentiles_ttft_ms: List[Tuple[float, float]]
|
||||
mean_tpot_ms: float
|
||||
median_tpot_ms: float
|
||||
std_tpot_ms: float
|
||||
p99_tpot_ms: float
|
||||
percentiles_tpot_ms: List[Tuple[float, float]]
|
||||
mean_itl_ms: float
|
||||
median_itl_ms: float
|
||||
std_itl_ms: float
|
||||
p99_itl_ms: float
|
||||
percentiles_itl_ms: List[Tuple[float, float]]
|
||||
# E2EL stands for end-to-end latency per request.
|
||||
# It is the time taken on the client side from sending
|
||||
# a request to receiving a complete response.
|
||||
mean_e2el_ms: float
|
||||
median_e2el_ms: float
|
||||
std_e2el_ms: float
|
||||
percentiles_e2el_ms: List[Tuple[float, float]]
|
||||
|
||||
|
||||
def sample_sharegpt_requests(
|
||||
@ -235,6 +242,8 @@ def calculate_metrics(
|
||||
outputs: List[RequestFuncOutput],
|
||||
dur_s: float,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
selected_percentile_metrics: List[str],
|
||||
selected_percentiles: List[float],
|
||||
) -> Tuple[BenchmarkMetrics, List[int]]:
|
||||
actual_output_lens: List[int] = []
|
||||
total_input = 0
|
||||
@ -242,6 +251,7 @@ def calculate_metrics(
|
||||
itls: List[float] = []
|
||||
tpots: List[float] = []
|
||||
ttfts: List[float] = []
|
||||
e2els: List[float] = []
|
||||
for i in range(len(outputs)):
|
||||
if outputs[i].success:
|
||||
# We use the tokenizer to count the number of output tokens for all
|
||||
@ -258,6 +268,7 @@ def calculate_metrics(
|
||||
(outputs[i].latency - outputs[i].ttft) / (output_len - 1))
|
||||
itls += outputs[i].itl
|
||||
ttfts.append(outputs[i].ttft)
|
||||
e2els.append(outputs[i].latency)
|
||||
completed += 1
|
||||
else:
|
||||
actual_output_lens.append(0)
|
||||
@ -272,21 +283,29 @@ def calculate_metrics(
|
||||
total_input=total_input,
|
||||
total_output=sum(actual_output_lens),
|
||||
request_throughput=completed / dur_s,
|
||||
input_throughput=total_input / dur_s,
|
||||
output_throughput=sum(actual_output_lens) / dur_s,
|
||||
total_token_throughput=(total_input + sum(actual_output_lens)) / dur_s,
|
||||
mean_ttft_ms=np.mean(ttfts or 0) *
|
||||
1000, # ttfts is empty if streaming is not supported by backend
|
||||
median_ttft_ms=np.median(ttfts or 0) * 1000,
|
||||
std_ttft_ms=np.std(ttfts or 0) * 1000,
|
||||
p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000,
|
||||
median_ttft_ms=np.median(ttfts or 0) * 1000,
|
||||
percentiles_ttft_ms=[(p, np.percentile(ttfts or 0, p) * 1000)
|
||||
for p in selected_percentiles],
|
||||
mean_tpot_ms=np.mean(tpots or 0) * 1000,
|
||||
median_tpot_ms=np.median(tpots or 0) * 1000,
|
||||
std_tpot_ms=np.std(tpots or 0) * 1000,
|
||||
p99_tpot_ms=np.percentile(tpots or 0, 99) * 1000,
|
||||
median_tpot_ms=np.median(tpots or 0) * 1000,
|
||||
percentiles_tpot_ms=[(p, np.percentile(tpots or 0, p) * 1000)
|
||||
for p in selected_percentiles],
|
||||
mean_itl_ms=np.mean(itls or 0) * 1000,
|
||||
median_itl_ms=np.median(itls or 0) * 1000,
|
||||
std_itl_ms=np.std(itls or 0) * 1000,
|
||||
p99_itl_ms=np.percentile(itls or 0, 99) * 1000,
|
||||
median_itl_ms=np.median(itls or 0) * 1000,
|
||||
percentiles_itl_ms=[(p, np.percentile(itls or 0, p) * 1000)
|
||||
for p in selected_percentiles],
|
||||
mean_e2el_ms=np.median(e2els or 0) * 1000,
|
||||
std_e2el_ms=np.std(e2els or 0) * 1000,
|
||||
median_e2el_ms=np.mean(e2els or 0) * 1000,
|
||||
percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000)
|
||||
for p in selected_percentiles],
|
||||
)
|
||||
|
||||
return metrics, actual_output_lens
|
||||
@ -304,6 +323,8 @@ async def benchmark(
|
||||
request_rate: float,
|
||||
disable_tqdm: bool,
|
||||
profile: bool,
|
||||
selected_percentile_metrics: List[str],
|
||||
selected_percentiles: List[str],
|
||||
):
|
||||
if backend in ASYNC_REQUEST_FUNCS:
|
||||
request_func = ASYNC_REQUEST_FUNCS[backend]
|
||||
@ -392,6 +413,8 @@ async def benchmark(
|
||||
outputs=outputs,
|
||||
dur_s=benchmark_duration,
|
||||
tokenizer=tokenizer,
|
||||
selected_percentile_metrics=selected_percentile_metrics,
|
||||
selected_percentiles=selected_percentiles,
|
||||
)
|
||||
|
||||
print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='='))
|
||||
@ -403,27 +426,10 @@ async def benchmark(
|
||||
metrics.total_output))
|
||||
print("{:<40} {:<10.2f}".format("Request throughput (req/s):",
|
||||
metrics.request_throughput))
|
||||
print("{:<40} {:<10.2f}".format("Input token throughput (tok/s):",
|
||||
metrics.input_throughput))
|
||||
print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):",
|
||||
metrics.output_throughput))
|
||||
print("{s:{c}^{n}}".format(s='Time to First Token', n=50, c='-'))
|
||||
print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms))
|
||||
print("{:<40} {:<10.2f}".format("Median TTFT (ms):",
|
||||
metrics.median_ttft_ms))
|
||||
print("{:<40} {:<10.2f}".format("P99 TTFT (ms):", metrics.p99_ttft_ms))
|
||||
print("{s:{c}^{n}}".format(s='Time per Output Token (excl. 1st token)',
|
||||
n=50,
|
||||
c='-'))
|
||||
print("{:<40} {:<10.2f}".format("Mean TPOT (ms):", metrics.mean_tpot_ms))
|
||||
print("{:<40} {:<10.2f}".format("Median TPOT (ms):",
|
||||
metrics.median_tpot_ms))
|
||||
print("{:<40} {:<10.2f}".format("P99 TPOT (ms):", metrics.p99_tpot_ms))
|
||||
print("{s:{c}^{n}}".format(s='Inter-token Latency', n=50, c='-'))
|
||||
print("{:<40} {:<10.2f}".format("Mean ITL (ms):", metrics.mean_itl_ms))
|
||||
print("{:<40} {:<10.2f}".format("Median ITL (ms):", metrics.median_itl_ms))
|
||||
print("{:<40} {:<10.2f}".format("P99 ITL (ms):", metrics.p99_itl_ms))
|
||||
print("=" * 50)
|
||||
print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):",
|
||||
metrics.total_token_throughput))
|
||||
|
||||
result = {
|
||||
"duration": benchmark_duration,
|
||||
@ -431,20 +437,8 @@ async def benchmark(
|
||||
"total_input_tokens": metrics.total_input,
|
||||
"total_output_tokens": metrics.total_output,
|
||||
"request_throughput": metrics.request_throughput,
|
||||
"input_throughput": metrics.input_throughput,
|
||||
"output_throughput": metrics.output_throughput,
|
||||
"mean_ttft_ms": metrics.mean_ttft_ms,
|
||||
"median_ttft_ms": metrics.median_ttft_ms,
|
||||
"std_ttft_ms": metrics.std_ttft_ms,
|
||||
"p99_ttft_ms": metrics.p99_ttft_ms,
|
||||
"mean_tpot_ms": metrics.mean_tpot_ms,
|
||||
"median_tpot_ms": metrics.median_tpot_ms,
|
||||
"std_tpot_ms": metrics.std_tpot_ms,
|
||||
"p99_tpot_ms": metrics.p99_tpot_ms,
|
||||
"mean_itl_ms": metrics.mean_itl_ms,
|
||||
"median_itl_ms": metrics.median_itl_ms,
|
||||
"std_itl_ms": metrics.std_itl_ms,
|
||||
"p99_itl_ms": metrics.p99_itl_ms,
|
||||
"total_token_throughput": metrics.total_token_throughput,
|
||||
"input_lens": [output.prompt_len for output in outputs],
|
||||
"output_lens": actual_output_lens,
|
||||
"ttfts": [output.ttft for output in outputs],
|
||||
@ -452,6 +446,47 @@ async def benchmark(
|
||||
"generated_texts": [output.generated_text for output in outputs],
|
||||
"errors": [output.error for output in outputs],
|
||||
}
|
||||
|
||||
def process_one_metric(
|
||||
# E.g., "ttft"
|
||||
metric_attribute_name: str,
|
||||
# E.g., "TTFT"
|
||||
metric_name: str,
|
||||
# E.g., "Time to First Token"
|
||||
metric_header: str,
|
||||
):
|
||||
# This function print and add statistics of the specified
|
||||
# metric.
|
||||
if metric_attribute_name not in selected_percentile_metrics:
|
||||
return
|
||||
print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-'))
|
||||
print("{:<40} {:<10.2f}".format(
|
||||
f"Mean {metric_name} (ms):",
|
||||
getattr(metrics, f"mean_{metric_attribute_name}_ms")))
|
||||
print("{:<40} {:<10.2f}".format(
|
||||
f"Median {metric_name} (ms):",
|
||||
getattr(metrics, f"median_{metric_attribute_name}_ms")))
|
||||
result[f"mean_{metric_attribute_name}_ms"] = getattr(
|
||||
metrics, f"mean_{metric_attribute_name}_ms")
|
||||
result[f"median_{metric_attribute_name}_ms"] = getattr(
|
||||
metrics, f"median_{metric_attribute_name}_ms")
|
||||
result[f"std_{metric_attribute_name}_ms"] = getattr(
|
||||
metrics, f"std_{metric_attribute_name}_ms")
|
||||
for p, value in getattr(metrics,
|
||||
f"percentiles_{metric_attribute_name}_ms"):
|
||||
p_word = str(int(p)) if int(p) == p else str(p)
|
||||
print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):",
|
||||
value))
|
||||
result[f"p{p_word}_{metric_attribute_name}_ms"] = value
|
||||
|
||||
process_one_metric("ttft", "TTFT", "Time to First Token")
|
||||
process_one_metric("tpot", "TPOT",
|
||||
"Time per Output Token (excl. 1st token)")
|
||||
process_one_metric("itl", "ITL", "Inter-token Latency")
|
||||
process_one_metric("e2el", "E2EL", "End-to-end Latency")
|
||||
|
||||
print("=" * 50)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@ -550,6 +585,10 @@ def main(args: argparse.Namespace):
|
||||
request_rate=args.request_rate,
|
||||
disable_tqdm=args.disable_tqdm,
|
||||
profile=args.profile,
|
||||
selected_percentile_metrics=args.percentile_metrics.split(","),
|
||||
selected_percentiles=[
|
||||
float(p) for p in args.metric_percentiles.split(",")
|
||||
],
|
||||
))
|
||||
|
||||
# Save config and results to json
|
||||
@ -765,6 +804,23 @@ if __name__ == "__main__":
|
||||
"{backend}-{args.request_rate}qps-{base_model_id}-{current_dt}.json"
|
||||
" format.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--percentile-metrics",
|
||||
type=str,
|
||||
default="ttft,tpot,itl",
|
||||
help="Comma-seperated list of selected metrics to report percentils. "
|
||||
"This argument specifies the metrics to report percentiles. "
|
||||
"Allowed metric names are \"ttft\", \"tpot\", \"itl\", \"e2el\". "
|
||||
"Default value is \"ttft,tpot,itl\".")
|
||||
parser.add_argument(
|
||||
"--metric-percentiles",
|
||||
type=str,
|
||||
default="99",
|
||||
help="Comma-seperated list of percentiles for selected metrics. "
|
||||
"To report 25-th, 50-th, and 75-th percentiles, use \"25,50,75\". "
|
||||
"Default value is \"99\". "
|
||||
"Use \"--percentile-metrics\" to select metrics.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
@ -6,13 +6,16 @@ import time
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import uvloop
|
||||
from tqdm import tqdm
|
||||
from transformers import (AutoModelForCausalLM, AutoTokenizer,
|
||||
PreTrainedTokenizerBase)
|
||||
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
|
||||
from vllm.entrypoints.openai.api_server import (
|
||||
build_async_engine_client_from_engine_args)
|
||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
from vllm.utils import FlexibleArgumentParser, merge_async_iterators
|
||||
|
||||
|
||||
def sample_requests(
|
||||
@ -82,8 +85,11 @@ def run_vllm(
|
||||
max_num_batched_tokens: int,
|
||||
distributed_executor_backend: Optional[str],
|
||||
gpu_memory_utilization: float = 0.9,
|
||||
num_scheduler_steps: int = 1,
|
||||
use_v2_block_manager: bool = False,
|
||||
download_dir: Optional[str] = None,
|
||||
load_format: str = EngineArgs.load_format,
|
||||
disable_async_output_proc: bool = False,
|
||||
) -> float:
|
||||
from vllm import LLM, SamplingParams
|
||||
llm = LLM(
|
||||
@ -106,6 +112,9 @@ def run_vllm(
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
load_format=load_format,
|
||||
num_scheduler_steps=num_scheduler_steps,
|
||||
use_v2_block_manager=use_v2_block_manager,
|
||||
disable_async_output_proc=disable_async_output_proc,
|
||||
)
|
||||
|
||||
# Add the requests to the engine.
|
||||
@ -129,6 +138,93 @@ def run_vllm(
|
||||
return end - start
|
||||
|
||||
|
||||
async def run_vllm_async(
|
||||
requests: List[Tuple[str, int, int]],
|
||||
model: str,
|
||||
tokenizer: str,
|
||||
quantization: Optional[str],
|
||||
tensor_parallel_size: int,
|
||||
seed: int,
|
||||
n: int,
|
||||
use_beam_search: bool,
|
||||
trust_remote_code: bool,
|
||||
dtype: str,
|
||||
max_model_len: Optional[int],
|
||||
enforce_eager: bool,
|
||||
kv_cache_dtype: str,
|
||||
quantization_param_path: Optional[str],
|
||||
device: str,
|
||||
enable_prefix_caching: bool,
|
||||
enable_chunked_prefill: bool,
|
||||
max_num_batched_tokens: int,
|
||||
distributed_executor_backend: Optional[str],
|
||||
gpu_memory_utilization: float = 0.9,
|
||||
num_scheduler_steps: int = 1,
|
||||
use_v2_block_manager: bool = False,
|
||||
download_dir: Optional[str] = None,
|
||||
load_format: str = EngineArgs.load_format,
|
||||
disable_async_output_proc: bool = False,
|
||||
disable_frontend_multiprocessing: bool = False,
|
||||
) -> float:
|
||||
from vllm import SamplingParams
|
||||
engine_args = AsyncEngineArgs(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
quantization=quantization,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
seed=seed,
|
||||
trust_remote_code=trust_remote_code,
|
||||
dtype=dtype,
|
||||
max_model_len=max_model_len,
|
||||
gpu_memory_utilization=gpu_memory_utilization,
|
||||
enforce_eager=enforce_eager,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
quantization_param_path=quantization_param_path,
|
||||
device=device,
|
||||
enable_prefix_caching=enable_prefix_caching,
|
||||
download_dir=download_dir,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
load_format=load_format,
|
||||
num_scheduler_steps=num_scheduler_steps,
|
||||
use_v2_block_manager=use_v2_block_manager,
|
||||
disable_async_output_proc=disable_async_output_proc,
|
||||
worker_use_ray=False,
|
||||
engine_use_ray=False,
|
||||
disable_log_requests=True,
|
||||
)
|
||||
|
||||
async with build_async_engine_client_from_engine_args(
|
||||
engine_args, disable_frontend_multiprocessing) as llm:
|
||||
|
||||
# Add the requests to the engine.
|
||||
prompts: List[str] = []
|
||||
sampling_params: List[SamplingParams] = []
|
||||
for prompt, _, output_len in requests:
|
||||
prompts.append(prompt)
|
||||
sampling_params.append(
|
||||
SamplingParams(
|
||||
n=n,
|
||||
temperature=0.0 if use_beam_search else 1.0,
|
||||
top_p=1.0,
|
||||
use_beam_search=use_beam_search,
|
||||
ignore_eos=True,
|
||||
max_tokens=output_len,
|
||||
))
|
||||
|
||||
generators = []
|
||||
start = time.perf_counter()
|
||||
for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)):
|
||||
generator = llm.generate(prompt, sp, request_id=f"test{i}")
|
||||
generators.append(generator)
|
||||
all_gens = merge_async_iterators(*generators)
|
||||
async for i, res in all_gens:
|
||||
pass
|
||||
end = time.perf_counter()
|
||||
return end - start
|
||||
|
||||
|
||||
def run_hf(
|
||||
requests: List[Tuple[str, int, int]],
|
||||
model: str,
|
||||
@ -224,7 +320,7 @@ def main(args: argparse.Namespace):
|
||||
args.output_len)
|
||||
|
||||
if args.backend == "vllm":
|
||||
elapsed_time = run_vllm(
|
||||
run_args = [
|
||||
requests, args.model, args.tokenizer, args.quantization,
|
||||
args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
|
||||
args.trust_remote_code, args.dtype, args.max_model_len,
|
||||
@ -232,7 +328,16 @@ def main(args: argparse.Namespace):
|
||||
args.quantization_param_path, args.device,
|
||||
args.enable_prefix_caching, args.enable_chunked_prefill,
|
||||
args.max_num_batched_tokens, args.distributed_executor_backend,
|
||||
args.gpu_memory_utilization, args.download_dir, args.load_format)
|
||||
args.gpu_memory_utilization, args.num_scheduler_steps,
|
||||
args.use_v2_block_manager, args.download_dir, args.load_format,
|
||||
args.disable_async_output_proc
|
||||
]
|
||||
|
||||
if args.async_engine:
|
||||
run_args.append(args.disable_frontend_multiprocessing)
|
||||
elapsed_time = uvloop.run(run_vllm_async(*run_args))
|
||||
else:
|
||||
elapsed_time = run_vllm(*run_args)
|
||||
elif args.backend == "hf":
|
||||
assert args.tensor_parallel_size == 1
|
||||
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
|
||||
@ -353,10 +458,18 @@ if __name__ == "__main__":
|
||||
choices=["auto", "cuda", "cpu", "openvino", "tpu", "xpu"],
|
||||
help='device type for vLLM execution, supporting CUDA, OpenVINO and '
|
||||
'CPU.')
|
||||
parser.add_argument(
|
||||
"--num-scheduler-steps",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Maximum number of forward steps per scheduler call.")
|
||||
parser.add_argument("--use-v2-block-manager",
|
||||
action='store_true',
|
||||
help="Enable block manager v2.")
|
||||
parser.add_argument(
|
||||
"--enable-prefix-caching",
|
||||
action='store_true',
|
||||
help="enable automatic prefix caching for vLLM backend.")
|
||||
help="Enable automatic prefix caching for vLLM backend.")
|
||||
parser.add_argument("--enable-chunked-prefill",
|
||||
action='store_true',
|
||||
help="enable chunked prefill for vLLM backend.")
|
||||
@ -405,6 +518,19 @@ if __name__ == "__main__":
|
||||
'section for more information.\n'
|
||||
'* "bitsandbytes" will load the weights using bitsandbytes '
|
||||
'quantization.\n')
|
||||
parser.add_argument(
|
||||
"--disable-async-output-proc",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="Disable async output processor for vLLM backend.")
|
||||
parser.add_argument("--async-engine",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="Use vLLM async engine rather than LLM class.")
|
||||
parser.add_argument("--disable-frontend-multiprocessing",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="Disable decoupled async engine frontend.")
|
||||
args = parser.parse_args()
|
||||
if args.tokenizer is None:
|
||||
args.tokenizer = args.model
|
||||
|
@ -6,7 +6,7 @@ TOKENS=$2
|
||||
|
||||
docker run -e HF_TOKEN=$HF_TOKEN --gpus all --shm-size 1g -p $PORT:80 \
|
||||
-v $PWD/data:/data \
|
||||
ghcr.io/huggingface/text-generation-inference:1.4.0 \
|
||||
ghcr.io/huggingface/text-generation-inference:2.2.0 \
|
||||
--model-id $MODEL \
|
||||
--sharded false \
|
||||
--max-input-length 1024 \
|
||||
|
@ -387,7 +387,8 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
|
||||
// This needs to be implemented and throw a TypeError in order for
|
||||
// PyTorch's opcheck to work on ops that use ScalarTypes.
|
||||
int64_t len() const {
|
||||
throw c10::TypeError("__len__ not implemented");
|
||||
throw c10::TypeError({__func__, __FILE__, static_cast<uint32_t>(__LINE__)},
|
||||
"__len__ not implemented");
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
700
csrc/mamba/causal_conv1d/causal_conv1d.cu
Normal file
700
csrc/mamba/causal_conv1d/causal_conv1d.cu
Normal file
@ -0,0 +1,700 @@
|
||||
// clang-format off
|
||||
// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d_fwd.cu
|
||||
// and https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d_update.cu
|
||||
#include <torch/all.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include "causal_conv1d.h"
|
||||
#include <c10/util/BFloat16.h>
|
||||
#include <c10/util/Half.h>
|
||||
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
|
||||
|
||||
#include <cub/block/block_load.cuh>
|
||||
#include <cub/block/block_store.cuh>
|
||||
|
||||
#include "static_switch.h"
|
||||
|
||||
|
||||
|
||||
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
|
||||
|
||||
#define DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \
|
||||
if (ITYPE == at::ScalarType::Half) { \
|
||||
using input_t = at::Half; \
|
||||
using weight_t = at::Half; \
|
||||
__VA_ARGS__(); \
|
||||
} else if (ITYPE == at::ScalarType::BFloat16) { \
|
||||
using input_t = at::BFloat16; \
|
||||
using weight_t = at::BFloat16; \
|
||||
__VA_ARGS__(); \
|
||||
} else if (ITYPE == at::ScalarType::Float) { \
|
||||
using input_t = float; \
|
||||
using weight_t = float; \
|
||||
__VA_ARGS__(); \
|
||||
} else { \
|
||||
AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
|
||||
}
|
||||
|
||||
|
||||
template<typename input_t, typename weight_t>
|
||||
void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
template <typename input_t, typename weight_t>
|
||||
void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
|
||||
template<typename input_t, typename weight_t>
|
||||
void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
|
||||
void set_conv_params_fwd(ConvParamsBase ¶ms,
|
||||
// sizes
|
||||
const size_t batch,
|
||||
const size_t dim,
|
||||
const size_t seqlen,
|
||||
const size_t width,
|
||||
// device pointers
|
||||
const at::Tensor x,
|
||||
const at::Tensor weight,
|
||||
const at::Tensor out,
|
||||
void* bias_ptr,
|
||||
bool silu_activation) {
|
||||
|
||||
// Reset the parameters
|
||||
memset(¶ms, 0, sizeof(params));
|
||||
|
||||
params.batch = batch;
|
||||
params.dim = dim;
|
||||
params.seqlen = seqlen;
|
||||
params.width = width;
|
||||
|
||||
params.silu_activation = silu_activation;
|
||||
|
||||
// Set the pointers and strides.
|
||||
params.x_ptr = x.data_ptr();
|
||||
params.weight_ptr = weight.data_ptr();
|
||||
params.bias_ptr = bias_ptr;
|
||||
params.out_ptr = out.data_ptr();
|
||||
// All stride are in elements, not bytes.
|
||||
params.x_batch_stride = x.stride(0);
|
||||
params.x_c_stride = x.stride(1);
|
||||
params.x_l_stride = x.stride(-1);
|
||||
params.weight_c_stride = weight.stride(0);
|
||||
params.weight_width_stride = weight.stride(1);
|
||||
params.out_batch_stride = out.stride(0);
|
||||
params.out_c_stride = out.stride(1);
|
||||
params.out_l_stride = out.stride(-1);
|
||||
}
|
||||
|
||||
|
||||
at::Tensor
|
||||
causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
|
||||
const c10::optional<at::Tensor> &bias_,
|
||||
const c10::optional<at::Tensor> &seq_idx_,
|
||||
const c10::optional<at::Tensor> &initial_states_,
|
||||
const c10::optional<at::Tensor> &final_states_out_,
|
||||
bool silu_activation) {
|
||||
auto input_type = x.scalar_type();
|
||||
auto weight_type = weight.scalar_type();
|
||||
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
||||
TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16);
|
||||
|
||||
TORCH_CHECK(x.is_cuda());
|
||||
TORCH_CHECK(weight.is_cuda());
|
||||
|
||||
const auto sizes = x.sizes();
|
||||
const int batch_size = sizes[0];
|
||||
const int dim = sizes[1];
|
||||
const int seqlen = sizes[2];
|
||||
const int width = weight.size(-1);
|
||||
|
||||
CHECK_SHAPE(x, batch_size, dim, seqlen);
|
||||
CHECK_SHAPE(weight, dim, width);
|
||||
|
||||
TORCH_CHECK(x.stride(2) == 1 || x.stride(1) == 1);
|
||||
const bool is_channel_last = x.stride(1) == 1 && x.stride(2) > 1;
|
||||
|
||||
if (is_channel_last) {
|
||||
TORCH_CHECK(dim % 8 == 0, "causal_conv1d only supports channel dimension divisible by 8 for now");
|
||||
TORCH_CHECK(x.stride(2) % 8 == 0 and x.stride(0) % 8 == 0, "causal_conv1d with channel last layout requires strides (x.stride(0) and x.stride(2)) to be multiples of 8");
|
||||
}
|
||||
TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4");
|
||||
|
||||
if (bias_.has_value()) {
|
||||
auto bias = bias_.value();
|
||||
TORCH_CHECK(bias.scalar_type() == weight_type);
|
||||
TORCH_CHECK(bias.is_cuda());
|
||||
TORCH_CHECK(bias.stride(-1) == 1);
|
||||
CHECK_SHAPE(bias, dim);
|
||||
}
|
||||
|
||||
if (seq_idx_.has_value()) {
|
||||
TORCH_CHECK(is_channel_last, "seq_idx is only supported for channel last layout");
|
||||
auto seq_idx = seq_idx_.value();
|
||||
TORCH_CHECK(seq_idx.scalar_type() == torch::kInt32);
|
||||
TORCH_CHECK(seq_idx.is_cuda());
|
||||
TORCH_CHECK(seq_idx.is_contiguous());
|
||||
CHECK_SHAPE(seq_idx, batch_size, seqlen);
|
||||
}
|
||||
|
||||
at::Tensor out = torch::empty_like(x);
|
||||
|
||||
ConvParamsBase params;
|
||||
set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out,
|
||||
bias_.has_value() ? bias_.value().data_ptr() : nullptr,
|
||||
silu_activation);
|
||||
|
||||
if (seq_idx_.has_value()) {
|
||||
params.seq_idx_ptr = seq_idx_.value().data_ptr();
|
||||
} else {
|
||||
params.seq_idx_ptr = nullptr;
|
||||
}
|
||||
|
||||
if (initial_states_.has_value()) {
|
||||
TORCH_CHECK(is_channel_last, "initial_states is only supported for channel last layout");
|
||||
auto initial_states = initial_states_.value();
|
||||
TORCH_CHECK(initial_states.scalar_type() == input_type);
|
||||
TORCH_CHECK(initial_states.is_cuda());
|
||||
CHECK_SHAPE(initial_states, batch_size, dim, width - 1);
|
||||
TORCH_CHECK(initial_states.stride(1) == 1);
|
||||
params.initial_states_ptr = initial_states.data_ptr();
|
||||
params.initial_states_batch_stride = initial_states.stride(0);
|
||||
params.initial_states_c_stride = initial_states.stride(1);
|
||||
params.initial_states_l_stride = initial_states.stride(2);
|
||||
} else {
|
||||
params.initial_states_ptr = nullptr;
|
||||
}
|
||||
|
||||
if (final_states_out_.has_value()) {
|
||||
TORCH_CHECK(is_channel_last, "final_states is only supported for channel last layout");
|
||||
auto final_states = final_states_out_.value();
|
||||
TORCH_CHECK(final_states.scalar_type() == input_type);
|
||||
TORCH_CHECK(final_states.is_cuda());
|
||||
CHECK_SHAPE(final_states, batch_size, dim, width - 1);
|
||||
TORCH_CHECK(final_states.stride(1) == 1);
|
||||
params.final_states_ptr = final_states.data_ptr();
|
||||
params.final_states_batch_stride = final_states.stride(0);
|
||||
params.final_states_c_stride = final_states.stride(1);
|
||||
params.final_states_l_stride = final_states.stride(2);
|
||||
} else {
|
||||
params.final_states_ptr = nullptr;
|
||||
}
|
||||
|
||||
// Otherwise the kernel will be launched from cuda:0 device
|
||||
// Cast to char to avoid compiler warning about narrowing
|
||||
at::cuda::CUDAGuard device_guard{(char)x.get_device()};
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_fwd", [&] {
|
||||
if (!is_channel_last) {
|
||||
causal_conv1d_fwd_cuda<input_t, weight_t>(params, stream);
|
||||
} else {
|
||||
causal_conv1d_channellast_fwd_cuda<input_t, weight_t>(params, stream);
|
||||
}
|
||||
});
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
at::Tensor
|
||||
causal_conv1d_update(const at::Tensor &x,
|
||||
const at::Tensor &conv_state,
|
||||
const at::Tensor &weight,
|
||||
const c10::optional<at::Tensor> &bias_,
|
||||
bool silu_activation) {
|
||||
auto input_type = x.scalar_type();
|
||||
auto weight_type = weight.scalar_type();
|
||||
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
||||
TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16);
|
||||
TORCH_CHECK(weight_type == input_type, "weight type must equal to input type, other variations are disabled due to binary size limitations");
|
||||
TORCH_CHECK(conv_state.scalar_type() == input_type);
|
||||
|
||||
TORCH_CHECK(x.is_cuda());
|
||||
TORCH_CHECK(conv_state.is_cuda());
|
||||
TORCH_CHECK(weight.is_cuda());
|
||||
|
||||
const auto sizes = x.sizes();
|
||||
const int batch_size = sizes[0];
|
||||
const int dim = sizes[1];
|
||||
const int width = weight.size(-1);
|
||||
|
||||
CHECK_SHAPE(x, batch_size, dim);
|
||||
CHECK_SHAPE(conv_state, batch_size, dim, width);
|
||||
CHECK_SHAPE(weight, dim, width);
|
||||
|
||||
TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4");
|
||||
|
||||
if (bias_.has_value()) {
|
||||
auto bias = bias_.value();
|
||||
TORCH_CHECK(bias.scalar_type() == weight_type);
|
||||
TORCH_CHECK(bias.is_cuda());
|
||||
TORCH_CHECK(bias.stride(-1) == 1);
|
||||
CHECK_SHAPE(bias, dim);
|
||||
}
|
||||
|
||||
at::Tensor out = torch::empty_like(x);
|
||||
|
||||
ConvParamsBase params;
|
||||
set_conv_params_fwd(params, batch_size, dim, /*seqlen=*/1, width, x, weight, out,
|
||||
bias_.has_value() ? bias_.value().data_ptr() : nullptr,
|
||||
silu_activation);
|
||||
params.conv_state_ptr = conv_state.data_ptr();
|
||||
// All stride are in elements, not bytes.
|
||||
params.conv_state_batch_stride = conv_state.stride(0);
|
||||
params.conv_state_c_stride = conv_state.stride(1);
|
||||
params.conv_state_l_stride = conv_state.stride(2);
|
||||
|
||||
// Otherwise the kernel will be launched from cuda:0 device
|
||||
// Cast to char to avoid compiler warning about narrowing
|
||||
at::cuda::CUDAGuard device_guard{(char)x.get_device()};
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_update", [&] {
|
||||
causal_conv1d_update_cuda<input_t, weight_t>(params, stream);
|
||||
});
|
||||
return out;
|
||||
}
|
||||
|
||||
template<int kNThreads_, int kWidth_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
|
||||
struct Causal_conv1d_fwd_kernel_traits {
|
||||
using input_t = input_t_;
|
||||
using weight_t = weight_t_;
|
||||
static constexpr int kNThreads = kNThreads_;
|
||||
static constexpr int kWidth = kWidth_;
|
||||
static constexpr int kNBytes = sizeof(input_t);
|
||||
static_assert(kNBytes == 2 || kNBytes == 4);
|
||||
static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
|
||||
static_assert(kWidth <= kNElts);
|
||||
static constexpr bool kIsVecLoad = kIsVecLoad_;
|
||||
using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
|
||||
using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNElts, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
||||
using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, 1, cub::BLOCK_LOAD_DIRECT>;
|
||||
using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNElts, cub::BLOCK_STORE_WARP_TRANSPOSE>;
|
||||
using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, 1, cub::BLOCK_STORE_DIRECT>;
|
||||
static constexpr int kSmemIOSize = kIsVecLoad
|
||||
? 0
|
||||
: custom_max({sizeof(typename BlockLoadT::TempStorage), sizeof(typename BlockStoreT::TempStorage)});
|
||||
static constexpr int kSmemExchangeSize = kNThreads * kNBytes * kNElts;
|
||||
static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize;
|
||||
};
|
||||
|
||||
template<typename Ktraits>
|
||||
__global__ __launch_bounds__(Ktraits::kNThreads)
|
||||
void causal_conv1d_fwd_kernel(ConvParamsBase params) {
|
||||
constexpr int kWidth = Ktraits::kWidth;
|
||||
constexpr int kNThreads = Ktraits::kNThreads;
|
||||
constexpr int kNElts = Ktraits::kNElts;
|
||||
static constexpr bool kIsVecLoad = Ktraits::kIsVecLoad;
|
||||
using input_t = typename Ktraits::input_t;
|
||||
using vec_t = typename Ktraits::vec_t;
|
||||
using weight_t = typename Ktraits::weight_t;
|
||||
|
||||
// Shared memory.
|
||||
extern __shared__ char smem_[];
|
||||
auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
|
||||
auto& smem_load_vec = reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(smem_);
|
||||
auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
|
||||
auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_);
|
||||
vec_t *smem_exchange = reinterpret_cast<vec_t *>(smem_ + Ktraits::kSmemIOSize);
|
||||
|
||||
const int tidx = threadIdx.x;
|
||||
const int batch_id = blockIdx.x;
|
||||
const int channel_id = blockIdx.y;
|
||||
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
|
||||
+ channel_id * params.x_c_stride;
|
||||
weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride;
|
||||
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
|
||||
+ channel_id * params.out_c_stride;
|
||||
float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[channel_id]);
|
||||
|
||||
// Thread 0 will load the last elements of the previous chunk, so we initialize those to 0.
|
||||
if (tidx == 0) {
|
||||
input_t zeros[kNElts] = {0};
|
||||
smem_exchange[kNThreads - 1] = reinterpret_cast<vec_t *>(zeros)[0];
|
||||
}
|
||||
|
||||
float weight_vals[kWidth];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); }
|
||||
|
||||
constexpr int kChunkSize = kNThreads * kNElts;
|
||||
const int n_chunks = (params.seqlen + kChunkSize - 1) / kChunkSize;
|
||||
for (int chunk = 0; chunk < n_chunks; ++chunk) {
|
||||
input_t x_vals_load[2 * kNElts] = {0};
|
||||
if constexpr(kIsVecLoad) {
|
||||
typename Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast<vec_t*>(x), *reinterpret_cast<vec_t (*)[1]>(&x_vals_load[kNElts]), (params.seqlen - chunk * kChunkSize) / kNElts);
|
||||
} else {
|
||||
__syncthreads();
|
||||
typename Ktraits::BlockLoadT(smem_load).Load(x, *reinterpret_cast<input_t (*)[kNElts]>(&x_vals_load[kNElts]), params.seqlen - chunk * kChunkSize);
|
||||
}
|
||||
x += kChunkSize;
|
||||
__syncthreads();
|
||||
// Thread kNThreads - 1 don't write yet, so that thread 0 can read
|
||||
// the last elements of the previous chunk.
|
||||
if (tidx < kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(x_vals_load)[1]; }
|
||||
__syncthreads();
|
||||
reinterpret_cast<vec_t *>(x_vals_load)[0] = smem_exchange[tidx > 0 ? tidx - 1 : kNThreads - 1];
|
||||
__syncthreads();
|
||||
// Now thread kNThreads - 1 can write the last elements of the current chunk.
|
||||
if (tidx == kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(x_vals_load)[1]; }
|
||||
|
||||
float x_vals[2 * kNElts];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 2 * kNElts; ++i) { x_vals[i] = float(x_vals_load[i]); }
|
||||
|
||||
float out_vals[kNElts];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNElts; ++i) {
|
||||
out_vals[i] = bias_val;
|
||||
#pragma unroll
|
||||
for (int w = 0; w < kWidth; ++w) {
|
||||
out_vals[i] += weight_vals[w] * x_vals[kNElts + i - (kWidth - w - 1)];
|
||||
}
|
||||
}
|
||||
|
||||
if (params.silu_activation) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNElts; ++i) {
|
||||
out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i]));
|
||||
}
|
||||
}
|
||||
|
||||
input_t out_vals_store[kNElts];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNElts; ++i) { out_vals_store[i] = out_vals[i]; }
|
||||
if constexpr(kIsVecLoad) {
|
||||
typename Ktraits::BlockStoreVecT(smem_store_vec).Store(reinterpret_cast<vec_t*>(out), reinterpret_cast<vec_t (&)[1]>(out_vals_store), (params.seqlen - chunk * kChunkSize) / kNElts);
|
||||
} else {
|
||||
typename Ktraits::BlockStoreT(smem_store).Store(out, out_vals_store, params.seqlen - chunk * kChunkSize);
|
||||
}
|
||||
out += kChunkSize;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template<int kNThreads, int kWidth, typename input_t, typename weight_t>
|
||||
void causal_conv1d_fwd_launch(ConvParamsBase ¶ms, cudaStream_t stream) {
|
||||
static constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8;
|
||||
BOOL_SWITCH(params.seqlen % kNElts == 0, kIsVecLoad, [&] {
|
||||
using Ktraits = Causal_conv1d_fwd_kernel_traits<kNThreads, kWidth, kIsVecLoad, input_t, weight_t>;
|
||||
constexpr int kSmemSize = Ktraits::kSmemSize;
|
||||
dim3 grid(params.batch, params.dim);
|
||||
|
||||
auto kernel = &causal_conv1d_fwd_kernel<Ktraits>;
|
||||
|
||||
if (kSmemSize >= 48 * 1024) {
|
||||
#ifndef USE_ROCM
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
||||
#else
|
||||
// There is a slight signature discrepancy in HIP and CUDA "FuncSetAttribute" function.
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
(void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
||||
std::cerr << "Warning (causal_conv1d fwd launch): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl;
|
||||
#endif
|
||||
}
|
||||
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
||||
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
});
|
||||
}
|
||||
|
||||
template<typename input_t, typename weight_t>
|
||||
void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream) {
|
||||
if (params.width == 2) {
|
||||
causal_conv1d_fwd_launch<128, 2, input_t, weight_t>(params, stream);
|
||||
} else if (params.width == 3) {
|
||||
causal_conv1d_fwd_launch<128, 3, input_t, weight_t>(params, stream);
|
||||
} else if (params.width == 4) {
|
||||
causal_conv1d_fwd_launch<128, 4, input_t, weight_t>(params, stream);
|
||||
}
|
||||
}
|
||||
|
||||
template<int kNThreads_, int kWidth_, int kChunkSizeL_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
|
||||
struct Causal_conv1d_channellast_fwd_kernel_traits {
|
||||
// The cache line is 128 bytes, and we try to read 16 bytes per thread.
|
||||
// So we have 8 threads per "row", so 32 or 64 elements in the channel dimension.
|
||||
// That leaves 4 columns per warp, and so 16 columns per block (assuming each block has 128
|
||||
// threads). Each each load is 16 x 32|64 elements in the L x C dimensions.
|
||||
using input_t = input_t_;
|
||||
using weight_t = weight_t_;
|
||||
static constexpr int kNThreads = kNThreads_;
|
||||
static_assert(kNThreads % 32 == 0);
|
||||
static constexpr int kNWarps = kNThreads / 32;
|
||||
static constexpr int kWidth = kWidth_;
|
||||
static constexpr int kChunkSizeL = kChunkSizeL_;
|
||||
static constexpr int kNBytes = sizeof(input_t);
|
||||
static_assert(kNBytes == 2 || kNBytes == 4);
|
||||
static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
|
||||
static constexpr int kNEltsPerRow = 128 / kNBytes;
|
||||
static constexpr int kNThreadsPerRow = kNEltsPerRow / kNElts; // Always 8 for now
|
||||
static_assert(kNThreadsPerRow * kNBytes * kNElts == 128);
|
||||
static constexpr int kNColsPerWarp = 32 / kNThreadsPerRow; // Always 4 for now
|
||||
static_assert(kNColsPerWarp * kNThreadsPerRow == 32);
|
||||
static constexpr int kNColsPerLoad = kNColsPerWarp * kNWarps;
|
||||
static constexpr int kNLoads = kChunkSizeL / kNColsPerLoad;
|
||||
static_assert(kNLoads * kNColsPerLoad == kChunkSizeL);
|
||||
static constexpr bool kIsVecLoad = kIsVecLoad_;
|
||||
using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
|
||||
// using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
||||
// using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
|
||||
// static constexpr int kSmemSize = std::max({sizeof(typename BlockLoadT::TempStorage),
|
||||
// sizeof(typename BlockStoreT::TempStorage)});
|
||||
// static constexpr int kSmemSize = kChunkSizeL * kNEltsPerRow * kNBytes;
|
||||
};
|
||||
|
||||
template<typename Ktraits, bool kHasSeqIdx>
|
||||
__global__ __launch_bounds__(Ktraits::kNThreads)
|
||||
void causal_conv1d_channellast_fwd_kernel(ConvParamsBase params) {
|
||||
constexpr int kWidth = Ktraits::kWidth;
|
||||
constexpr int kNThreads = Ktraits::kNThreads;
|
||||
constexpr int kNElts = Ktraits::kNElts;
|
||||
constexpr int kNThreadsPerC = Ktraits::kNThreadsPerRow;
|
||||
constexpr int kLPerLoad = Ktraits::kNColsPerLoad;
|
||||
constexpr int kChunkSizeL = Ktraits::kChunkSizeL;
|
||||
constexpr int kChunkSizeC = Ktraits::kNEltsPerRow;
|
||||
using input_t = typename Ktraits::input_t;
|
||||
using vec_t = typename Ktraits::vec_t;
|
||||
using weight_t = typename Ktraits::weight_t;
|
||||
|
||||
// Shared memory.
|
||||
__shared__ input_t x_smem[kWidth - 1 + kChunkSizeL][kChunkSizeC + kNElts];
|
||||
|
||||
const int batch_id = blockIdx.x;
|
||||
const int chunk_l_id = blockIdx.y;
|
||||
const int chunk_c_id = blockIdx.z;
|
||||
const int tid = threadIdx.x;
|
||||
const int l_idx = tid / kNThreadsPerC;
|
||||
const int c_idx = tid % kNThreadsPerC;
|
||||
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
|
||||
+ (chunk_l_id * kChunkSizeL + l_idx) * params.x_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
|
||||
weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr)
|
||||
+ chunk_c_id * kChunkSizeC * params.weight_c_stride;
|
||||
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
|
||||
+ (chunk_l_id * kChunkSizeL + l_idx) * params.out_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
|
||||
int *seq_idx = !kHasSeqIdx ? nullptr : reinterpret_cast<int *>(params.seq_idx_ptr)
|
||||
+ batch_id * params.seqlen + chunk_l_id * kChunkSizeL;
|
||||
input_t *initial_states = params.initial_states_ptr == nullptr || chunk_l_id > 0 ? nullptr
|
||||
: reinterpret_cast<input_t *>(params.initial_states_ptr) + batch_id * params.initial_states_batch_stride + l_idx * params.initial_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
|
||||
// The last L-chunk will also have enough info to write to final states, since it also contain a few x values
|
||||
// from the previous L-chunk.
|
||||
input_t *final_states = params.final_states_ptr == nullptr || chunk_l_id < gridDim.y - 1 ? nullptr
|
||||
: reinterpret_cast<input_t *>(params.final_states_ptr) + batch_id * params.final_states_batch_stride + l_idx * params.final_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < Ktraits::kNLoads; ++l) {
|
||||
input_t x_vals_load[kNElts] = {0};
|
||||
if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen
|
||||
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
||||
reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x + l * kLPerLoad * params.x_l_stride);
|
||||
}
|
||||
reinterpret_cast<vec_t *>(x_smem[kWidth - 1 + l * kLPerLoad + l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0];
|
||||
}
|
||||
// Load the elements from the previous chunk that are needed for convolution.
|
||||
if (l_idx < kWidth - 1) {
|
||||
input_t x_vals_load[kNElts] = {0};
|
||||
if (chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) >= 0
|
||||
&& chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < params.seqlen
|
||||
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
||||
reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x - (kWidth - 1) * params.x_l_stride);
|
||||
} else if (initial_states != nullptr
|
||||
&& chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < 0
|
||||
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
||||
reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(initial_states);
|
||||
}
|
||||
reinterpret_cast<vec_t *>(x_smem[l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (final_states != nullptr
|
||||
&& l_idx < kWidth - 1
|
||||
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
||||
// x_smem[0] contains element at index chunk_l_id * kChunkSizeL - (kWidth - 1)
|
||||
// So last few elements (index params.seqlen - kWidth + 1 + l_idx) are stored in x_smem[params.seqlen - kWidth + 1 + l_idx - (chunk_l_id * kChunkSizeL - kWidth + 1)][c_idx]
|
||||
*reinterpret_cast<vec_t *>(final_states) = reinterpret_cast<vec_t *>(x_smem[params.seqlen + l_idx - chunk_l_id * kChunkSizeL])[c_idx];
|
||||
}
|
||||
|
||||
constexpr int kLPerThread = constexpr_min(kChunkSizeL * kChunkSizeC / kNThreads, kChunkSizeL);
|
||||
static_assert(kLPerThread * kNThreads == kChunkSizeL * kChunkSizeC);
|
||||
constexpr int kNThreadsPerRow = kChunkSizeL / kLPerThread;
|
||||
static_assert(kNThreadsPerRow * kLPerThread == kChunkSizeL);
|
||||
// kChunkSizeL, kLPerThread, kNThreadsPerRow should be powers of 2 for simplicity
|
||||
static_assert((kChunkSizeL & (kChunkSizeL - 1)) == 0);
|
||||
static_assert((kLPerThread & (kLPerThread - 1)) == 0);
|
||||
static_assert((kNThreadsPerRow & (kNThreadsPerRow - 1)) == 0);
|
||||
static_assert(kNThreadsPerRow <= 32);
|
||||
|
||||
const int row_idx = tid / kNThreadsPerRow;
|
||||
const int col_idx = tid % kNThreadsPerRow;
|
||||
|
||||
float bias_val = params.bias_ptr == nullptr || chunk_c_id * kChunkSizeC + row_idx >= params.dim ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[chunk_c_id * kChunkSizeC + row_idx]);
|
||||
float weight_vals[kWidth] = {0};
|
||||
if (chunk_c_id * kChunkSizeC + row_idx < params.dim) {
|
||||
#pragma unroll
|
||||
for (int w = 0; w < kWidth; ++w) {
|
||||
weight_vals[w] = weight[row_idx * params.weight_c_stride + w * params.weight_width_stride];
|
||||
}
|
||||
}
|
||||
float x_vals[kWidth - 1 + kLPerThread];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) {
|
||||
x_vals[i] = float(x_smem[col_idx * kLPerThread + i][row_idx]);
|
||||
}
|
||||
int seq_idx_thread[kWidth - 1 + kLPerThread];
|
||||
if constexpr (kHasSeqIdx) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) {
|
||||
seq_idx_thread[i] = chunk_l_id * kChunkSizeL + col_idx * kLPerThread + i - (kWidth - 1) >= 0 ? seq_idx[col_idx * kLPerThread + i - (kWidth - 1)] : -1;
|
||||
}
|
||||
}
|
||||
|
||||
float out_vals[kLPerThread];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kLPerThread; ++i) {
|
||||
out_vals[i] = bias_val;
|
||||
const int seq_idx_cur = !kHasSeqIdx ? 0 : seq_idx_thread[i + kWidth - 1];
|
||||
#pragma unroll
|
||||
for (int w = 0; w < kWidth; ++w) {
|
||||
if constexpr (!kHasSeqIdx) {
|
||||
out_vals[i] += weight_vals[w] * x_vals[i + w];
|
||||
} else {
|
||||
out_vals[i] += seq_idx_thread[i + w] == seq_idx_cur ? weight_vals[w] * x_vals[i + w] : 0.f;
|
||||
}
|
||||
}
|
||||
if (params.silu_activation) {out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i])); }
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kLPerThread; ++i) { x_smem[col_idx * kLPerThread + i][row_idx] = out_vals[i]; }
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < Ktraits::kNLoads; ++l) {
|
||||
input_t out_vals_store[kNElts];
|
||||
reinterpret_cast<vec_t *>(out_vals_store)[0] = reinterpret_cast<vec_t *>(x_smem[l * kLPerLoad + l_idx])[c_idx];
|
||||
if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen
|
||||
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
||||
*reinterpret_cast<vec_t *>(out + l * kLPerLoad * params.out_l_stride) = reinterpret_cast<vec_t *>(out_vals_store)[0];
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
template<int kNThreads, int kWidth, typename input_t, typename weight_t>
|
||||
void causal_conv1d_channellast_fwd_launch(ConvParamsBase ¶ms, cudaStream_t stream) {
|
||||
BOOL_SWITCH(params.seq_idx_ptr != nullptr, kHasSeqIdx, [&] {
|
||||
using Ktraits = Causal_conv1d_channellast_fwd_kernel_traits<kNThreads, kWidth, 64, true, input_t, weight_t>;
|
||||
// constexpr int kSmemSize = Ktraits::kSmemSize;
|
||||
constexpr int kChunkSizeL = Ktraits::kChunkSizeL;
|
||||
constexpr int kChunkSizeC = Ktraits::kNEltsPerRow;
|
||||
const int n_chunks_L = (params.seqlen + kChunkSizeL - 1) / kChunkSizeL;
|
||||
const int n_chunks_C = (params.dim + kChunkSizeC - 1) / kChunkSizeC;
|
||||
dim3 grid(params.batch, n_chunks_L, n_chunks_C);
|
||||
dim3 block(Ktraits::kNThreads);
|
||||
auto kernel = &causal_conv1d_channellast_fwd_kernel<Ktraits, kHasSeqIdx>;
|
||||
// if (kSmemSize >= 48 * 1024) {
|
||||
// C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
// kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
||||
// }
|
||||
// kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
||||
kernel<<<grid, Ktraits::kNThreads, 0, stream>>>(params);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
});
|
||||
}
|
||||
|
||||
template<typename input_t, typename weight_t>
|
||||
void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream) {
|
||||
if (params.width == 2) {
|
||||
causal_conv1d_channellast_fwd_launch<128, 2, input_t, weight_t>(params, stream);
|
||||
} else if (params.width == 3) {
|
||||
causal_conv1d_channellast_fwd_launch<128, 3, input_t, weight_t>(params, stream);
|
||||
} else if (params.width == 4) {
|
||||
causal_conv1d_channellast_fwd_launch<128, 4, input_t, weight_t>(params, stream);
|
||||
}
|
||||
}
|
||||
|
||||
template void causal_conv1d_fwd_cuda<float, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
template void causal_conv1d_fwd_cuda<at::Half, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
template void causal_conv1d_fwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
|
||||
template void causal_conv1d_channellast_fwd_cuda<float, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
template void causal_conv1d_channellast_fwd_cuda<at::Half, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
template void causal_conv1d_channellast_fwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
///////
|
||||
|
||||
|
||||
|
||||
|
||||
template<int kNThreads_, int kWidth_, typename input_t_, typename weight_t_>
|
||||
struct Causal_conv1d_update_kernel_traits {
|
||||
using input_t = input_t_;
|
||||
using weight_t = weight_t_;
|
||||
static constexpr int kNThreads = kNThreads_;
|
||||
static constexpr int kWidth = kWidth_;
|
||||
static constexpr int kNBytes = sizeof(input_t);
|
||||
static_assert(kNBytes == 2 || kNBytes == 4);
|
||||
};
|
||||
|
||||
template<typename Ktraits>
|
||||
__global__ __launch_bounds__(Ktraits::kNThreads)
|
||||
void causal_conv1d_update_kernel(ConvParamsBase params) {
|
||||
constexpr int kWidth = Ktraits::kWidth;
|
||||
constexpr int kNThreads = Ktraits::kNThreads;
|
||||
using input_t = typename Ktraits::input_t;
|
||||
using weight_t = typename Ktraits::weight_t;
|
||||
|
||||
const int tidx = threadIdx.x;
|
||||
const int batch_id = blockIdx.x;
|
||||
const int channel_id = blockIdx.y * kNThreads + tidx;
|
||||
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
|
||||
+ channel_id * params.x_c_stride;
|
||||
input_t *conv_state = reinterpret_cast<input_t *>(params.conv_state_ptr) + batch_id * params.conv_state_batch_stride
|
||||
+ channel_id * params.conv_state_c_stride;
|
||||
weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride;
|
||||
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
|
||||
+ channel_id * params.out_c_stride;
|
||||
float bias_val = params.bias_ptr == nullptr || channel_id >= params.dim ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[channel_id]);
|
||||
|
||||
float weight_vals[kWidth] = {0};
|
||||
if (channel_id < params.dim) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); }
|
||||
}
|
||||
|
||||
float x_vals[kWidth] = {0};
|
||||
if (channel_id < params.dim) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kWidth - 1; ++i) { x_vals[i] = float(conv_state[(i + 1) * params.conv_state_l_stride]); }
|
||||
x_vals[kWidth - 1] = float(x[0]);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kWidth; ++i) { conv_state[i * params.conv_state_l_stride] = input_t(x_vals[i]); }
|
||||
}
|
||||
|
||||
float out_val = bias_val;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kWidth; ++i) { out_val += weight_vals[i] * x_vals[i]; }
|
||||
if (params.silu_activation) { out_val = out_val / (1 + expf(-out_val)); }
|
||||
if (channel_id < params.dim) { out[0] = input_t(out_val); }
|
||||
}
|
||||
|
||||
template<int kNThreads, int kWidth, typename input_t, typename weight_t>
|
||||
void causal_conv1d_update_launch(ConvParamsBase ¶ms, cudaStream_t stream) {
|
||||
using Ktraits = Causal_conv1d_update_kernel_traits<kNThreads, kWidth, input_t, weight_t>;
|
||||
dim3 grid(params.batch, (params.dim + kNThreads - 1) / kNThreads);
|
||||
auto kernel = &causal_conv1d_update_kernel<Ktraits>;
|
||||
kernel<<<grid, Ktraits::kNThreads, 0, stream>>>(params);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
}
|
||||
|
||||
template<typename input_t, typename weight_t>
|
||||
void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream) {
|
||||
if (params.width == 2) {
|
||||
causal_conv1d_update_launch<64, 2, input_t, weight_t>(params, stream);
|
||||
} else if (params.width == 3) {
|
||||
causal_conv1d_update_launch<64, 3, input_t, weight_t>(params, stream);
|
||||
} else if (params.width == 4) {
|
||||
causal_conv1d_update_launch<64, 4, input_t, weight_t>(params, stream);
|
||||
}
|
||||
}
|
||||
|
||||
template void causal_conv1d_update_cuda<float, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
template void causal_conv1d_update_cuda<at::Half, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
template void causal_conv1d_update_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
144
csrc/mamba/causal_conv1d/causal_conv1d.h
Normal file
144
csrc/mamba/causal_conv1d/causal_conv1d.h
Normal file
@ -0,0 +1,144 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2024, Tri Dao.
|
||||
******************************************************************************/
|
||||
// clang-format off
|
||||
// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d.h
|
||||
#pragma once
|
||||
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct ConvParamsBase {
|
||||
using index_t = uint32_t;
|
||||
|
||||
int batch, dim, seqlen, width;
|
||||
bool silu_activation;
|
||||
|
||||
index_t x_batch_stride;
|
||||
index_t x_c_stride;
|
||||
index_t x_l_stride;
|
||||
index_t weight_c_stride;
|
||||
index_t weight_width_stride;
|
||||
index_t out_batch_stride;
|
||||
index_t out_c_stride;
|
||||
index_t out_l_stride;
|
||||
|
||||
index_t conv_state_batch_stride;
|
||||
index_t conv_state_c_stride;
|
||||
index_t conv_state_l_stride;
|
||||
|
||||
// Common data pointers.
|
||||
void *__restrict__ x_ptr;
|
||||
void *__restrict__ weight_ptr;
|
||||
void *__restrict__ bias_ptr;
|
||||
void *__restrict__ out_ptr;
|
||||
|
||||
void *__restrict__ conv_state_ptr;
|
||||
|
||||
void *__restrict__ seq_idx_ptr;
|
||||
|
||||
// No __restrict__ since initial_states could be the same as final_states.
|
||||
void * initial_states_ptr;
|
||||
index_t initial_states_batch_stride;
|
||||
index_t initial_states_l_stride;
|
||||
index_t initial_states_c_stride;
|
||||
|
||||
void * final_states_ptr;
|
||||
index_t final_states_batch_stride;
|
||||
index_t final_states_l_stride;
|
||||
index_t final_states_c_stride;
|
||||
};
|
||||
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#include <cuda_bf16.h>
|
||||
|
||||
template<typename T>
|
||||
__device__ inline T shuffle_xor(T val, int offset) {
|
||||
return __shfl_xor_sync(uint32_t(-1), val, offset);
|
||||
}
|
||||
|
||||
constexpr size_t custom_max(std::initializer_list<size_t> ilist)
|
||||
{
|
||||
return std::max(ilist);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
constexpr T constexpr_min(T a, T b) {
|
||||
return std::min(a, b);
|
||||
}
|
||||
|
||||
#else
|
||||
#include <hip/hip_bf16.h>
|
||||
|
||||
template<typename T>
|
||||
__device__ inline T shuffle_xor(T val, int offset) {
|
||||
return __shfl_xor(val, offset);
|
||||
}
|
||||
constexpr size_t custom_max(std::initializer_list<size_t> ilist)
|
||||
{
|
||||
return *std::max_element(ilist.begin(), ilist.end());
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
constexpr T constexpr_min(T a, T b) {
|
||||
return a < b ? a : b;
|
||||
}
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<int BYTES> struct BytesToType {};
|
||||
|
||||
template<> struct BytesToType<16> {
|
||||
using Type = uint4;
|
||||
static_assert(sizeof(Type) == 16);
|
||||
};
|
||||
|
||||
template<> struct BytesToType<8> {
|
||||
using Type = uint64_t;
|
||||
static_assert(sizeof(Type) == 8);
|
||||
};
|
||||
|
||||
template<> struct BytesToType<4> {
|
||||
using Type = uint32_t;
|
||||
static_assert(sizeof(Type) == 4);
|
||||
};
|
||||
|
||||
template<> struct BytesToType<2> {
|
||||
using Type = uint16_t;
|
||||
static_assert(sizeof(Type) == 2);
|
||||
};
|
||||
|
||||
template<> struct BytesToType<1> {
|
||||
using Type = uint8_t;
|
||||
static_assert(sizeof(Type) == 1);
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T>
|
||||
struct SumOp {
|
||||
__device__ inline T operator()(T const & x, T const & y) { return x + y; }
|
||||
};
|
||||
|
||||
template<int THREADS>
|
||||
struct Allreduce {
|
||||
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
|
||||
template<typename T, typename Operator>
|
||||
static __device__ inline T run(T x, Operator &op) {
|
||||
constexpr int OFFSET = THREADS / 2;
|
||||
x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
|
||||
return Allreduce<OFFSET>::run(x, op);
|
||||
}
|
||||
};
|
||||
|
||||
template<>
|
||||
struct Allreduce<2> {
|
||||
template<typename T, typename Operator>
|
||||
static __device__ inline T run(T x, Operator &op) {
|
||||
x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
|
||||
return x;
|
||||
}
|
||||
};
|
28
csrc/mamba/causal_conv1d/static_switch.h
Normal file
28
csrc/mamba/causal_conv1d/static_switch.h
Normal file
@ -0,0 +1,28 @@
|
||||
// Inspired by
|
||||
// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
|
||||
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
|
||||
// clang-format off
|
||||
// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/static_switch.h
|
||||
|
||||
#pragma once
|
||||
|
||||
/// @param COND - a boolean expression to switch by
|
||||
/// @param CONST_NAME - a name given for the constexpr bool variable.
|
||||
/// @param ... - code to execute for true and false
|
||||
///
|
||||
/// Usage:
|
||||
/// ```
|
||||
/// BOOL_SWITCH(flag, BoolConst, [&] {
|
||||
/// some_function<BoolConst>(...);
|
||||
/// });
|
||||
/// ```
|
||||
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
|
||||
[&] { \
|
||||
if (COND) { \
|
||||
static constexpr bool CONST_NAME = true; \
|
||||
return __VA_ARGS__(); \
|
||||
} else { \
|
||||
static constexpr bool CONST_NAME = false; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
}()
|
276
csrc/mamba/mamba_ssm/selective_scan.h
Normal file
276
csrc/mamba/mamba_ssm/selective_scan.h
Normal file
@ -0,0 +1,276 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2023, Tri Dao.
|
||||
******************************************************************************/
|
||||
// clang-format off
|
||||
// adapted from https://github.com/state-spaces/mamba/blob/main/csrc/selective_scan/selective_scan.h
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#include <cuda_bf16.h>
|
||||
#else
|
||||
#include <hip/hip_bf16.h>
|
||||
#endif
|
||||
#include <cuda_fp16.h>
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct SSMParamsBase {
|
||||
using index_t = uint32_t;
|
||||
|
||||
int batch, dim, seqlen, dstate, n_groups, n_chunks;
|
||||
int dim_ngroups_ratio;
|
||||
bool is_variable_B;
|
||||
bool is_variable_C;
|
||||
|
||||
bool delta_softplus;
|
||||
|
||||
index_t A_d_stride;
|
||||
index_t A_dstate_stride;
|
||||
index_t B_batch_stride;
|
||||
index_t B_d_stride;
|
||||
index_t B_dstate_stride;
|
||||
index_t B_group_stride;
|
||||
index_t C_batch_stride;
|
||||
index_t C_d_stride;
|
||||
index_t C_dstate_stride;
|
||||
index_t C_group_stride;
|
||||
index_t u_batch_stride;
|
||||
index_t u_d_stride;
|
||||
index_t delta_batch_stride;
|
||||
index_t delta_d_stride;
|
||||
index_t z_batch_stride;
|
||||
index_t z_d_stride;
|
||||
index_t out_batch_stride;
|
||||
index_t out_d_stride;
|
||||
index_t out_z_batch_stride;
|
||||
index_t out_z_d_stride;
|
||||
|
||||
// Common data pointers.
|
||||
void *__restrict__ A_ptr;
|
||||
void *__restrict__ B_ptr;
|
||||
void *__restrict__ C_ptr;
|
||||
void *__restrict__ D_ptr;
|
||||
void *__restrict__ u_ptr;
|
||||
void *__restrict__ delta_ptr;
|
||||
void *__restrict__ delta_bias_ptr;
|
||||
void *__restrict__ out_ptr;
|
||||
void *__restrict__ x_ptr;
|
||||
void *__restrict__ z_ptr;
|
||||
void *__restrict__ out_z_ptr;
|
||||
void *__restrict__ index_ptr;
|
||||
};
|
||||
|
||||
|
||||
|
||||
|
||||
#ifndef USE_ROCM
|
||||
|
||||
constexpr size_t custom_max(std::initializer_list<size_t> ilist)
|
||||
{
|
||||
return std::max(ilist);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
constexpr T constexpr_min(T a, T b) {
|
||||
return std::min(a, b);
|
||||
}
|
||||
|
||||
#else
|
||||
constexpr size_t custom_max(std::initializer_list<size_t> ilist)
|
||||
{
|
||||
return *std::max_element(ilist.begin(), ilist.end());
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
constexpr T constexpr_min(T a, T b) {
|
||||
return a < b ? a : b;
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
#define MAX_DSTATE 256
|
||||
|
||||
|
||||
inline __device__ float2 operator+(const float2 & a, const float2 & b){
|
||||
return {a.x + b.x, a.y + b.y};
|
||||
}
|
||||
|
||||
inline __device__ float3 operator+(const float3 &a, const float3 &b) {
|
||||
return {a.x + b.x, a.y + b.y, a.z + b.z};
|
||||
}
|
||||
|
||||
inline __device__ float4 operator+(const float4 & a, const float4 & b){
|
||||
return {a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w};
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<int BYTES> struct BytesToType {};
|
||||
|
||||
template<> struct BytesToType<16> {
|
||||
using Type = uint4;
|
||||
static_assert(sizeof(Type) == 16);
|
||||
};
|
||||
|
||||
template<> struct BytesToType<8> {
|
||||
using Type = uint64_t;
|
||||
static_assert(sizeof(Type) == 8);
|
||||
};
|
||||
|
||||
template<> struct BytesToType<4> {
|
||||
using Type = uint32_t;
|
||||
static_assert(sizeof(Type) == 4);
|
||||
};
|
||||
|
||||
template<> struct BytesToType<2> {
|
||||
using Type = uint16_t;
|
||||
static_assert(sizeof(Type) == 2);
|
||||
};
|
||||
|
||||
template<> struct BytesToType<1> {
|
||||
using Type = uint8_t;
|
||||
static_assert(sizeof(Type) == 1);
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename scalar_t, int N>
|
||||
struct Converter{
|
||||
static inline __device__ void to_float(const scalar_t (&src)[N], float (&dst)[N]) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N; ++i) { dst[i] = src[i]; }
|
||||
}
|
||||
};
|
||||
|
||||
template<int N>
|
||||
struct Converter<at::Half, N>{
|
||||
static inline __device__ void to_float(const at::Half (&src)[N], float (&dst)[N]) {
|
||||
static_assert(N % 2 == 0);
|
||||
auto &src2 = reinterpret_cast<const half2 (&)[N / 2]>(src);
|
||||
auto &dst2 = reinterpret_cast<float2 (&)[N / 2]>(dst);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N / 2; ++i) { dst2[i] = __half22float2(src2[i]); }
|
||||
}
|
||||
};
|
||||
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
template<int N>
|
||||
struct Converter<at::BFloat16, N>{
|
||||
static inline __device__ void to_float(const at::BFloat16 (&src)[N], float (&dst)[N]) {
|
||||
static_assert(N % 2 == 0);
|
||||
auto &src2 = reinterpret_cast<const nv_bfloat162 (&)[N / 2]>(src);
|
||||
auto &dst2 = reinterpret_cast<float2 (&)[N / 2]>(dst);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N / 2; ++i) { dst2[i] = __bfloat1622float2(src2[i]); }
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
template<typename scalar_t> struct SSMScanOp;
|
||||
|
||||
template<>
|
||||
struct SSMScanOp<float> {
|
||||
__device__ __forceinline__ float2 operator()(const float2 &ab0, const float2 &ab1) const {
|
||||
return make_float2(ab1.x * ab0.x, ab1.x * ab0.y + ab1.y);
|
||||
}
|
||||
};
|
||||
|
||||
// A stateful callback functor that maintains a running prefix to be applied
|
||||
// during consecutive scan operations.
|
||||
template <typename scalar_t> struct SSMScanPrefixCallbackOp {
|
||||
using scan_t = std::conditional_t<std::is_same_v<scalar_t, float>, float2, float4>;
|
||||
scan_t running_prefix;
|
||||
// Constructor
|
||||
__device__ SSMScanPrefixCallbackOp(scan_t running_prefix_) : running_prefix(running_prefix_) {}
|
||||
// Callback operator to be entered by the first warp of threads in the block.
|
||||
// Thread-0 is responsible for returning a value for seeding the block-wide scan.
|
||||
__device__ scan_t operator()(scan_t block_aggregate) {
|
||||
scan_t old_prefix = running_prefix;
|
||||
running_prefix = SSMScanOp<scalar_t>()(running_prefix, block_aggregate);
|
||||
return old_prefix;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Ktraits>
|
||||
inline __device__ void load_input(typename Ktraits::input_t *u,
|
||||
typename Ktraits::input_t (&u_vals)[Ktraits::kNItems],
|
||||
typename Ktraits::BlockLoadT::TempStorage &smem_load,
|
||||
int seqlen) {
|
||||
if constexpr (Ktraits::kIsEvenLen) {
|
||||
auto& smem_load_vec = reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(smem_load);
|
||||
using vec_t = typename Ktraits::vec_t;
|
||||
typename Ktraits::BlockLoadVecT(smem_load_vec).Load(
|
||||
reinterpret_cast<vec_t*>(u),
|
||||
reinterpret_cast<vec_t(&)[Ktraits::kNLoads]>(u_vals)
|
||||
#ifdef USE_ROCM
|
||||
, Ktraits::kNThreads * Ktraits::kNLoads
|
||||
#endif
|
||||
|
||||
);
|
||||
} else {
|
||||
typename Ktraits::BlockLoadT(smem_load).Load(u, u_vals, seqlen, 0.f);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename Ktraits>
|
||||
inline __device__ void load_index(int *u,
|
||||
int (&u_vals)[Ktraits::kNItems],
|
||||
typename Ktraits::BlockLoadIndexT::TempStorage &smem_load_index,
|
||||
int seqlen) {
|
||||
if constexpr (Ktraits::kIsEvenLen) {
|
||||
auto& smem_load_index_vec = reinterpret_cast<typename Ktraits::BlockLoadIndexVecT::TempStorage&>(smem_load_index);
|
||||
Ktraits::BlockLoadIndexVecT(smem_load_index_vec).Load(
|
||||
reinterpret_cast<uint4*>(u),
|
||||
reinterpret_cast<uint4(&)[Ktraits::kNLoadsIndex]>(u_vals)
|
||||
);
|
||||
} else {
|
||||
Ktraits::BlockLoadIndexT(smem_load_index).Load(u, u_vals, seqlen, 0);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename Ktraits>
|
||||
inline __device__ void load_weight(typename Ktraits::input_t *Bvar,
|
||||
typename Ktraits::weight_t (&B_vals)[Ktraits::kNItems],
|
||||
typename Ktraits::BlockLoadWeightT::TempStorage &smem_load_weight,
|
||||
int seqlen) {
|
||||
constexpr int kNItems = Ktraits::kNItems;
|
||||
typename Ktraits::input_t B_vals_load[kNItems];
|
||||
if constexpr (Ktraits::kIsEvenLen) {
|
||||
auto& smem_load_weight_vec = reinterpret_cast<typename Ktraits::BlockLoadWeightVecT::TempStorage&>(smem_load_weight);
|
||||
using vec_t = typename Ktraits::vec_t;
|
||||
typename Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load(
|
||||
reinterpret_cast<vec_t*>(Bvar),
|
||||
reinterpret_cast<vec_t(&)[Ktraits::kNLoads]>(B_vals_load)
|
||||
);
|
||||
} else {
|
||||
typename Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f);
|
||||
}
|
||||
// #pragma unroll
|
||||
// for (int i = 0; i < kNItems; ++i) { B_vals[i] = B_vals_load[i]; }
|
||||
Converter<typename Ktraits::input_t, kNItems>::to_float(B_vals_load, B_vals);
|
||||
}
|
||||
|
||||
template<typename Ktraits>
|
||||
inline __device__ void store_output(typename Ktraits::input_t *out,
|
||||
const float (&out_vals)[Ktraits::kNItems],
|
||||
typename Ktraits::BlockStoreT::TempStorage &smem_store,
|
||||
int seqlen) {
|
||||
typename Ktraits::input_t write_vals[Ktraits::kNItems];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < Ktraits::kNItems; ++i) { write_vals[i] = out_vals[i]; }
|
||||
if constexpr (Ktraits::kIsEvenLen) {
|
||||
auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_store);
|
||||
using vec_t = typename Ktraits::vec_t;
|
||||
typename Ktraits::BlockStoreVecT(smem_store_vec).Store(
|
||||
reinterpret_cast<vec_t*>(out),
|
||||
reinterpret_cast<vec_t(&)[Ktraits::kNLoads]>(write_vals)
|
||||
);
|
||||
} else {
|
||||
typename Ktraits::BlockStoreT(smem_store).Store(out, write_vals, seqlen);
|
||||
}
|
||||
}
|
593
csrc/mamba/mamba_ssm/selective_scan_fwd.cu
Normal file
593
csrc/mamba/mamba_ssm/selective_scan_fwd.cu
Normal file
@ -0,0 +1,593 @@
|
||||
// clang-format off
|
||||
// adapted from https://github.com/state-spaces/mamba/blob/main/csrc/selective_scan/selective_scan_fwd_kernel.cuh
|
||||
#include <torch/all.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include "selective_scan.h"
|
||||
|
||||
#include <c10/util/BFloat16.h>
|
||||
#include <c10/util/Half.h>
|
||||
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#include <cub/block/block_load.cuh>
|
||||
#include <cub/block/block_store.cuh>
|
||||
#include <cub/block/block_scan.cuh>
|
||||
#else
|
||||
#include <hipcub/hipcub.hpp>
|
||||
namespace cub = hipcub;
|
||||
#endif
|
||||
|
||||
#include "selective_scan.h"
|
||||
#include "static_switch.h"
|
||||
|
||||
template<int kNThreads_, int kNItems_, int kNRows_, bool kIsEvenLen_,
|
||||
bool kIsVariableB_, bool kIsVariableC_,
|
||||
bool kHasZ_, bool kUseIndex_, typename input_t_, typename weight_t_>
|
||||
struct Selective_Scan_fwd_kernel_traits {
|
||||
static_assert(kNItems_ % 4 == 0);
|
||||
using input_t = input_t_;
|
||||
using weight_t = weight_t_;
|
||||
static constexpr int kNThreads = kNThreads_;
|
||||
// Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads improves occupancy.
|
||||
static constexpr int kMinBlocks = kNThreads < 128 ? 5 : 3;
|
||||
static constexpr int kNItems = kNItems_;
|
||||
static constexpr int kNRows = kNRows_;
|
||||
static constexpr int kNBytes = sizeof(input_t);
|
||||
static_assert(kNBytes == 2 || kNBytes == 4);
|
||||
static constexpr int kNElts = kNBytes == 4 ? 4 : constexpr_min(8, kNItems);
|
||||
static_assert(kNItems % kNElts == 0);
|
||||
static constexpr int kNLoads = kNItems / kNElts;
|
||||
static constexpr bool kIsEvenLen = kIsEvenLen_;
|
||||
static constexpr bool kIsVariableB = kIsVariableB_;
|
||||
static constexpr bool kIsVariableC = kIsVariableC_;
|
||||
static constexpr bool kHasZ = kHasZ_;
|
||||
static constexpr bool kUseIndex = kUseIndex_;
|
||||
|
||||
static constexpr bool kDirectIO = kIsEvenLen && kNLoads == 1;
|
||||
static constexpr int kNLoadsIndex = kNItems / 4;
|
||||
using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
|
||||
using scan_t = float2;
|
||||
using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
||||
using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads,
|
||||
!kDirectIO ? cub::BLOCK_LOAD_WARP_TRANSPOSE : cub::BLOCK_LOAD_DIRECT>;
|
||||
using BlockLoadIndexT = cub::BlockLoad<int, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
||||
using BlockLoadIndexVecT = cub::BlockLoad<uint4, kNThreads, kNLoadsIndex,
|
||||
!(kIsEvenLen && kNLoadsIndex == 1) ? cub::BLOCK_LOAD_WARP_TRANSPOSE : cub::BLOCK_LOAD_DIRECT>;
|
||||
using BlockLoadWeightT = cub::BlockLoad<input_t, kNThreads, kNItems , cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
||||
using BlockLoadWeightVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads ,
|
||||
!kDirectIO ? cub::BLOCK_LOAD_WARP_TRANSPOSE : cub::BLOCK_LOAD_DIRECT>;
|
||||
using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
|
||||
using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, kNLoads,
|
||||
!kDirectIO ? cub::BLOCK_STORE_WARP_TRANSPOSE : cub::BLOCK_STORE_DIRECT>;
|
||||
// using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING_MEMOIZE>;
|
||||
// using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING>;
|
||||
using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_WARP_SCANS>;
|
||||
static constexpr int kSmemIOSize = custom_max({sizeof(typename BlockLoadT::TempStorage),
|
||||
sizeof(typename BlockLoadVecT::TempStorage),
|
||||
sizeof(typename BlockLoadIndexT::TempStorage),
|
||||
sizeof(typename BlockLoadIndexVecT::TempStorage),
|
||||
(int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightT::TempStorage),
|
||||
(int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightVecT::TempStorage),
|
||||
sizeof(typename BlockStoreT::TempStorage),
|
||||
sizeof(typename BlockStoreVecT::TempStorage)});
|
||||
static constexpr int kSmemSize = kSmemIOSize + sizeof(typename BlockScanT::TempStorage);
|
||||
};
|
||||
|
||||
template<typename Ktraits>
|
||||
__global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks)
|
||||
void selective_scan_fwd_kernel(SSMParamsBase params) {
|
||||
constexpr bool kIsVariableB = Ktraits::kIsVariableB;
|
||||
constexpr bool kIsVariableC = Ktraits::kIsVariableC;
|
||||
constexpr bool kHasZ = Ktraits::kHasZ;
|
||||
constexpr bool kUseIndex = Ktraits::kUseIndex;
|
||||
constexpr int kNThreads = Ktraits::kNThreads;
|
||||
constexpr int kNItems = Ktraits::kNItems;
|
||||
constexpr int kNRows = Ktraits::kNRows;
|
||||
constexpr bool kDirectIO = Ktraits::kDirectIO;
|
||||
using input_t = typename Ktraits::input_t;
|
||||
using weight_t = typename Ktraits::weight_t;
|
||||
using scan_t = typename Ktraits::scan_t;
|
||||
|
||||
// Shared memory.
|
||||
extern __shared__ char smem_[];
|
||||
// cast to lvalue reference of expected type
|
||||
// char *smem_loadstorescan = smem_ + 2 * MAX_DSTATE * sizeof(weight_t);
|
||||
// auto& smem_load = reinterpret_cast<typename BlockLoadT::TempStorage&>(smem_ + 2 * MAX_DSTATE * sizeof(weight_t));
|
||||
// auto& smem_load = reinterpret_cast<typename BlockLoadT::TempStorage&>(smem_loadstorescan);
|
||||
auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
|
||||
auto& smem_load_weight = reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage&>(smem_);
|
||||
auto& smem_load_index = reinterpret_cast<typename Ktraits::BlockLoadIndexT::TempStorage&>(smem_);
|
||||
auto& smem_load_weight1 = *reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage*>(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage));
|
||||
auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
|
||||
auto& smem_scan = *reinterpret_cast<typename Ktraits::BlockScanT::TempStorage*>(smem_ + Ktraits::kSmemIOSize);
|
||||
// weight_t *smem_a = reinterpret_cast<weight_t *>(smem_ + smem_loadstorescan_size);
|
||||
// weight_t *smem_bc = reinterpret_cast<weight_t *>(smem_a + MAX_DSTATE);
|
||||
scan_t *smem_running_prefix = reinterpret_cast<scan_t *>(smem_ + Ktraits::kSmemSize);
|
||||
|
||||
const int batch_id = blockIdx.x;
|
||||
const int dim_id = blockIdx.y;
|
||||
const int group_id = dim_id / (params.dim_ngroups_ratio);
|
||||
input_t *u = reinterpret_cast<input_t *>(params.u_ptr) + batch_id * params.u_batch_stride
|
||||
+ dim_id * kNRows * params.u_d_stride;
|
||||
input_t *delta = reinterpret_cast<input_t *>(params.delta_ptr) + batch_id * params.delta_batch_stride
|
||||
+ dim_id * kNRows * params.delta_d_stride;
|
||||
weight_t *A = reinterpret_cast<weight_t *>(params.A_ptr) + dim_id * kNRows * params.A_d_stride;
|
||||
weight_t *B = reinterpret_cast<weight_t *>(params.B_ptr) + dim_id * kNRows * params.B_d_stride;
|
||||
input_t *Bvar = reinterpret_cast<input_t *>(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride;
|
||||
weight_t *C = reinterpret_cast<weight_t *>(params.C_ptr) + dim_id * kNRows * params.C_d_stride;
|
||||
input_t *Cvar = reinterpret_cast<input_t *>(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride;
|
||||
scan_t *x = reinterpret_cast<scan_t *>(params.x_ptr) + (batch_id * params.dim + dim_id * kNRows) * params.n_chunks * params.dstate;
|
||||
int *index = !kUseIndex ? nullptr :reinterpret_cast<int *>(params.index_ptr) + batch_id * params.seqlen;
|
||||
|
||||
float D_val[kNRows] = {0};
|
||||
if (params.D_ptr != nullptr) {
|
||||
#pragma unroll
|
||||
for (int r = 0; r < kNRows; ++r) {
|
||||
D_val[r] = reinterpret_cast<float *>(params.D_ptr)[dim_id * kNRows + r];
|
||||
}
|
||||
}
|
||||
float delta_bias[kNRows] = {0};
|
||||
if (params.delta_bias_ptr != nullptr) {
|
||||
#pragma unroll
|
||||
for (int r = 0; r < kNRows; ++r) {
|
||||
delta_bias[r] = reinterpret_cast<float *>(params.delta_bias_ptr)[dim_id * kNRows + r];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) {
|
||||
// smem_a[state_idx] = A[state_idx * params.A_dstate_stride];
|
||||
// smem_bc[state_idx] = B[state_idx * params.B_dstate_stride] * C[state_idx * params.C_dstate_stride];
|
||||
// }
|
||||
|
||||
constexpr int kChunkSize = kNThreads * kNItems;
|
||||
for (int chunk = 0; chunk < params.n_chunks; ++chunk) {
|
||||
input_t u_vals[kNRows][kNItems], delta_vals_load[kNRows][kNItems];
|
||||
int index_vals_load[kNRows][kNItems];
|
||||
|
||||
__syncthreads();
|
||||
#pragma unroll
|
||||
for (int r = 0; r < kNRows; ++r) {
|
||||
if constexpr (!kDirectIO) {
|
||||
if (r > 0) { __syncthreads(); }
|
||||
}
|
||||
load_input<Ktraits>(u + r * params.u_d_stride, u_vals[r], smem_load, params.seqlen - chunk * kChunkSize);
|
||||
if constexpr (!kDirectIO) { __syncthreads(); }
|
||||
load_input<Ktraits>(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, params.seqlen - chunk * kChunkSize);
|
||||
if constexpr (kUseIndex) {
|
||||
load_index<Ktraits>(index + r * params.delta_d_stride, index_vals_load[r], smem_load_index, params.seqlen - chunk * kChunkSize);
|
||||
}
|
||||
}
|
||||
if constexpr (kUseIndex) {
|
||||
index += kChunkSize;
|
||||
}
|
||||
u += kChunkSize;
|
||||
delta += kChunkSize;
|
||||
|
||||
float delta_vals[kNRows][kNItems], delta_u_vals[kNRows][kNItems], out_vals[kNRows][kNItems];
|
||||
#pragma unroll
|
||||
for (int r = 0; r < kNRows; ++r) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNItems; ++i) {
|
||||
float u_val = float(u_vals[r][i]);
|
||||
delta_vals[r][i] = float(delta_vals_load[r][i]) + delta_bias[r];
|
||||
if (params.delta_softplus) {
|
||||
delta_vals[r][i] = delta_vals[r][i] <= 20.f ? log1pf(expf(delta_vals[r][i])) : delta_vals[r][i];
|
||||
}
|
||||
delta_u_vals[r][i] = delta_vals[r][i] * u_val;
|
||||
out_vals[r][i] = D_val[r] * u_val;
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
for (int state_idx = 0; state_idx < params.dstate; ++state_idx) {
|
||||
weight_t A_val[kNRows];
|
||||
#pragma unroll
|
||||
for (int r = 0; r < kNRows; ++r) {
|
||||
A_val[r] = A[state_idx * params.A_dstate_stride + r * params.A_d_stride];
|
||||
// Multiply the real part of A with LOG2E so we can use exp2f instead of expf.
|
||||
constexpr float kLog2e = M_LOG2E;
|
||||
A_val[r] *= kLog2e;
|
||||
}
|
||||
// This variable holds B * C if both B and C are constant across seqlen. If only B varies
|
||||
// across seqlen, this holds C. If only C varies across seqlen, this holds B.
|
||||
// If both B and C vary, this is unused.
|
||||
weight_t BC_val[kNRows];
|
||||
weight_t B_vals[kNItems], C_vals[kNItems];
|
||||
if constexpr (kIsVariableB) {
|
||||
load_weight<Ktraits>(Bvar + state_idx * params.B_dstate_stride, B_vals,
|
||||
smem_load_weight, (params.seqlen - chunk * kChunkSize) * (1));
|
||||
if constexpr (!kIsVariableC) {
|
||||
#pragma unroll
|
||||
for (int r = 0; r < kNRows; ++r) {
|
||||
BC_val[r] = C[state_idx * params.C_dstate_stride + r * params.C_d_stride];
|
||||
}
|
||||
}
|
||||
}
|
||||
if constexpr (kIsVariableC) {
|
||||
auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1;
|
||||
load_weight<Ktraits>(Cvar + state_idx * params.C_dstate_stride, C_vals,
|
||||
smem_load_weight_C, (params.seqlen - chunk * kChunkSize) * (1 ));
|
||||
if constexpr (!kIsVariableB) {
|
||||
#pragma unroll
|
||||
for (int r = 0; r < kNRows; ++r) {
|
||||
BC_val[r] = B[state_idx * params.B_dstate_stride + r * params.B_d_stride];
|
||||
}
|
||||
}
|
||||
}
|
||||
if constexpr (!kIsVariableB && !kIsVariableC) {
|
||||
#pragma unroll
|
||||
for (int r = 0; r < kNRows; ++r) {
|
||||
BC_val[r] = B[state_idx * params.B_dstate_stride + r * params.B_d_stride] * C[state_idx * params.C_dstate_stride + r * params.C_d_stride];
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int r = 0; r < kNRows; ++r) {
|
||||
if (r > 0) { __syncthreads(); } // Scan could be using the same smem
|
||||
scan_t thread_data[kNItems];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNItems; ++i) {
|
||||
thread_data[i] = make_float2(exp2f(delta_vals[r][i] * A_val[r]),
|
||||
!kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]);
|
||||
|
||||
// Reset A bar for cumulative sequences (Real)
|
||||
if constexpr (kUseIndex) {
|
||||
if (index_vals_load[r][i] == 0) {
|
||||
thread_data[i].x = 0.f;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct
|
||||
if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) {
|
||||
thread_data[i] = make_float2(1.f, 0.f);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Initialize running total
|
||||
scan_t running_prefix;
|
||||
// If we use WARP_SCAN then all lane 0 of all warps (not just thread 0) needs to read
|
||||
running_prefix = chunk == 0 ? x[(r * params.n_chunks) * params.dstate + state_idx] : ( threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.f, 0.f));
|
||||
// running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float2(1.f, 0.f);
|
||||
SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix);
|
||||
typename Ktraits::BlockScanT(smem_scan).InclusiveScan(
|
||||
thread_data, thread_data, SSMScanOp<weight_t>(), prefix_op
|
||||
);
|
||||
// There's a syncthreads in the scan op, so we don't need to sync here.
|
||||
// Unless there's only 1 warp, but then it's the same thread (0) reading and writing.
|
||||
if (threadIdx.x == 0) {
|
||||
smem_running_prefix[state_idx] = prefix_op.running_prefix;
|
||||
x[(r * params.n_chunks + chunk) * params.dstate + state_idx] = prefix_op.running_prefix;
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNItems; ++i) {
|
||||
const weight_t C_val = !kIsVariableC
|
||||
? BC_val[r]
|
||||
: (!kIsVariableB ? BC_val[r] * C_vals[i] : C_vals[i]);
|
||||
out_vals[r][i] += thread_data[i].y * C_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
|
||||
+ dim_id * kNRows * params.out_d_stride + chunk * kChunkSize;
|
||||
__syncthreads();
|
||||
#pragma unroll
|
||||
for (int r = 0; r < kNRows; ++r) {
|
||||
if constexpr (!kDirectIO) {
|
||||
if (r > 0) { __syncthreads(); }
|
||||
}
|
||||
store_output<Ktraits>(out + r * params.out_d_stride, out_vals[r], smem_store, params.seqlen - chunk * kChunkSize);
|
||||
}
|
||||
|
||||
if constexpr (kHasZ) {
|
||||
input_t *z = reinterpret_cast<input_t *>(params.z_ptr) + batch_id * params.z_batch_stride
|
||||
+ dim_id * kNRows * params.z_d_stride + chunk * kChunkSize;
|
||||
input_t *out_z = reinterpret_cast<input_t *>(params.out_z_ptr) + batch_id * params.out_z_batch_stride
|
||||
+ dim_id * kNRows * params.out_z_d_stride + chunk * kChunkSize;
|
||||
#pragma unroll
|
||||
for (int r = 0; r < kNRows; ++r) {
|
||||
input_t z_vals[kNItems];
|
||||
__syncthreads();
|
||||
load_input<Ktraits>(z + r * params.z_d_stride, z_vals, smem_load, params.seqlen - chunk * kChunkSize);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNItems; ++i) {
|
||||
float z_val = z_vals[i];
|
||||
out_vals[r][i] *= z_val / (1 + expf(-z_val));
|
||||
}
|
||||
__syncthreads();
|
||||
store_output<Ktraits>(out_z + r * params.out_z_d_stride, out_vals[r], smem_store, params.seqlen - chunk * kChunkSize);
|
||||
}
|
||||
}
|
||||
|
||||
Bvar += kChunkSize * 1;
|
||||
Cvar += kChunkSize * 1;
|
||||
}
|
||||
}
|
||||
|
||||
template<int kNThreads, int kNItems, typename input_t, typename weight_t>
|
||||
void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) {
|
||||
// Only kNRows == 1 is tested for now, which ofc doesn't differ from previously when we had each block
|
||||
// processing 1 row.
|
||||
constexpr int kNRows = 1;
|
||||
// kIsVariableB, kIsVariableC and kHasZ are all set to True to reduce binary size
|
||||
constexpr bool kIsVariableB = true;
|
||||
constexpr bool kIsVariableC = true;
|
||||
constexpr bool kHasZ = true;
|
||||
BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {
|
||||
BOOL_SWITCH(params.index_ptr != nullptr , kUseIndex, [&] {
|
||||
using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, kIsVariableC, kHasZ, kUseIndex, input_t, weight_t>;
|
||||
constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t);
|
||||
dim3 grid(params.batch, params.dim / kNRows);
|
||||
auto kernel = &selective_scan_fwd_kernel<Ktraits>;
|
||||
if (kSmemSize >= 48 * 1024) {
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
||||
}
|
||||
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template<typename input_t, typename weight_t>
|
||||
void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream) {
|
||||
|
||||
#ifndef USE_ROCM
|
||||
if (params.seqlen <= 128) {
|
||||
selective_scan_fwd_launch<32, 4, input_t, weight_t>(params, stream);
|
||||
} else if (params.seqlen <= 256) {
|
||||
selective_scan_fwd_launch<32, 8, input_t, weight_t>(params, stream);
|
||||
} else if (params.seqlen <= 512) {
|
||||
selective_scan_fwd_launch<32, 16, input_t, weight_t>(params, stream);
|
||||
} else if (params.seqlen <= 1024) {
|
||||
selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream);
|
||||
} else {
|
||||
selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream);
|
||||
}
|
||||
#else
|
||||
if (params.seqlen <= 256) {
|
||||
selective_scan_fwd_launch<64, 4, input_t, weight_t>(params, stream);
|
||||
} else if (params.seqlen <= 512) {
|
||||
selective_scan_fwd_launch<64, 8, input_t, weight_t>(params, stream);
|
||||
} else if (params.seqlen <= 1024) {
|
||||
selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream);
|
||||
} else {
|
||||
selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template void selective_scan_fwd_cuda<at::BFloat16, float>(SSMParamsBase ¶ms, cudaStream_t stream);
|
||||
template void selective_scan_fwd_cuda<at::Half, float>(SSMParamsBase ¶ms, cudaStream_t stream);
|
||||
template void selective_scan_fwd_cuda<float, float>(SSMParamsBase ¶ms, cudaStream_t stream);
|
||||
|
||||
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
|
||||
|
||||
#define DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \
|
||||
if (ITYPE == at::ScalarType::Half) { \
|
||||
using input_t = at::Half; \
|
||||
using weight_t = float; \
|
||||
__VA_ARGS__(); \
|
||||
} else if (ITYPE == at::ScalarType::BFloat16) { \
|
||||
using input_t = at::BFloat16; \
|
||||
using weight_t = float; \
|
||||
__VA_ARGS__(); \
|
||||
} else if (ITYPE == at::ScalarType::Float) { \
|
||||
using input_t = float; \
|
||||
using weight_t = float; \
|
||||
__VA_ARGS__(); \
|
||||
} else { \
|
||||
AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
|
||||
}
|
||||
|
||||
|
||||
template<typename input_t, typename weight_t>
|
||||
void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream);
|
||||
|
||||
void set_ssm_params_fwd(SSMParamsBase ¶ms,
|
||||
// sizes
|
||||
const size_t batch,
|
||||
const size_t dim,
|
||||
const size_t seqlen,
|
||||
const size_t dstate,
|
||||
const size_t n_groups,
|
||||
const size_t n_chunks,
|
||||
const bool is_variable_B,
|
||||
const bool is_variable_C,
|
||||
// device pointers
|
||||
const torch::Tensor u,
|
||||
const torch::Tensor delta,
|
||||
const torch::Tensor A,
|
||||
const torch::Tensor B,
|
||||
const torch::Tensor C,
|
||||
const torch::Tensor out,
|
||||
const torch::Tensor z,
|
||||
const torch::Tensor out_z,
|
||||
void* D_ptr,
|
||||
void* delta_bias_ptr,
|
||||
void* x_ptr,
|
||||
bool has_z,
|
||||
bool delta_softplus,
|
||||
void* index_ptr) {
|
||||
|
||||
// Reset the parameters
|
||||
memset(¶ms, 0, sizeof(params));
|
||||
|
||||
params.batch = batch;
|
||||
params.dim = dim;
|
||||
params.seqlen = seqlen;
|
||||
params.dstate = dstate;
|
||||
params.n_groups = n_groups;
|
||||
params.n_chunks = n_chunks;
|
||||
params.dim_ngroups_ratio = dim / n_groups;
|
||||
|
||||
params.delta_softplus = delta_softplus;
|
||||
|
||||
params.is_variable_B = is_variable_B;
|
||||
params.is_variable_C = is_variable_C;
|
||||
|
||||
// Set the pointers and strides.
|
||||
params.u_ptr = u.data_ptr();
|
||||
params.delta_ptr = delta.data_ptr();
|
||||
params.A_ptr = A.data_ptr();
|
||||
params.B_ptr = B.data_ptr();
|
||||
params.C_ptr = C.data_ptr();
|
||||
params.D_ptr = D_ptr;
|
||||
params.delta_bias_ptr = delta_bias_ptr;
|
||||
params.out_ptr = out.data_ptr();
|
||||
params.x_ptr = x_ptr;
|
||||
params.z_ptr = has_z ? z.data_ptr() : nullptr;
|
||||
params.out_z_ptr = has_z ? out_z.data_ptr() : nullptr;
|
||||
|
||||
params.index_ptr = index_ptr;
|
||||
|
||||
// All stride are in elements, not bytes.
|
||||
params.A_d_stride = A.stride(0);
|
||||
params.A_dstate_stride = A.stride(1);
|
||||
if (!is_variable_B) {
|
||||
params.B_d_stride = B.stride(0);
|
||||
} else {
|
||||
params.B_batch_stride = B.stride(0);
|
||||
params.B_group_stride = B.stride(1);
|
||||
}
|
||||
params.B_dstate_stride = !is_variable_B ? B.stride(1) : B.stride(2);
|
||||
if (!is_variable_C) {
|
||||
params.C_d_stride = C.stride(0);
|
||||
} else {
|
||||
params.C_batch_stride = C.stride(0);
|
||||
params.C_group_stride = C.stride(1);
|
||||
}
|
||||
params.C_dstate_stride = !is_variable_C ? C.stride(1) : C.stride(2);
|
||||
params.u_batch_stride = u.stride(0);
|
||||
params.u_d_stride = u.stride(1);
|
||||
params.delta_batch_stride = delta.stride(0);
|
||||
params.delta_d_stride = delta.stride(1);
|
||||
if (has_z) {
|
||||
params.z_batch_stride = z.stride(0);
|
||||
params.z_d_stride = z.stride(1);
|
||||
params.out_z_batch_stride = out_z.stride(0);
|
||||
params.out_z_d_stride = out_z.stride(1);
|
||||
}
|
||||
params.out_batch_stride = out.stride(0);
|
||||
params.out_d_stride = out.stride(1);
|
||||
}
|
||||
|
||||
std::vector<torch::Tensor>
|
||||
selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
|
||||
const torch::Tensor &A, const torch::Tensor &B, const torch::Tensor &C,
|
||||
const c10::optional<torch::Tensor> &D_,
|
||||
const c10::optional<torch::Tensor> &z_,
|
||||
const c10::optional<torch::Tensor> &delta_bias_,
|
||||
bool delta_softplus,
|
||||
const c10::optional<torch::Tensor> &index_,
|
||||
const c10::optional<torch::Tensor> &x) {
|
||||
auto input_type = u.scalar_type();
|
||||
auto weight_type = A.scalar_type();
|
||||
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
||||
TORCH_CHECK(weight_type == at::ScalarType::Float);
|
||||
|
||||
const bool is_variable_B = B.dim() >= 3;
|
||||
const bool is_variable_C = C.dim() >= 3;
|
||||
|
||||
TORCH_CHECK(delta.scalar_type() == input_type);
|
||||
TORCH_CHECK(B.scalar_type() == (!is_variable_B ? weight_type : input_type));
|
||||
TORCH_CHECK(C.scalar_type() == (!is_variable_C ? weight_type : input_type));
|
||||
|
||||
TORCH_CHECK(u.is_cuda());
|
||||
TORCH_CHECK(delta.is_cuda());
|
||||
TORCH_CHECK(A.is_cuda());
|
||||
TORCH_CHECK(B.is_cuda());
|
||||
TORCH_CHECK(C.is_cuda());
|
||||
|
||||
TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1);
|
||||
TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1);
|
||||
|
||||
const auto sizes = u.sizes();
|
||||
const int batch_size = sizes[0];
|
||||
const int dim = sizes[1];
|
||||
const int seqlen = sizes[2];
|
||||
const int dstate = A.size(1);
|
||||
const int n_groups = is_variable_B ? B.size(1) : 1;
|
||||
|
||||
TORCH_CHECK(dstate <= 256, "selective_scan only supports state dimension <= 256");
|
||||
|
||||
CHECK_SHAPE(u, batch_size, dim, seqlen);
|
||||
CHECK_SHAPE(delta, batch_size, dim, seqlen);
|
||||
CHECK_SHAPE(A, dim, dstate);
|
||||
TORCH_CHECK(is_variable_B, "is_variable_B = False is disabled in favor of reduced binary size")
|
||||
CHECK_SHAPE(B, batch_size, n_groups, dstate, seqlen );
|
||||
TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1);
|
||||
|
||||
TORCH_CHECK(is_variable_C, "is_variable_C = False is disabled in favor of reduced binary size")
|
||||
CHECK_SHAPE(C, batch_size, n_groups, dstate, seqlen);
|
||||
TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1);
|
||||
|
||||
if (D_.has_value()) {
|
||||
auto D = D_.value();
|
||||
TORCH_CHECK(D.scalar_type() == at::ScalarType::Float);
|
||||
TORCH_CHECK(D.is_cuda());
|
||||
TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1);
|
||||
CHECK_SHAPE(D, dim);
|
||||
}
|
||||
|
||||
if (delta_bias_.has_value()) {
|
||||
auto delta_bias = delta_bias_.value();
|
||||
TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float);
|
||||
TORCH_CHECK(delta_bias.is_cuda());
|
||||
TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1);
|
||||
CHECK_SHAPE(delta_bias, dim);
|
||||
}
|
||||
if (index_.has_value()) {
|
||||
auto index = index_.value();
|
||||
TORCH_CHECK(index.scalar_type() == at::ScalarType::Int);
|
||||
TORCH_CHECK(index.is_cuda());
|
||||
CHECK_SHAPE(index, batch_size, seqlen);
|
||||
}
|
||||
|
||||
at::Tensor z, out_z;
|
||||
const bool has_z = z_.has_value();
|
||||
TORCH_CHECK(has_z, "has_z = False is disabled in favor of reduced binary size")
|
||||
z = z_.value();
|
||||
TORCH_CHECK(z.scalar_type() == input_type);
|
||||
TORCH_CHECK(z.is_cuda());
|
||||
TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1);
|
||||
CHECK_SHAPE(z, batch_size, dim, seqlen);
|
||||
out_z = torch::empty_like(z);
|
||||
|
||||
const int n_chunks = (seqlen + 2048 - 1) / 2048;
|
||||
// const int n_chunks = (seqlen + 1024 - 1) / 1024;
|
||||
// at::Tensor out = torch::empty_like(u);
|
||||
// Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout
|
||||
at::Tensor out = torch::empty_like(delta);
|
||||
if (x.has_value()){
|
||||
auto _x = x.value();
|
||||
TORCH_CHECK(_x.scalar_type() == weight_type);
|
||||
TORCH_CHECK(_x.is_cuda());
|
||||
TORCH_CHECK(_x.stride(-1) == 1);
|
||||
CHECK_SHAPE(_x, batch_size, dim, n_chunks, dstate * 2);
|
||||
}
|
||||
|
||||
SSMParamsBase params;
|
||||
set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C,
|
||||
u, delta, A, B, C, out, z, out_z,
|
||||
D_.has_value() ? D_.value().data_ptr() : nullptr,
|
||||
delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr,
|
||||
x.value().data_ptr(),
|
||||
has_z,
|
||||
delta_softplus,
|
||||
index_.has_value() ? index_.value().data_ptr() : nullptr);
|
||||
|
||||
// Otherwise the kernel will be launched from cuda:0 device
|
||||
// Cast to char to avoid compiler warning about narrowing
|
||||
at::cuda::CUDAGuard device_guard{(char)u.get_device()};
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] {
|
||||
selective_scan_fwd_cuda<input_t, weight_t>(params, stream);
|
||||
});
|
||||
std::vector<at::Tensor> result = {out, x.value()};
|
||||
if (has_z) { result.push_back(out_z); }
|
||||
return result;
|
||||
}
|
||||
|
28
csrc/mamba/mamba_ssm/static_switch.h
Normal file
28
csrc/mamba/mamba_ssm/static_switch.h
Normal file
@ -0,0 +1,28 @@
|
||||
// Inspired by
|
||||
// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
|
||||
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
|
||||
|
||||
// clang-format off
|
||||
// adapted from https://github.com/state-spaces/mamba/blob/main/csrc/selective_scan/static_switch.h
|
||||
#pragma once
|
||||
|
||||
/// @param COND - a boolean expression to switch by
|
||||
/// @param CONST_NAME - a name given for the constexpr bool variable.
|
||||
/// @param ... - code to execute for true and false
|
||||
///
|
||||
/// Usage:
|
||||
/// ```
|
||||
/// BOOL_SWITCH(flag, BoolConst, [&] {
|
||||
/// some_function<BoolConst>(...);
|
||||
/// });
|
||||
/// ```
|
||||
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
|
||||
[&] { \
|
||||
if (COND) { \
|
||||
constexpr bool CONST_NAME = true; \
|
||||
return __VA_ARGS__(); \
|
||||
} else { \
|
||||
constexpr bool CONST_NAME = false; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
}()
|
1740
csrc/moe/marlin_moe_ops.cu
Normal file
1740
csrc/moe/marlin_moe_ops.cu
Normal file
File diff suppressed because it is too large
Load Diff
12
csrc/moe/marlin_moe_ops.h
Normal file
12
csrc/moe/marlin_moe_ops.h
Normal file
@ -0,0 +1,12 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/all.h>
|
||||
|
||||
torch::Tensor marlin_gemm_moe(
|
||||
const torch::Tensor& a, const torch::Tensor& b_q_weights,
|
||||
const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights,
|
||||
const torch::Tensor& topk_ids, const torch::Tensor& b_scales,
|
||||
const torch::Tensor& g_idx, const torch::Tensor& perm,
|
||||
torch::Tensor& workspace, int64_t size_m, int64_t size_n, int64_t size_k,
|
||||
bool is_k_full, int64_t num_experts, int64_t topk, int64_t moe_block_size,
|
||||
bool replicate_input, bool apply_weights);
|
@ -1,5 +1,6 @@
|
||||
#include "core/registration.h"
|
||||
#include "moe_ops.h"
|
||||
#include "marlin_moe_ops.h"
|
||||
|
||||
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
||||
// Apply topk softmax to the gating outputs.
|
||||
@ -7,6 +8,17 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
||||
"topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! "
|
||||
"token_expert_indices, Tensor gating_output) -> ()");
|
||||
m.impl("topk_softmax", torch::kCUDA, &topk_softmax);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
m.def(
|
||||
"marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
|
||||
"Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! "
|
||||
"g_idx, Tensor! perm, Tensor! workspace, int size_m, int size_n, int "
|
||||
"size_k, bool is_k_full, int num_experts, int topk, int moe_block_size, "
|
||||
"bool replicate_input, bool apply_weights) -> Tensor");
|
||||
|
||||
m.impl("marlin_gemm_moe", torch::kCUDA, &marlin_gemm_moe);
|
||||
#endif
|
||||
}
|
||||
|
||||
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
||||
|
22
csrc/ops.h
22
csrc/ops.h
@ -195,6 +195,28 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
||||
torch::Tensor experts_ids,
|
||||
torch::Tensor num_tokens_post_pad);
|
||||
|
||||
std::vector<torch::Tensor> selective_scan_fwd(
|
||||
const torch::Tensor& u, const torch::Tensor& delta, const torch::Tensor& A,
|
||||
const torch::Tensor& B, const torch::Tensor& C,
|
||||
const c10::optional<torch::Tensor>& D_,
|
||||
const c10::optional<torch::Tensor>& z_,
|
||||
const c10::optional<torch::Tensor>& delta_bias_, bool delta_softplus,
|
||||
const c10::optional<torch::Tensor>& index_,
|
||||
const c10::optional<torch::Tensor>& x);
|
||||
|
||||
at::Tensor causal_conv1d_update(const at::Tensor& x,
|
||||
const at::Tensor& conv_state,
|
||||
const at::Tensor& weight,
|
||||
const c10::optional<at::Tensor>& bias_,
|
||||
bool silu_activation);
|
||||
|
||||
at::Tensor causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,
|
||||
const c10::optional<at::Tensor>& bias_,
|
||||
const c10::optional<at::Tensor>& seq_idx_,
|
||||
const c10::optional<at::Tensor>& initial_states_,
|
||||
const c10::optional<at::Tensor>& final_states_out_,
|
||||
bool silu_activation);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
using fptr_t = int64_t;
|
||||
fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
|
||||
|
@ -202,6 +202,31 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.def("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8);
|
||||
ops.impl("cutlass_scaled_mm_supports_fp8", torch::kCUDA,
|
||||
&cutlass_scaled_mm_supports_fp8);
|
||||
// Mamba selective scan kernel
|
||||
ops.def(
|
||||
"selective_scan_fwd(Tensor! u, Tensor! delta,"
|
||||
"Tensor! A, Tensor! B, Tensor! C,"
|
||||
"Tensor? D_, Tensor? z_, Tensor? delta_bias_,"
|
||||
"bool delta_softplus,"
|
||||
"Tensor? index_, Tensor? x) -> Tensor[]");
|
||||
ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);
|
||||
|
||||
ops.def(
|
||||
"causal_conv1d_update(Tensor! x,"
|
||||
"Tensor! conv_state,"
|
||||
"Tensor! weight,"
|
||||
"Tensor? bias_,"
|
||||
"bool silu_activation) -> Tensor");
|
||||
ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update);
|
||||
|
||||
ops.def(
|
||||
"causal_conv1d_fwd(Tensor! x, Tensor! weight,"
|
||||
"Tensor? bias_,"
|
||||
"Tensor? seq_idx_,"
|
||||
"Tensor? initial_states_,"
|
||||
"Tensor? final_states_out_,"
|
||||
"bool silu_activation) -> Tensor");
|
||||
ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd);
|
||||
#endif
|
||||
|
||||
// Quantized GEMM for GPTQ.
|
||||
|
@ -12,3 +12,5 @@ torch
|
||||
py-cpuinfo
|
||||
transformers
|
||||
openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
|
||||
mistral_common >= 1.3.4
|
||||
openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
|
@ -45,8 +45,6 @@ Base Classes
|
||||
|
||||
.. autodata:: vllm.multimodal.NestedTensors
|
||||
|
||||
.. autodata:: vllm.multimodal.BatchedTensors
|
||||
|
||||
.. autodata:: vllm.multimodal.BatchedTensorInputs
|
||||
|
||||
.. autoclass:: vllm.multimodal.MultiModalDataBuiltins
|
||||
|
@ -56,9 +56,10 @@ First, install the dependencies:
|
||||
$ pip uninstall torch torch-xla -y
|
||||
|
||||
$ # Install PyTorch and PyTorch XLA.
|
||||
$ export DATE="+20240808"
|
||||
$ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-nightly${DATE}-cp310-cp310-linux_x86_64.whl
|
||||
$ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly${DATE}-cp310-cp310-linux_x86_64.whl
|
||||
$ export DATE="20240828"
|
||||
$ export TORCH_VERSION="2.5.0"
|
||||
$ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-${TORCH_VERSION}.dev${DATE}-cp310-cp310-linux_x86_64.whl
|
||||
$ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-${TORCH_VERSION}.dev${DATE}-cp310-cp310-linux_x86_64.whl
|
||||
|
||||
$ # Install JAX and Pallas.
|
||||
$ pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
|
||||
|
@ -51,6 +51,10 @@ Decoder-only Language Models
|
||||
- DeciLM
|
||||
- :code:`Deci/DeciLM-7B`, :code:`Deci/DeciLM-7B-instruct`, etc.
|
||||
-
|
||||
* - :code:`ExaoneForCausalLM`
|
||||
- EXAONE-3
|
||||
- :code:`LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct`, etc.
|
||||
- ✅︎
|
||||
* - :code:`FalconForCausalLM`
|
||||
- Falcon
|
||||
- :code:`tiiuae/falcon-7b`, :code:`tiiuae/falcon-40b`, :code:`tiiuae/falcon-rw-7b`, etc.
|
||||
@ -143,6 +147,10 @@ Decoder-only Language Models
|
||||
- Phi-3-Small
|
||||
- :code:`microsoft/Phi-3-small-8k-instruct`, :code:`microsoft/Phi-3-small-128k-instruct`, etc.
|
||||
-
|
||||
* - :code:`PhiMoEForCausalLM`
|
||||
- Phi-3.5-MoE
|
||||
- :code:`microsoft/Phi-3.5-MoE-instruct`, etc.
|
||||
-
|
||||
* - :code:`PersimmonForCausalLM`
|
||||
- Persimmon
|
||||
- :code:`adept/persimmon-8b-base`, :code:`adept/persimmon-8b-chat`, etc.
|
||||
|
@ -20,4 +20,4 @@ The performance benchmarks and nightly benchmarks can be triggered by submitting
|
||||
|
||||
.. note::
|
||||
|
||||
Please refer to `vLLM performance benchmark descriptions <https://github.com/vllm-project/vllm/blob/main/.buildkite/nightly-benchmarks/tests/descriptions.md>`_ and `vLLM nightly benchmark descriptions <https://github.com/vllm-project/vllm/blob/main/.buildkite/nightly-benchmarks/nightly-descriptions.md>`_ for detailed descriptions on benchmark environment, workload and metrics.
|
||||
Please refer to `vLLM performance benchmark descriptions <https://github.com/vllm-project/vllm/blob/main/.buildkite/nightly-benchmarks/performance-benchmarks-descriptions.md>`_ and `vLLM nightly benchmark descriptions <https://github.com/vllm-project/vllm/blob/main/.buildkite/nightly-benchmarks/nightly-descriptions.md>`_ for detailed descriptions on benchmark environment, workload and metrics.
|
||||
|
@ -19,27 +19,31 @@ You can quantize your own models by installing AutoAWQ or picking one of the `40
|
||||
|
||||
$ pip install autoawq
|
||||
|
||||
After installing AutoAWQ, you are ready to quantize a model. Here is an example of how to quantize Vicuna 7B v1.5:
|
||||
After installing AutoAWQ, you are ready to quantize a model. Here is an example of how to quantize `mistralai/Mistral-7B-Instruct-v0.2`:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from awq import AutoAWQForCausalLM
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
model_path = 'lmsys/vicuna-7b-v1.5'
|
||||
quant_path = 'vicuna-7b-v1.5-awq'
|
||||
|
||||
model_path = 'mistralai/Mistral-7B-Instruct-v0.2'
|
||||
quant_path = 'mistral-instruct-v0.2-awq'
|
||||
quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM" }
|
||||
|
||||
|
||||
# Load model
|
||||
model = AutoAWQForCausalLM.from_pretrained(model_path, **{"low_cpu_mem_usage": True})
|
||||
model = AutoAWQForCausalLM.from_pretrained(
|
||||
model_path, **{"low_cpu_mem_usage": True, "use_cache": False}
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
|
||||
|
||||
# Quantize
|
||||
model.quantize(tokenizer, quant_config=quant_config)
|
||||
|
||||
|
||||
# Save quantized model
|
||||
model.save_quantized(quant_path)
|
||||
tokenizer.save_pretrained(quant_path)
|
||||
|
||||
print(f'Model is quantized and saved at "{quant_path}"')
|
||||
|
||||
To run an AWQ model with vLLM, you can use `TheBloke/Llama-2-7b-Chat-AWQ <https://huggingface.co/TheBloke/Llama-2-7b-Chat-AWQ>`_ with the following command:
|
||||
|
||||
|
@ -110,14 +110,90 @@ directory [here](https://github.com/vllm-project/vllm/tree/main/examples/)
|
||||
:func: create_parser_for_docs
|
||||
:prog: vllm serve
|
||||
```
|
||||
## Tool Calling in the Chat Completion API
|
||||
### Named Function Calling
|
||||
vLLM supports only named function calling in the chat completion API by default. It does so using Outlines, so this is
|
||||
enabled by default, and will work with any supported model. You are guaranteed a validly-parsable function call - not a
|
||||
high-quality one.
|
||||
|
||||
To use a named function, you need to define the functions in the `tools` parameter of the chat completion request, and
|
||||
specify the `name` of one of the tools in the `tool_choice` parameter of the chat completion request.
|
||||
|
||||
### Config file
|
||||
|
||||
The `serve` module can also accept arguments from a config file in
|
||||
`yaml` format. The arguments in the yaml must be specified using the
|
||||
long form of the argument outlined [here](https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#command-line-arguments-for-the-server):
|
||||
|
||||
For example:
|
||||
|
||||
```yaml
|
||||
# config.yaml
|
||||
|
||||
host: "127.0.0.1"
|
||||
port: 6379
|
||||
uvicorn-log-level: "info"
|
||||
```
|
||||
|
||||
```bash
|
||||
$ vllm serve SOME_MODEL --config config.yaml
|
||||
```
|
||||
---
|
||||
**NOTE**
|
||||
In case an argument is supplied using command line and the config file, the value from the commandline will take precedence.
|
||||
The order of priorities is `command line > config file values > defaults`.
|
||||
|
||||
---
|
||||
|
||||
## Tool calling in the chat completion API
|
||||
vLLM supports only named function calling in the chat completion API. The `tool_choice` options `auto` and `required` are **not yet supported** but on the roadmap.
|
||||
|
||||
To use a named function you need to define the function in the `tools` parameter and call it in the `tool_choice` parameter.
|
||||
|
||||
It is the callers responsibility to prompt the model with the tool information, vLLM will not automatically manipulate the prompt. **This may change in the future.**
|
||||
It is the callers responsibility to prompt the model with the tool information, vLLM will not automatically manipulate the prompt.
|
||||
|
||||
vLLM will use guided decoding to ensure the response matches the tool parameter object defined by the JSON schema in the `tools` parameter.
|
||||
|
||||
Please refer to the OpenAI API reference documentation for more information.
|
||||
|
||||
### Automatic Function Calling
|
||||
To enable this feature, you should set the following flags:
|
||||
* `--enable-auto-tool-choice` -- **mandatory** Auto tool choice. tells vLLM that you want to enable the model to generate its own tool calls when it
|
||||
deems appropriate.
|
||||
* `--tool-call-parser` -- select the tool parser to use - currently either `hermes` or `mistral`. Additional tool parsers
|
||||
will continue to be added in the future.
|
||||
* `--chat-template` -- **optional** for auto tool choice. the path to the chat template which handles `tool`-role messages and `assistant`-role messages
|
||||
that contain previously generated tool calls. Hermes and Mistral models have tool-compatible chat templates in their
|
||||
`tokenizer_config.json` files, but you can specify a custom template. This argument can be set to `tool_use` if your model has a tool use-specific chat
|
||||
template configured in the `tokenizer_config.json`. In this case, it will be used per the `transformers` specification. More on this [here](https://huggingface.co/docs/transformers/en/chat_templating#why-do-some-models-have-multiple-templates)
|
||||
from HuggingFace; and you can find an example of this in a `tokenizer_config.json` [here](https://huggingface.co/NousResearch/Hermes-2-Pro-Llama-3-8B/blob/main/tokenizer_config.json)
|
||||
|
||||
If your favorite tool-calling model is not supported, please feel free to contribute a parser & tool use chat template!
|
||||
|
||||
#### Hermes Models
|
||||
All Nous Research Hermes-series models newer than Hermes 2 Pro should be supported.
|
||||
* `NousResearch/Hermes-2-Pro-*`
|
||||
* `NousResearch/Hermes-2-Theta-*`
|
||||
* `NousResearch/Hermes-3-*`
|
||||
|
||||
|
||||
_Note that the Hermes 2 **Theta** models are known to have degraded tool call quality & capabilities due to the merge
|
||||
step in their creation_.
|
||||
|
||||
Flags: `--tool-call-parser hermes`
|
||||
|
||||
#### Mistral Models
|
||||
Supported models:
|
||||
* `mistralai/Mistral-7B-Instruct-v0.3` (confirmed)
|
||||
* Additional mistral function-calling models are compatible as well.
|
||||
|
||||
Known issues:
|
||||
1. Mistral 7B struggles to generate parallel tool calls correctly.
|
||||
2. Mistral's `tokenizer_config.json` chat template requires tool call IDs that are exactly 9 digits, which is
|
||||
much shorter than what vLLM generates. Since an exception is thrown when this condition
|
||||
is not met, the following additional chat templates are provided:
|
||||
|
||||
* `examples/tool_chat_template_mistral.jinja` - this is the "official" Mistral chat template, but tweaked so that
|
||||
it works with vLLM's tool call IDs (provided `tool_call_id` fields are truncated to the last 9 digits)
|
||||
* `examples/tool_chat_template_mistral_parallel.jinja` - this is a "better" version that adds a tool-use system prompt
|
||||
when tools are provided, that results in much better reliability when working with parallel tool calling.
|
||||
|
||||
|
||||
Recommended flags: `--tool-call-parser mistral --chat-template examples/tool_chat_template_mistral_parallel.jinja`
|
||||
|
@ -11,25 +11,33 @@ from vllm import LLM, SamplingParams
|
||||
from vllm.assets.audio import AudioAsset
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
# Input audio and question
|
||||
audio_and_sample_rate = AudioAsset("mary_had_lamb").audio_and_sample_rate
|
||||
question = "What is recited in the audio?"
|
||||
audio_assets = [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")]
|
||||
question_per_audio_count = [
|
||||
"What is recited in the audio?",
|
||||
"What sport and what nursery rhyme are referenced?"
|
||||
]
|
||||
|
||||
|
||||
# Ultravox 0.3
|
||||
def run_ultravox(question):
|
||||
def run_ultravox(question, audio_count):
|
||||
model_name = "fixie-ai/ultravox-v0_3"
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
messages = [{
|
||||
'role': 'user',
|
||||
'content': f"<|reserved_special_token_0|>\n{question}"
|
||||
'role':
|
||||
'user',
|
||||
'content':
|
||||
"<|reserved_special_token_0|>\n" * audio_count + question
|
||||
}]
|
||||
prompt = tokenizer.apply_chat_template(messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True)
|
||||
|
||||
llm = LLM(model=model_name)
|
||||
llm = LLM(model=model_name,
|
||||
enforce_eager=True,
|
||||
enable_chunked_prefill=False,
|
||||
max_model_len=8192,
|
||||
limit_mm_per_prompt={"audio": audio_count})
|
||||
stop_token_ids = None
|
||||
return llm, prompt, stop_token_ids
|
||||
|
||||
@ -44,7 +52,9 @@ def main(args):
|
||||
if model not in model_example_map:
|
||||
raise ValueError(f"Model type {model} is not supported.")
|
||||
|
||||
llm, prompt, stop_token_ids = model_example_map[model](question)
|
||||
audio_count = args.num_audios
|
||||
llm, prompt, stop_token_ids = model_example_map[model](
|
||||
question_per_audio_count[audio_count - 1], audio_count)
|
||||
|
||||
# We set temperature to 0.2 so that outputs can be different
|
||||
# even when all prompts are identical when running batch inference.
|
||||
@ -53,23 +63,18 @@ def main(args):
|
||||
stop_token_ids=stop_token_ids)
|
||||
|
||||
assert args.num_prompts > 0
|
||||
if args.num_prompts == 1:
|
||||
# Single inference
|
||||
inputs = {
|
||||
"prompt": prompt,
|
||||
"multi_modal_data": {
|
||||
"audio": audio_and_sample_rate
|
||||
},
|
||||
}
|
||||
|
||||
else:
|
||||
inputs = {
|
||||
"prompt": prompt,
|
||||
"multi_modal_data": {
|
||||
"audio": [
|
||||
asset.audio_and_sample_rate
|
||||
for asset in audio_assets[:audio_count]
|
||||
]
|
||||
},
|
||||
}
|
||||
if args.num_prompts > 1:
|
||||
# Batch inference
|
||||
inputs = [{
|
||||
"prompt": prompt,
|
||||
"multi_modal_data": {
|
||||
"audio": audio_and_sample_rate
|
||||
},
|
||||
} for _ in range(args.num_prompts)]
|
||||
inputs = [inputs] * args.num_prompts
|
||||
|
||||
outputs = llm.generate(inputs, sampling_params=sampling_params)
|
||||
|
||||
@ -92,6 +97,11 @@ if __name__ == "__main__":
|
||||
type=int,
|
||||
default=1,
|
||||
help='Number of prompts to run.')
|
||||
parser.add_argument("--num-audios",
|
||||
type=int,
|
||||
default=1,
|
||||
choices=[1, 2],
|
||||
help="Number of audio items per prompt.")
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
@ -1,5 +1,12 @@
|
||||
import os
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
# creates XLA hlo graphs for all the context length buckets.
|
||||
os.environ['NEURON_CONTEXT_LENGTH_BUCKETS'] = "128,512,1024,2048"
|
||||
# creates XLA hlo graphs for all the token gen buckets.
|
||||
os.environ['NEURON_TOKEN_GEN_BUCKETS'] = "128,512,1024,2048"
|
||||
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
@ -19,8 +26,8 @@ llm = LLM(
|
||||
# Currently, this is a known limitation in continuous batching support
|
||||
# in transformers-neuronx.
|
||||
# TODO(liangfu): Support paged-attention in transformers-neuronx.
|
||||
max_model_len=128,
|
||||
block_size=128,
|
||||
max_model_len=2048,
|
||||
block_size=2048,
|
||||
# The device can be automatically detected when AWS Neuron SDK is installed.
|
||||
# The device argument can be either unspecified for automated detection,
|
||||
# or explicitly assigned.
|
||||
|
50
examples/offline_inference_neuron_int8_quantization.py
Normal file
50
examples/offline_inference_neuron_int8_quantization.py
Normal file
@ -0,0 +1,50 @@
|
||||
import os
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
# creates XLA hlo graphs for all the context length buckets.
|
||||
os.environ['NEURON_CONTEXT_LENGTH_BUCKETS'] = "128,512,1024,2048"
|
||||
# creates XLA hlo graphs for all the token gen buckets.
|
||||
os.environ['NEURON_TOKEN_GEN_BUCKETS'] = "128,512,1024,2048"
|
||||
# Quantizes neuron model weight to int8 ,
|
||||
# The default config for quantization is int8 dtype.
|
||||
os.environ['NEURON_QUANT_DTYPE'] = "s8"
|
||||
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
# Create a sampling params object.
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
|
||||
# Create an LLM.
|
||||
llm = LLM(
|
||||
model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
max_num_seqs=8,
|
||||
# The max_model_len and block_size arguments are required to be same as
|
||||
# max sequence length when targeting neuron device.
|
||||
# Currently, this is a known limitation in continuous batching support
|
||||
# in transformers-neuronx.
|
||||
# TODO(liangfu): Support paged-attention in transformers-neuronx.
|
||||
max_model_len=2048,
|
||||
block_size=2048,
|
||||
# The device can be automatically detected when AWS Neuron SDK is installed.
|
||||
# The device argument can be either unspecified for automated detection,
|
||||
# or explicitly assigned.
|
||||
device="neuron",
|
||||
quantization="neuron_quant",
|
||||
override_neuron_config={
|
||||
"cast_logits_dtype": "bfloat16",
|
||||
},
|
||||
tensor_parallel_size=2)
|
||||
# Generate texts from the prompts. The output is a list of RequestOutput objects
|
||||
# that contain the prompt, generated text, and other information.
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
# Print the outputs.
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
162
examples/openai_chat_completion_client_with_tools.py
Normal file
162
examples/openai_chat_completion_client_with_tools.py
Normal file
@ -0,0 +1,162 @@
|
||||
"""
|
||||
Set up this example by starting a vLLM OpenAI-compatible server with tool call
|
||||
options enabled. For example:
|
||||
|
||||
IMPORTANT: for mistral, you must use one of the provided mistral tool call
|
||||
templates, or your own - the model default doesn't work for tool calls with vLLM
|
||||
See the vLLM docs on OpenAI server & tool calling for more details.
|
||||
|
||||
vllm serve --model mistralai/Mistral-7B-Instruct-v0.3 \
|
||||
--chat-template examples/tool_chat_template_mistral.jinja \
|
||||
--enable-auto-tool-choice --tool-call-parser mistral
|
||||
|
||||
OR
|
||||
vllm serve --model NousResearch/Hermes-2-Pro-Llama-3-8B \
|
||||
--chat-template examples/tool_chat_template_hermes.jinja \
|
||||
--enable-auto-tool-choice --tool-call-parser hermes
|
||||
"""
|
||||
import json
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
# Modify OpenAI's API key and API base to use vLLM's API server.
|
||||
openai_api_key = "EMPTY"
|
||||
openai_api_base = "http://localhost:8000/v1"
|
||||
|
||||
client = OpenAI(
|
||||
# defaults to os.environ.get("OPENAI_API_KEY")
|
||||
api_key=openai_api_key,
|
||||
base_url=openai_api_base,
|
||||
)
|
||||
|
||||
models = client.models.list()
|
||||
model = models.data[0].id
|
||||
|
||||
tools = [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type":
|
||||
"string",
|
||||
"description":
|
||||
"The city to find the weather for, e.g. 'San Francisco'"
|
||||
},
|
||||
"state": {
|
||||
"type":
|
||||
"string",
|
||||
"description":
|
||||
"the two-letter abbreviation for the state that the city is"
|
||||
" in, e.g. 'CA' which would mean 'California'"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"description": "The unit to fetch the temperature in",
|
||||
"enum": ["celsius", "fahrenheit"]
|
||||
}
|
||||
},
|
||||
"required": ["city", "state", "unit"]
|
||||
}
|
||||
}
|
||||
}]
|
||||
|
||||
messages = [{
|
||||
"role": "user",
|
||||
"content": "Hi! How are you doing today?"
|
||||
}, {
|
||||
"role": "assistant",
|
||||
"content": "I'm doing well! How can I help you?"
|
||||
}, {
|
||||
"role":
|
||||
"user",
|
||||
"content":
|
||||
"Can you tell me what the temperate will be in Dallas, in fahrenheit?"
|
||||
}]
|
||||
|
||||
chat_completion = client.chat.completions.create(messages=messages,
|
||||
model=model,
|
||||
tools=tools)
|
||||
|
||||
print("Chat completion results:")
|
||||
print(chat_completion)
|
||||
print("\n\n")
|
||||
|
||||
tool_calls_stream = client.chat.completions.create(messages=messages,
|
||||
model=model,
|
||||
tools=tools,
|
||||
stream=True)
|
||||
|
||||
chunks = []
|
||||
for chunk in tool_calls_stream:
|
||||
chunks.append(chunk)
|
||||
if chunk.choices[0].delta.tool_calls:
|
||||
print(chunk.choices[0].delta.tool_calls[0])
|
||||
else:
|
||||
print(chunk.choices[0].delta)
|
||||
|
||||
arguments = []
|
||||
tool_call_idx = -1
|
||||
for chunk in chunks:
|
||||
|
||||
if chunk.choices[0].delta.tool_calls:
|
||||
tool_call = chunk.choices[0].delta.tool_calls[0]
|
||||
|
||||
if tool_call.index != tool_call_idx:
|
||||
if tool_call_idx >= 0:
|
||||
print(
|
||||
f"streamed tool call arguments: {arguments[tool_call_idx]}"
|
||||
)
|
||||
tool_call_idx = chunk.choices[0].delta.tool_calls[0].index
|
||||
arguments.append("")
|
||||
if tool_call.id:
|
||||
print(f"streamed tool call id: {tool_call.id} ")
|
||||
|
||||
if tool_call.function:
|
||||
if tool_call.function.name:
|
||||
print(f"streamed tool call name: {tool_call.function.name}")
|
||||
|
||||
if tool_call.function.arguments:
|
||||
arguments[tool_call_idx] += tool_call.function.arguments
|
||||
|
||||
if len(arguments):
|
||||
print(f"streamed tool call arguments: {arguments[-1]}")
|
||||
|
||||
print("\n\n")
|
||||
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"tool_calls": chat_completion.choices[0].message.tool_calls
|
||||
})
|
||||
|
||||
|
||||
# Now, simulate a tool call
|
||||
def get_current_weather(city: str, state: str, unit: 'str'):
|
||||
return ("The weather in Dallas, Texas is 85 degrees fahrenheit. It is "
|
||||
"partly cloudly, with highs in the 90's.")
|
||||
|
||||
|
||||
available_tools = {"get_current_weather": get_current_weather}
|
||||
|
||||
completion_tool_calls = chat_completion.choices[0].message.tool_calls
|
||||
for call in completion_tool_calls:
|
||||
tool_to_call = available_tools[call.function.name]
|
||||
args = json.loads(call.function.arguments)
|
||||
result = tool_to_call(**args)
|
||||
print(result)
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"content": result,
|
||||
"tool_call_id": call.id,
|
||||
"name": call.function.name
|
||||
})
|
||||
|
||||
chat_completion_2 = client.chat.completions.create(messages=messages,
|
||||
model=model,
|
||||
tools=tools,
|
||||
stream=False)
|
||||
print("\n\n")
|
||||
print(chat_completion_2)
|
@ -19,7 +19,6 @@ responses = client.embeddings.create(
|
||||
"The best thing about vLLM is that it supports many different models"
|
||||
],
|
||||
model=model,
|
||||
encoding_format="float",
|
||||
)
|
||||
|
||||
for data in responses.data:
|
||||
|
@ -1,7 +1,13 @@
|
||||
"""An example showing how to use vLLM to serve VLMs.
|
||||
|
||||
Launch the vLLM server with the following command:
|
||||
|
||||
(single image inference with Llava)
|
||||
vllm serve llava-hf/llava-1.5-7b-hf --chat-template template_llava.jinja
|
||||
|
||||
(multi-image inference with Phi-3.5-vision-instruct)
|
||||
vllm serve microsoft/Phi-3.5-vision-instruct --max-model-len 4096 \
|
||||
--trust-remote-code --limit-mm-per-prompt image=2
|
||||
"""
|
||||
import base64
|
||||
|
||||
@ -84,3 +90,36 @@ chat_completion_from_base64 = client.chat.completions.create(
|
||||
|
||||
result = chat_completion_from_base64.choices[0].message.content
|
||||
print(f"Chat completion output:{result}")
|
||||
|
||||
# Multi-image input inference
|
||||
image_url_duck = "https://upload.wikimedia.org/wikipedia/commons/d/da/2015_Kaczka_krzy%C5%BCowka_w_wodzie_%28samiec%29.jpg"
|
||||
image_url_lion = "https://upload.wikimedia.org/wikipedia/commons/7/77/002_The_lion_king_Snyggve_in_the_Serengeti_National_Park_Photo_by_Giles_Laurent.jpg"
|
||||
chat_completion_from_url = client.chat.completions.create(
|
||||
messages=[{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What are the animals in these images?"
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url_duck
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url_lion
|
||||
},
|
||||
},
|
||||
],
|
||||
}],
|
||||
model=model,
|
||||
max_tokens=64,
|
||||
)
|
||||
|
||||
result = chat_completion_from_url.choices[0].message.content
|
||||
print(f"Chat completion output:{result}")
|
||||
|
129
examples/tool_chat_template_hermes.jinja
Normal file
129
examples/tool_chat_template_hermes.jinja
Normal file
@ -0,0 +1,129 @@
|
||||
{%- macro json_to_python_type(json_spec) %}
|
||||
{%- set basic_type_map = {
|
||||
"string": "str",
|
||||
"number": "float",
|
||||
"integer": "int",
|
||||
"boolean": "bool"
|
||||
} %}
|
||||
|
||||
{%- if basic_type_map[json_spec.type] is defined %}
|
||||
{{- basic_type_map[json_spec.type] }}
|
||||
{%- elif json_spec.type == "array" %}
|
||||
{{- "list[" + json_to_python_type(json_spec|items) + "]" }}
|
||||
{%- elif json_spec.type == "object" %}
|
||||
{%- if json_spec.additionalProperties is defined %}
|
||||
{{- "dict[str, " + json_to_python_type(json_spec.additionalProperties) + ']' }}
|
||||
{%- else %}
|
||||
{{- "dict" }}
|
||||
{%- endif %}
|
||||
{%- elif json_spec.type is iterable %}
|
||||
{{- "Union[" }}
|
||||
{%- for t in json_spec.type %}
|
||||
{{- json_to_python_type({"type": t}) }}
|
||||
{%- if not loop.last %}
|
||||
{{- "," }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{{- "]" }}
|
||||
{%- else %}
|
||||
{{- "Any" }}
|
||||
{%- endif %}
|
||||
{%- endmacro %}
|
||||
|
||||
|
||||
{{- bos_token }}
|
||||
{{- "<|im_start|>system\nYou are a function calling AI model. You are provided with function signatures within <tools></tools> XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: <tools> " }}
|
||||
{%- if tools is iterable and tools | length > 0 %}
|
||||
{%- for tool in tools %}
|
||||
{%- if tool.function is defined %}
|
||||
{%- set tool = tool.function %}
|
||||
{%- endif %}
|
||||
{{- '{"type": "function", "function": ' }}
|
||||
{{- '{"name": "' + tool.name + '", ' }}
|
||||
{{- '"description": "' + tool.name + '(' }}
|
||||
{%- for param_name, param_fields in tool.parameters.properties|items %}
|
||||
{{- param_name + ": " + json_to_python_type(param_fields) }}
|
||||
{%- if not loop.last %}
|
||||
{{- ", " }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{{- ")" }}
|
||||
{%- if tool.return is defined %}
|
||||
{{- " -> " + json_to_python_type(tool.return) }}
|
||||
{%- endif %}
|
||||
{{- " - " + tool.description + "\n\n" }}
|
||||
{%- for param_name, param_fields in tool.parameters.properties|items %}
|
||||
{%- if loop.first %}
|
||||
{{- " Args:\n" }}
|
||||
{%- endif %}
|
||||
{{- " " + param_name + "(" + json_to_python_type(param_fields) + "): " + param_fields.description|trim }}
|
||||
{%- endfor %}
|
||||
{%- if tool.return is defined and tool.return.description is defined %}
|
||||
{{- "\n Returns:\n " + tool.return.description }}
|
||||
{%- endif %}
|
||||
{{- '"' }}
|
||||
{{- ', "parameters": ' }}
|
||||
{%- if tool.parameters.properties | length == 0 %}
|
||||
{{- "{}" }}
|
||||
{%- else %}
|
||||
{{- tool.parameters|tojson }}
|
||||
{%- endif %}
|
||||
{{- "}" }}
|
||||
{%- if not loop.last %}
|
||||
{{- "\n" }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- endif %}
|
||||
{{- " </tools>" }}
|
||||
{{- 'Use the following pydantic model json schema for each tool call you will make: {"properties": {"name": {"title": "Name", "type": "string"}, "arguments": {"title": "Arguments", "type": "object"}}, "required": ["name", "arguments"], "title": "FunctionCall", "type": "object"}}
|
||||
' }}
|
||||
{{- "For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:
|
||||
" }}
|
||||
{{- "<tool_call>
|
||||
" }}
|
||||
{{- '{"name": <function-name>, "arguments": <args-dict>}
|
||||
' }}
|
||||
{{- '</tool_call><|im_end|>' }}
|
||||
{%- for message in messages %}
|
||||
{%- if message.role == "user" or message.role == "system" or (message.role == "assistant" and message.tool_calls is not defined) %}
|
||||
{{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }}
|
||||
{%- elif message.role == "assistant" and message.tool_calls is defined %}
|
||||
{{- '<|im_start|>' + message.role }}
|
||||
{%- for tool_call in message.tool_calls %}
|
||||
{{- '\n<tool_call>\n' }}
|
||||
{%- if tool_call.function is defined %}
|
||||
{%- set tool_call = tool_call.function %}
|
||||
{%- endif %}
|
||||
{{- '{' }}
|
||||
{{- '"name": "' }}
|
||||
{{- tool_call.name }}
|
||||
{{- '"}' }}
|
||||
{{- ', ' }}
|
||||
{%- if tool_call.arguments is defined %}
|
||||
{{- '"arguments": ' }}
|
||||
{{- tool_call.arguments|tojson }}
|
||||
{%- endif %}
|
||||
{{- '\n</tool_call>' }}
|
||||
{%- endfor %}
|
||||
{{- '<|im_end|>\n' }}
|
||||
{%- elif message.role == "tool" %}
|
||||
{%- if loop.previtem and loop.previtem.role != "tool" %}
|
||||
{{- '<|im_start|>tool\n' }}
|
||||
{%- endif %}
|
||||
{{- '<tool_response>\n' }}
|
||||
{{- message.content }}
|
||||
{%- if not loop.last %}
|
||||
{{- '\n</tool_response>\n' }}
|
||||
{%- else %}
|
||||
{{- '\n</tool_response>' }}
|
||||
{%- endif %}
|
||||
{%- if not loop.last and loop.nextitem.role != "tool" %}
|
||||
{{- '<|im_end|>' }}
|
||||
{%- elif loop.last %}
|
||||
{{- '<|im_end|>' }}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- if add_generation_prompt %}
|
||||
{{- '<|im_start|>assistant\n' }}
|
||||
{%- endif %}
|
86
examples/tool_chat_template_mistral.jinja
Normal file
86
examples/tool_chat_template_mistral.jinja
Normal file
@ -0,0 +1,86 @@
|
||||
{%- if messages[0]["role"] == "system" %}
|
||||
{%- set system_message = messages[0]["content"] %}
|
||||
{%- set loop_messages = messages[1:] %}
|
||||
{%- else %}
|
||||
{%- set loop_messages = messages %}
|
||||
{%- endif %}
|
||||
{%- if not tools is defined %}
|
||||
{%- set tools = none %}
|
||||
{%- endif %}
|
||||
{%- set user_messages = loop_messages | selectattr("role", "equalto", "user") | list %}
|
||||
|
||||
{%- for message in loop_messages | rejectattr("role", "equalto", "tool") | rejectattr("role", "equalto", "tool_results") | selectattr("tool_calls", "undefined") %}
|
||||
{%- if (message["role"] == "user") != (loop.index0 % 2 == 0) %}
|
||||
{{- raise_exception("After the optional system message, conversation roles must alternate user/assistant/user/assistant/...") }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
|
||||
{{- bos_token }}
|
||||
{%- for message in loop_messages %}
|
||||
{%- if message["role"] == "user" %}
|
||||
{%- if tools is not none and (message == user_messages[-1]) %}
|
||||
{{- "[AVAILABLE_TOOLS] [" }}
|
||||
{%- for tool in tools %}
|
||||
{%- set tool = tool.function %}
|
||||
{{- '{"type": "function", "function": {' }}
|
||||
{%- for key, val in tool.items() if key != "return" %}
|
||||
{%- if val is string %}
|
||||
{{- '"' + key + '": "' + val + '"' }}
|
||||
{%- else %}
|
||||
{{- '"' + key + '": ' + val|tojson }}
|
||||
{%- endif %}
|
||||
{%- if not loop.last %}
|
||||
{{- ", " }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{{- "}}" }}
|
||||
{%- if not loop.last %}
|
||||
{{- ", " }}
|
||||
{%- else %}
|
||||
{{- "]" }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{{- "[/AVAILABLE_TOOLS]" }}
|
||||
{%- endif %}
|
||||
{%- if loop.last and system_message is defined %}
|
||||
{{- "[INST] " + system_message + "\n\n" + message["content"] + "[/INST]" }}
|
||||
{%- else %}
|
||||
{{- "[INST] " + message["content"] + "[/INST]" }}
|
||||
{%- endif %}
|
||||
{%- elif message["role"] == "tool_calls" or message.tool_calls is defined %}
|
||||
{%- if message.tool_calls is defined %}
|
||||
{%- set tool_calls = message.tool_calls %}
|
||||
{%- else %}
|
||||
{%- set tool_calls = message.content %}
|
||||
{%- endif %}
|
||||
{{- "[TOOL_CALLS] [" }}
|
||||
{%- for tool_call in tool_calls %}
|
||||
{%- set out = tool_call.function|tojson %}
|
||||
{{- out[:-1] }}
|
||||
{%- if not tool_call.id is defined or tool_call.id|length < 9 %}
|
||||
{{- raise_exception("Tool call IDs should be alphanumeric strings with length >= 9! (1)" + tool_call.id) }}
|
||||
{%- endif %}
|
||||
{{- ', "id": "' + tool_call.id[-9:] + '"}' }}
|
||||
{%- if not loop.last %}
|
||||
{{- ", " }}
|
||||
{%- else %}
|
||||
{{- "]" + eos_token }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- elif message["role"] == "assistant" %}
|
||||
{{- " " + message["content"] + eos_token }}
|
||||
{%- elif message["role"] == "tool_results" or message["role"] == "tool" %}
|
||||
{%- if message.content is defined and message.content.content is defined %}
|
||||
{%- set content = message.content.content %}
|
||||
{%- else %}
|
||||
{%- set content = message.content %}
|
||||
{%- endif %}
|
||||
{{- '[TOOL_RESULTS] {"content": ' + content|string + ", " }}
|
||||
{%- if not message.tool_call_id is defined or message.tool_call_id|length < 9 %}
|
||||
{{- raise_exception("Tool call IDs should be alphanumeric strings with length >= 9! (2)" + message.tool_call_id) }}
|
||||
{%- endif %}
|
||||
{{- '"call_id": "' + message.tool_call_id[-9:] + '"}[/TOOL_RESULTS]' }}
|
||||
{%- else %}
|
||||
{{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
94
examples/tool_chat_template_mistral_parallel.jinja
Normal file
94
examples/tool_chat_template_mistral_parallel.jinja
Normal file
@ -0,0 +1,94 @@
|
||||
{%- if messages[0]["role"] == "system" %}
|
||||
{%- set system_message = messages[0]["content"] %}
|
||||
{%- set loop_messages = messages[1:] %}
|
||||
{%- else %}
|
||||
{%- set loop_messages = messages %}
|
||||
{%- endif %}
|
||||
{%- if not tools is defined %}
|
||||
{%- set tools = none %}
|
||||
{%- endif %}
|
||||
{%- if tools is defined %}
|
||||
{%- set parallel_tool_prompt = "You are a helpful assistant that can call tools. If you call one or more tools, format them in a single JSON array or objects, where each object is a tool call, not as separate objects outside of an array or multiple arrays. Use the format [{\"name\": tool call name, \"arguments\": tool call arguments}, additional tool calls] if you call more than one tool. If you call tools, do not attempt to interpret them or otherwise provide a response until you receive a tool call result that you can interpret for the user." %}
|
||||
{%- if system_message is defined %}
|
||||
{%- set system_message = parallel_tool_prompt + "\n\n" + system_message %}
|
||||
{%- else %}
|
||||
{%- set system_message = parallel_tool_prompt %}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- set user_messages = loop_messages | selectattr("role", "equalto", "user") | list %}
|
||||
|
||||
{%- for message in loop_messages | rejectattr("role", "equalto", "tool") | rejectattr("role", "equalto", "tool_results") | selectattr("tool_calls", "undefined") %}
|
||||
{%- if (message["role"] == "user") != (loop.index0 % 2 == 0) %}
|
||||
{{- raise_exception("After the optional system message, conversation roles must alternate user/assistant/user/assistant/...") }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
|
||||
{{- bos_token }}
|
||||
{%- for message in loop_messages %}
|
||||
{%- if message["role"] == "user" %}
|
||||
{%- if tools is not none and (message == user_messages[-1]) %}
|
||||
{{- "[AVAILABLE_TOOLS] [" }}
|
||||
{%- for tool in tools %}
|
||||
{%- set tool = tool.function %}
|
||||
{{- '{"type": "function", "function": {' }}
|
||||
{%- for key, val in tool.items() if key != "return" %}
|
||||
{%- if val is string %}
|
||||
{{- '"' + key + '": "' + val + '"' }}
|
||||
{%- else %}
|
||||
{{- '"' + key + '": ' + val|tojson }}
|
||||
{%- endif %}
|
||||
{%- if not loop.last %}
|
||||
{{- ", " }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{{- "}}" }}
|
||||
{%- if not loop.last %}
|
||||
{{- ", " }}
|
||||
{%- else %}
|
||||
{{- "]" }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{{- "[/AVAILABLE_TOOLS]" }}
|
||||
{%- endif %}
|
||||
{%- if loop.last and system_message is defined %}
|
||||
{{- "[INST] " + system_message + "\n\n" + message["content"] + "[/INST]" }}
|
||||
{%- else %}
|
||||
{{- "[INST] " + message["content"] + "[/INST]" }}
|
||||
{%- endif %}
|
||||
{%- elif message["role"] == "tool_calls" or message.tool_calls is defined %}
|
||||
{%- if message.tool_calls is defined %}
|
||||
{%- set tool_calls = message.tool_calls %}
|
||||
{%- else %}
|
||||
{%- set tool_calls = message.content %}
|
||||
{%- endif %}
|
||||
{{- "[TOOL_CALLS] [" }}
|
||||
{%- for tool_call in tool_calls %}
|
||||
{%- set out = tool_call.function|tojson %}
|
||||
{{- out[:-1] }}
|
||||
{%- if not tool_call.id is defined or tool_call.id|length < 9 %}
|
||||
{{- raise_exception("Tool call IDs should be alphanumeric strings with length >= 9! (1)" + tool_call.id) }}
|
||||
{%- endif %}
|
||||
{{- ', "id": "' + tool_call.id[-9:] + '"}' }}
|
||||
{%- if not loop.last %}
|
||||
{{- ", " }}
|
||||
{%- else %}
|
||||
{{- "]" + eos_token }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- elif message["role"] == "assistant" %}
|
||||
{{- " " + message["content"] + eos_token }}
|
||||
{%- elif message["role"] == "tool_results" or message["role"] == "tool" %}
|
||||
{%- if message.content is defined and message.content.content is defined %}
|
||||
{%- set content = message.content.content %}
|
||||
{%- else %}
|
||||
{%- set content = message.content %}
|
||||
{%- endif %}
|
||||
{{- '[TOOL_RESULTS] {"content": ' + content|string + ", " }}
|
||||
{%- if not message.tool_call_id is defined or message.tool_call_id|length < 9 %}
|
||||
{{- raise_exception("Tool call IDs should be alphanumeric strings with length >= 9! (2)" + message.tool_call_id) }}
|
||||
{%- endif %}
|
||||
{{- '"call_id": "' + message.tool_call_id[-9:] + '"}[/TOOL_RESULTS]' }}
|
||||
{%- else %}
|
||||
{{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
@ -99,7 +99,6 @@ echo 'vLLM mypy:'
|
||||
mypy --follow-imports skip # Note that this is less strict than CI
|
||||
mypy tests --follow-imports skip
|
||||
mypy vllm/attention --follow-imports skip
|
||||
mypy vllm/core --follow-imports skip
|
||||
mypy vllm/distributed --follow-imports skip
|
||||
mypy vllm/engine --follow-imports skip
|
||||
mypy vllm/executor --follow-imports skip
|
||||
|
@ -58,6 +58,7 @@ files = [
|
||||
"vllm/adapter_commons",
|
||||
"vllm/assets",
|
||||
"vllm/entrypoints",
|
||||
"vllm/core",
|
||||
"vllm/inputs",
|
||||
"vllm/logging",
|
||||
"vllm/multimodal",
|
||||
|
@ -20,9 +20,10 @@ lm-format-enforcer == 0.10.6
|
||||
outlines >= 0.0.43, < 0.1 # Requires torch >= 2.1.0
|
||||
typing_extensions >= 4.10
|
||||
filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4
|
||||
partial-json-parser # used for parsing partial JSON outputs
|
||||
pyzmq
|
||||
msgspec
|
||||
librosa # Required for audio processing
|
||||
soundfile # Required for audio processing
|
||||
gguf == 0.9.1
|
||||
importlib_metadata
|
||||
mistral_common >= 1.3.4
|
||||
pyyaml
|
||||
|
@ -1,3 +0,0 @@
|
||||
# Mamba dependencies
|
||||
mamba-ssm>=1.2.2
|
||||
causal-conv1d>=1.2.0
|
@ -8,3 +8,4 @@ botocore
|
||||
ray >= 2.10.0
|
||||
peft
|
||||
pytest-asyncio
|
||||
tensorizer>=2.9.0
|
@ -11,12 +11,14 @@ pytest-shard
|
||||
|
||||
# testing utils
|
||||
awscli
|
||||
einops # required for MPT and qwen-vl
|
||||
einops # required for MPT, qwen-vl and Mamba
|
||||
httpx
|
||||
librosa # required for audio test
|
||||
peft
|
||||
requests
|
||||
ray
|
||||
sentence-transformers # required for embedding
|
||||
soundfile # required for audio test
|
||||
compressed-tensors==0.4.0 # required for compressed-tensors
|
||||
timm # required for internvl test
|
||||
transformers_stream_generator # required for qwen-vl test
|
||||
@ -30,4 +32,4 @@ aiohttp
|
||||
|
||||
# quantization
|
||||
bitsandbytes==0.42.0
|
||||
buildkite-test-collector==0.1.8
|
||||
buildkite-test-collector==0.1.8
|
||||
|
@ -4,4 +4,4 @@
|
||||
# Dependencies for TPU
|
||||
# Currently, the TPU backend uses a nightly version of PyTorch XLA.
|
||||
# You can install the dependencies in Dockerfile.tpu.
|
||||
ray
|
||||
ray[default]
|
||||
|
4
setup.py
4
setup.py
@ -362,7 +362,8 @@ def get_vllm_version() -> str:
|
||||
version = find_version(get_path("vllm", "version.py"))
|
||||
|
||||
if _no_device():
|
||||
version += "+empty"
|
||||
if envs.VLLM_TARGET_DEVICE == "empty":
|
||||
version += "+empty"
|
||||
elif _is_cuda():
|
||||
cuda_version = str(get_nvcc_cuda_version())
|
||||
if cuda_version != MAIN_CUDA_VERSION:
|
||||
@ -501,6 +502,7 @@ setup(
|
||||
ext_modules=ext_modules,
|
||||
extras_require={
|
||||
"tensorizer": ["tensorizer>=2.9.0"],
|
||||
"audio": ["librosa", "soundfile"] # Required for audio processing
|
||||
},
|
||||
cmdclass={"build_ext": cmake_build_ext} if len(ext_modules) > 0 else {},
|
||||
package_data=package_data,
|
||||
|
@ -1,5 +1,6 @@
|
||||
import openai # use the official client for correctness check
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from ..utils import VLLM_PATH, RemoteOpenAIServer
|
||||
|
||||
@ -31,9 +32,10 @@ def server():
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def client(server):
|
||||
return server.get_async_client()
|
||||
@pytest_asyncio.fixture
|
||||
async def client(server):
|
||||
async with server.get_async_client() as async_client:
|
||||
yield async_client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -6,6 +6,7 @@ prefill requests are chunked.
|
||||
|
||||
Run `pytest tests/models/test_chunked_prefill.py`.
|
||||
"""
|
||||
from contextlib import nullcontext
|
||||
|
||||
import pytest
|
||||
|
||||
@ -15,18 +16,6 @@ MODELS = [
|
||||
"facebook/opt-125m",
|
||||
"meta-llama/Llama-2-7b-hf",
|
||||
]
|
||||
E5M2_KV_MODELS = [
|
||||
"facebook/opt-125m",
|
||||
"meta-llama/Llama-2-7b-chat-hf",
|
||||
]
|
||||
E4M3_KV_MODELS = [
|
||||
"meta-llama/Llama-2-7b-chat-hf", "nm-testing/Qwen2-1.5B-Instruct-FP8-K-V",
|
||||
"nm-testing/TinyLlama-1.1B-compressed-tensors-kv-cache-scheme"
|
||||
]
|
||||
KV_CACHE_QUANTIZATION_PATHS = {
|
||||
"meta-llama/Llama-2-7b-chat-hf":
|
||||
"./tests/fp8_kv/llama2-7b-fp8-kv/kv_cache_scales.json"
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@ -77,10 +66,10 @@ def test_models(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("kv_cache_dtype,model",
|
||||
[("fp8_e5m2", m)
|
||||
for m in E5M2_KV_MODELS] + [("fp8_e4m3", m)
|
||||
for m in E4M3_KV_MODELS])
|
||||
@pytest.mark.parametrize(
|
||||
"kv_cache_dtype,model",
|
||||
[("fp8_e4m3",
|
||||
"nm-testing/TinyLlama-1.1B-compressed-tensors-kv-cache-scheme")])
|
||||
# Due to low-precision numerical divergence, we only test logprob of 4 tokens
|
||||
@pytest.mark.parametrize("max_tokens", [4])
|
||||
@pytest.mark.parametrize("chunked_prefill_token_size", [4, 16])
|
||||
@ -88,6 +77,9 @@ def test_models(
|
||||
# NOTE: Increasing this in this suite will fail CI because we currently cannot
|
||||
# reset distributed env properly. Use a value > 1 just when you test.
|
||||
@pytest.mark.parametrize("tensor_parallel_size", [1])
|
||||
# Due to low-precision numerical divergence, this test is too sensitive to
|
||||
# the async postprocessor
|
||||
@pytest.mark.parametrize("disable_async_output_proc", [True])
|
||||
def test_models_with_fp8_kv_cache(
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
@ -97,36 +89,25 @@ def test_models_with_fp8_kv_cache(
|
||||
chunked_prefill_token_size: int,
|
||||
enforce_eager: bool,
|
||||
tensor_parallel_size: int,
|
||||
disable_async_output_proc: bool,
|
||||
) -> None:
|
||||
"""
|
||||
Only checks log probs match between chunked-prefill and
|
||||
non-chunked-prefill version of vLLM model runner.
|
||||
|
||||
This test is used when there is discrepancy in kernels
|
||||
/ numerics (e.g. when using lower-precision types like FP8).
|
||||
Check output logprobs match between no_chunked_prefill and chunked_prefill
|
||||
with fp8 kv cache. General fp8 kv-cache tests are covered in test_fp8.py,
|
||||
so here we only check chunked prefill.
|
||||
"""
|
||||
NUM_LOG_PROBS = 8
|
||||
|
||||
if model == "facebook/opt-125m":
|
||||
pytest.skip(
|
||||
"#7378: CUDA illegal memory access (undiagnosed) facebook/opt-125m"
|
||||
)
|
||||
|
||||
max_num_seqs = chunked_prefill_token_size
|
||||
max_num_batched_tokens = chunked_prefill_token_size
|
||||
|
||||
extra_kwargs = {}
|
||||
if model in KV_CACHE_QUANTIZATION_PATHS:
|
||||
extra_kwargs["quantization_param_path"] = KV_CACHE_QUANTIZATION_PATHS[
|
||||
model]
|
||||
|
||||
with vllm_runner(
|
||||
model,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
enforce_eager=enforce_eager,
|
||||
max_num_seqs=max_num_seqs,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
**extra_kwargs,
|
||||
disable_async_output_proc=disable_async_output_proc,
|
||||
) as vllm_model:
|
||||
no_chunked_prefill_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, NUM_LOG_PROBS)
|
||||
@ -139,7 +120,7 @@ def test_models_with_fp8_kv_cache(
|
||||
enforce_eager=enforce_eager,
|
||||
max_num_seqs=max_num_seqs,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
**extra_kwargs,
|
||||
disable_async_output_proc=disable_async_output_proc,
|
||||
) as vllm_model:
|
||||
chunked_prefill_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, NUM_LOG_PROBS)
|
||||
@ -150,3 +131,68 @@ def test_models_with_fp8_kv_cache(
|
||||
name_0="no_chunked_prefill",
|
||||
name_1="chunked_prefill",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("max_tokens", [16])
|
||||
@pytest.mark.parametrize("enforce_eager", [False])
|
||||
@pytest.mark.parametrize("chunk_size", [30, 32])
|
||||
@pytest.mark.parametrize("use_v2_block_manager", [False, True])
|
||||
# NOTE: Increasing this in this suite will fail CI because we currently cannot
|
||||
# reset distributed env properly. Use a value > 1 just when you test.
|
||||
@pytest.mark.parametrize("tensor_parallel_size", [1])
|
||||
def test_with_prefix_caching(
|
||||
vllm_runner,
|
||||
max_tokens: int,
|
||||
enforce_eager: bool,
|
||||
chunk_size: int,
|
||||
use_v2_block_manager: bool,
|
||||
tensor_parallel_size: int,
|
||||
) -> None:
|
||||
"""
|
||||
Checks exact match decode with and without prefix caching
|
||||
with chunked prefill enabled.
|
||||
"""
|
||||
model = "meta-llama/Llama-2-7b-chat-hf"
|
||||
# The common prompt has 142 tokens with Llama-2 tokenizer.
|
||||
common_prompt = "You are a helpful AI assistant " * 20
|
||||
unique_prompts = [
|
||||
"Question", # Warmup
|
||||
"Question", # Fully cached
|
||||
"Another question", # Partial cached
|
||||
]
|
||||
full_prompts = [f"{common_prompt}\n{p}" for p in unique_prompts]
|
||||
|
||||
max_num_batched_tokens = max_num_seqs = chunk_size
|
||||
outputs = {} # type: ignore
|
||||
check_result = True
|
||||
for enable in (True, False):
|
||||
with vllm_runner(
|
||||
model,
|
||||
dtype="half",
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
enable_chunked_prefill=True,
|
||||
enable_prefix_caching=enable,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
use_v2_block_manager=use_v2_block_manager,
|
||||
enforce_eager=enforce_eager,
|
||||
max_num_seqs=max_num_seqs,
|
||||
) as vllm_model:
|
||||
# It should fail when prefix caching is enable and chunk
|
||||
# size is not a multiple of block size (16).
|
||||
should_fail = chunk_size % 16 != 0 and enable
|
||||
check_result &= not should_fail
|
||||
outputs[enable] = []
|
||||
# Send the request one-by-one to ensure the cache is populated.
|
||||
with pytest.raises(ValueError) if should_fail else nullcontext():
|
||||
for prompt in full_prompts:
|
||||
outputs[enable] += vllm_model.generate_greedy([prompt],
|
||||
max_tokens)
|
||||
|
||||
# Check results only if we did not expect a failure.
|
||||
if check_result:
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=outputs[False],
|
||||
outputs_1_lst=outputs[True],
|
||||
name_0="w/o prefix caching",
|
||||
name_1="with prefix caching",
|
||||
)
|
||||
|
@ -209,7 +209,6 @@ def test_swap_infeasible(
|
||||
prefill_blocks = 2
|
||||
decode_blocks = max_tokens // BLOCK_SIZE
|
||||
example_prompts = example_prompts[:1]
|
||||
|
||||
with vllm_runner(
|
||||
model,
|
||||
dtype=dtype,
|
||||
|
59
tests/compile/test_wrapper.py
Normal file
59
tests/compile/test_wrapper.py
Normal file
@ -0,0 +1,59 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispacther
|
||||
|
||||
|
||||
class MyMod(torch.nn.Module):
|
||||
|
||||
def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
|
||||
if cache is not None:
|
||||
return x + cache
|
||||
return x * 2
|
||||
|
||||
|
||||
class MyWrapper(TorchCompileWrapperWithCustomDispacther):
|
||||
|
||||
def __init__(self, model):
|
||||
self.model = model
|
||||
compiled_callable = torch.compile(self.forward, backend="eager")
|
||||
super().__init__(compiled_callable)
|
||||
|
||||
def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
|
||||
# this is the function to be compiled
|
||||
return self.model(x, cache)
|
||||
|
||||
def __call__(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
|
||||
# let torch.compile compile twice
|
||||
if len(self.compiled_codes) == 2:
|
||||
dispatch_id = 0 if cache is None else 1
|
||||
with self.dispatch_to_code(dispatch_id):
|
||||
return self.forward(x, cache)
|
||||
else:
|
||||
return self.compiled_callable(x, cache)
|
||||
|
||||
|
||||
def test_torch_compile_wrapper():
|
||||
mod = MyMod()
|
||||
wrappers = []
|
||||
for i in range(3):
|
||||
torch._dynamo.reset()
|
||||
wrapper = MyWrapper(mod)
|
||||
wrappers.append(wrapper)
|
||||
x = torch.tensor([1])
|
||||
wrapper(x, None) # profile run, compile
|
||||
# create a cache tensor
|
||||
cache = torch.tensor([2])
|
||||
wrapper(x, cache) # warm up with cache, recompile
|
||||
|
||||
# for new input, dispatch to the compiled code directly
|
||||
new_x = torch.tensor([3])
|
||||
assert wrapper(new_x,
|
||||
None).item() == 6 # dispatch to the first compiled code
|
||||
assert wrapper(
|
||||
new_x, cache).item() == 5 # dispatch to the second compiled code
|
||||
|
||||
for wrapper in wrappers:
|
||||
# make sure they have independent compiled codes
|
||||
assert len(wrapper.compiled_codes) == 2
|
@ -41,6 +41,10 @@ _TEST_DIR = os.path.dirname(__file__)
|
||||
_TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")]
|
||||
_LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")]
|
||||
|
||||
PromptImageInput = Union[List[Image.Image], List[List[Image.Image]]]
|
||||
PromptAudioInput = Union[List[Tuple[np.ndarray, int]],
|
||||
List[List[Tuple[np.ndarray, int]]]]
|
||||
|
||||
|
||||
def _read_prompts(filename: str) -> List[str]:
|
||||
with open(filename, "r") as f:
|
||||
@ -161,7 +165,7 @@ def example_encoder_decoder_prompts(
|
||||
decoder prompt) tuple.
|
||||
|
||||
Returns:
|
||||
|
||||
|
||||
* Encoder prompt list
|
||||
* Decoder prompt list (reverse of encoder prompt list)
|
||||
'''
|
||||
@ -205,8 +209,14 @@ class HfRunner:
|
||||
|
||||
def wrap_device(self, input: _T) -> _T:
|
||||
if not is_cpu():
|
||||
# Check if the input is already on the GPU
|
||||
if hasattr(input, 'device') and input.device.type == "cuda":
|
||||
return input # Already on GPU, no need to move
|
||||
return input.to("cuda")
|
||||
else:
|
||||
# Check if the input is already on the CPU
|
||||
if hasattr(input, 'device') and input.device.type == "cpu":
|
||||
return input # Already on CPU, no need to move
|
||||
return input.to("cpu")
|
||||
|
||||
def __init__(
|
||||
@ -578,8 +588,7 @@ class VllmRunner:
|
||||
self,
|
||||
prompts: List[str],
|
||||
sampling_params: SamplingParams,
|
||||
images: Optional[Union[List[Image.Image],
|
||||
List[List[Image.Image]]]] = None,
|
||||
images: Optional[PromptImageInput] = None,
|
||||
) -> List[Tuple[List[List[int]], List[str]]]:
|
||||
if images is not None:
|
||||
assert len(prompts) == len(images)
|
||||
@ -623,10 +632,8 @@ class VllmRunner:
|
||||
self,
|
||||
prompts: List[str],
|
||||
sampling_params: SamplingParams,
|
||||
images: Optional[Union[List[Image.Image],
|
||||
List[List[Image.Image]]]] = None,
|
||||
audios: Optional[Union[List[Tuple[np.ndarray, int]],
|
||||
List[List[Tuple[np.ndarray, int]]]]] = None
|
||||
images: Optional[PromptImageInput] = None,
|
||||
audios: Optional[PromptAudioInput] = None,
|
||||
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
|
||||
assert sampling_params.logprobs is not None
|
||||
|
||||
@ -676,10 +683,8 @@ class VllmRunner:
|
||||
prompts: List[str],
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
images: Optional[Union[List[Image.Image],
|
||||
List[List[Image.Image]]]] = None,
|
||||
audios: Optional[Union[List[Tuple[np.ndarray, int]],
|
||||
List[List[Tuple[np.ndarray, int]]]]] = None,
|
||||
images: Optional[PromptImageInput] = None,
|
||||
audios: Optional[PromptAudioInput] = None,
|
||||
stop_token_ids: Optional[List[int]] = None,
|
||||
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
|
||||
greedy_logprobs_params = SamplingParams(temperature=0.0,
|
||||
|
@ -708,6 +708,37 @@ class TestPrefixCachingBlockAllocator:
|
||||
token_ids=token_ids)
|
||||
assert allocator.get_prefix_cache_hit_rate() > 0.99
|
||||
|
||||
# Test case for marking cache hit blocks as computed right after
|
||||
# a batch of prefill sequences are scheduled.
|
||||
@staticmethod
|
||||
def test_touch_block():
|
||||
block_size = 16
|
||||
common_blocks = 4
|
||||
allocator = PrefixCachingBlockAllocator(num_blocks=8,
|
||||
block_size=block_size)
|
||||
|
||||
common_token_ids = list(range(block_size * common_blocks))
|
||||
|
||||
# Mimic the behavior of allocating the same block chain
|
||||
# (i.e., common prefix) for a batch of 3 different prefill sequences.
|
||||
for _ in range(3):
|
||||
blocks = TestPrefixCachingBlockAllocator.create_immutable_chain(
|
||||
block_size=block_size,
|
||||
token_ids=common_token_ids,
|
||||
allocator=allocator,
|
||||
)
|
||||
block_ids = [block.block_id for block in blocks]
|
||||
# The allocated blocks should be marked as touched
|
||||
# but not computed.
|
||||
computed_block_ids = allocator.get_computed_block_ids(
|
||||
[], block_ids, skip_last_block_id=False)
|
||||
assert len(computed_block_ids) == 0
|
||||
|
||||
allocator.mark_blocks_as_computed([])
|
||||
computed_block_ids = allocator.get_computed_block_ids(
|
||||
[], block_ids, skip_last_block_id=False)
|
||||
assert len(computed_block_ids) == common_blocks
|
||||
|
||||
@staticmethod
|
||||
def create_immutable_chain(
|
||||
block_size: int,
|
||||
|
@ -595,3 +595,43 @@ def test_sliding_window_multi_seq():
|
||||
|
||||
# assert all blocks are free now
|
||||
assert block_manager.get_num_free_gpu_blocks() == num_gpu_blocks
|
||||
|
||||
|
||||
def test_mark_blocks_as_computed_with_prefix_cache_and_chunked_prefill():
|
||||
"""When prefix cache and chunked prefill are enabled, the block manager
|
||||
should only mark a chunk of blocks as computed instead of all blocks.
|
||||
"""
|
||||
|
||||
block_size = 4
|
||||
num_cpu_blocks = 0
|
||||
num_gpu_blocks = 16
|
||||
block_manager = BlockSpaceManagerV1(block_size,
|
||||
num_gpu_blocks,
|
||||
num_cpu_blocks,
|
||||
watermark=0,
|
||||
enable_caching=True)
|
||||
|
||||
# Set prompt size to have num_gpu_blocks - 1 full blocks.
|
||||
prompt_length = block_size * num_gpu_blocks - 1
|
||||
|
||||
# Allocate (reserve) all blocks.
|
||||
_, seq_group = create_dummy_prompt("0",
|
||||
prompt_length,
|
||||
block_size=block_size)
|
||||
block_manager.allocate(seq_group)
|
||||
assert seq_group.seqs[0].n_blocks == num_gpu_blocks
|
||||
|
||||
# 1st chunk: Compute 2 and half blocks. Should mark 2 blocks as computed.
|
||||
token_chunk_size = int(block_size * 2.5)
|
||||
block_manager.mark_blocks_as_computed(seq_group, token_chunk_size)
|
||||
computed_blocks = block_manager.get_all_computed_blocks(seq_group.seqs[0])
|
||||
assert len(computed_blocks) == 2
|
||||
|
||||
# Actual computed tokens.
|
||||
seq_group.seqs[0].data.update_num_computed_tokens(token_chunk_size)
|
||||
|
||||
# 2nd chunk: Complete 3rd block and additional 4 blocks.
|
||||
token_chunk_size = int(block_size * 4.5)
|
||||
block_manager.mark_blocks_as_computed(seq_group, token_chunk_size)
|
||||
computed_blocks = block_manager.get_all_computed_blocks(seq_group.seqs[0])
|
||||
assert len(computed_blocks) == 7
|
||||
|
@ -21,7 +21,7 @@ def append_new_token(seq_group, token_id: int):
|
||||
|
||||
|
||||
def schedule_and_update_computed_tokens(scheduler):
|
||||
metas, out = scheduler.schedule()
|
||||
metas, out, _ = scheduler.schedule()
|
||||
for s, meta in zip(out.scheduled_seq_groups, metas):
|
||||
s.seq_group.update_num_computed_tokens(meta.token_chunk_size)
|
||||
return metas, out
|
||||
@ -180,7 +180,7 @@ def test_maximal_decoding():
|
||||
"""Verify decoding requests are prioritized."""
|
||||
block_size = 4
|
||||
max_seqs = 2
|
||||
max_model_len = 2
|
||||
max_model_len = 8
|
||||
max_num_batched_tokens = 2
|
||||
scheduler_config = SchedulerConfig(max_num_batched_tokens,
|
||||
max_seqs,
|
||||
@ -562,3 +562,42 @@ def test_chunked_prefill_max_seqs():
|
||||
assert len(get_sequence_groups(out)) == max_seqs
|
||||
assert not running[0].is_prefill()
|
||||
assert not running[1].is_prefill()
|
||||
|
||||
|
||||
def test_perfix_caching():
|
||||
"""Verify allocating full blocks when prefix caching is enabled."""
|
||||
block_size = 4
|
||||
max_seqs = 10
|
||||
max_model_len = 80
|
||||
max_num_batched_tokens = 64
|
||||
scheduler_config = SchedulerConfig(max_num_batched_tokens,
|
||||
max_seqs,
|
||||
max_model_len,
|
||||
enable_chunked_prefill=True)
|
||||
cache_config = CacheConfig(block_size,
|
||||
1.0,
|
||||
1,
|
||||
"auto",
|
||||
enable_prefix_caching=True)
|
||||
cache_config.num_cpu_blocks = 0
|
||||
cache_config.num_gpu_blocks = 32
|
||||
scheduler = Scheduler(scheduler_config, cache_config, None)
|
||||
running: List[SequenceGroup] = []
|
||||
|
||||
# Add seq groups to scheduler.
|
||||
for i in range(2):
|
||||
_, seq_group = create_dummy_prompt(str(i),
|
||||
block_size=block_size,
|
||||
prompt_length=50)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
running.append(seq_group)
|
||||
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert set(get_sequence_groups(out)) == set(running)
|
||||
assert seq_group_meta[0].token_chunk_size == 50
|
||||
# Verify it is chunked. Note that although the budget is 64-50=14,
|
||||
# we only allocate full blocks for prefix caching, so only 4*(14//4)=12
|
||||
# tokens are allocated.
|
||||
assert seq_group_meta[1].token_chunk_size == 12
|
||||
assert out.num_prefill_groups == 2
|
||||
assert out.num_batched_tokens == 62
|
||||
|
@ -199,7 +199,7 @@ def append_new_token(out, token_id: int):
|
||||
|
||||
|
||||
def schedule_and_update_computed_tokens(scheduler):
|
||||
metas, out = scheduler.schedule()
|
||||
metas, out, _ = scheduler.schedule()
|
||||
for s, meta in zip(out.scheduled_seq_groups, metas):
|
||||
s.seq_group.update_num_computed_tokens(meta.token_chunk_size)
|
||||
return metas, out
|
||||
|
2
tests/data/test_config.yaml
Normal file
2
tests/data/test_config.yaml
Normal file
@ -0,0 +1,2 @@
|
||||
port: 12312
|
||||
tensor_parallel_size: 2
|
@ -7,6 +7,8 @@ from vllm import CompletionOutput, LLMEngine, SamplingParams
|
||||
MODEL = "meta-llama/llama-2-7b-hf"
|
||||
MAX_TOKENS = 200
|
||||
|
||||
IS_ASYNC = False
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def vllm_model(vllm_runner):
|
||||
@ -14,77 +16,13 @@ def vllm_model(vllm_runner):
|
||||
yield vllm_model
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_stop_basic(vllm_model):
|
||||
_test_stopping(vllm_model.model.llm_engine,
|
||||
stop=["."],
|
||||
include_in_output=False,
|
||||
expected_output="VLLM is a 100% volunteer organization",
|
||||
expected_reason=".")
|
||||
|
||||
_test_stopping(vllm_model.model.llm_engine,
|
||||
stop=["."],
|
||||
include_in_output=True,
|
||||
expected_output="VLLM is a 100% volunteer organization.",
|
||||
expected_reason=".")
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_stop_multi_tokens(vllm_model):
|
||||
_test_stopping(
|
||||
vllm_model.model.llm_engine,
|
||||
stop=["group of peo", "short"],
|
||||
include_in_output=False,
|
||||
expected_output="VLLM is a 100% volunteer organization. We are a ",
|
||||
expected_reason="group of peo")
|
||||
|
||||
_test_stopping(
|
||||
vllm_model.model.llm_engine,
|
||||
stop=["group of peo", "short"],
|
||||
include_in_output=True,
|
||||
expected_output=
|
||||
"VLLM is a 100% volunteer organization. We are a group of peo",
|
||||
expected_reason="group of peo")
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_stop_partial_token(vllm_model):
|
||||
_test_stopping(vllm_model.model.llm_engine,
|
||||
stop=["gani"],
|
||||
include_in_output=False,
|
||||
expected_output="VLLM is a 100% volunteer or",
|
||||
expected_reason="gani")
|
||||
|
||||
_test_stopping(vllm_model.model.llm_engine,
|
||||
stop=["gani"],
|
||||
include_in_output=True,
|
||||
expected_output="VLLM is a 100% volunteer organi",
|
||||
expected_reason="gani")
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_stop_token_id(vllm_model):
|
||||
# token id 13013 => " organization"
|
||||
|
||||
_test_stopping(vllm_model.model.llm_engine,
|
||||
stop_token_ids=[13013],
|
||||
include_in_output=False,
|
||||
expected_output="VLLM is a 100% volunteer",
|
||||
expected_reason=13013)
|
||||
|
||||
_test_stopping(vllm_model.model.llm_engine,
|
||||
stop_token_ids=[13013],
|
||||
include_in_output=True,
|
||||
expected_output="VLLM is a 100% volunteer organization",
|
||||
expected_reason=13013)
|
||||
|
||||
|
||||
def _test_stopping(llm_engine: LLMEngine,
|
||||
expected_output: str,
|
||||
expected_reason: Any,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop_token_ids: Optional[List[int]] = None,
|
||||
include_in_output: bool = False) -> None:
|
||||
include_in_output: bool = False,
|
||||
use_async_output_proc: bool = False) -> None:
|
||||
llm_engine.add_request(
|
||||
"id", "A story about vLLM:\n",
|
||||
SamplingParams(
|
||||
@ -98,6 +36,10 @@ def _test_stopping(llm_engine: LLMEngine,
|
||||
output: Optional[CompletionOutput] = None
|
||||
output_text = ""
|
||||
stop_reason = None
|
||||
|
||||
if use_async_output_proc:
|
||||
llm_engine.step()
|
||||
|
||||
while llm_engine.has_unfinished_requests():
|
||||
(request_output, ) = llm_engine.step()
|
||||
(output, ) = request_output.outputs
|
||||
@ -110,3 +52,112 @@ def _test_stopping(llm_engine: LLMEngine,
|
||||
assert output is not None
|
||||
assert output_text == expected_output
|
||||
assert stop_reason == expected_reason
|
||||
|
||||
|
||||
def _set_async_mode(llm_engine, is_async):
|
||||
llm_engine.scheduler[0].use_async_output_proc = is_async
|
||||
|
||||
|
||||
def _stop_basic(llm_engine, is_async):
|
||||
_test_stopping(llm_engine,
|
||||
stop=["."],
|
||||
include_in_output=False,
|
||||
expected_output="VLLM is a 100% volunteer organization",
|
||||
expected_reason=".",
|
||||
use_async_output_proc=is_async)
|
||||
|
||||
_test_stopping(llm_engine,
|
||||
stop=["."],
|
||||
include_in_output=True,
|
||||
expected_output="VLLM is a 100% volunteer organization.",
|
||||
expected_reason=".",
|
||||
use_async_output_proc=is_async)
|
||||
|
||||
|
||||
def _stop_multi_tokens(llm_engine, is_async):
|
||||
_test_stopping(
|
||||
llm_engine,
|
||||
stop=["group of peo", "short"],
|
||||
include_in_output=False,
|
||||
expected_output="VLLM is a 100% volunteer organization. We are a ",
|
||||
expected_reason="group of peo",
|
||||
use_async_output_proc=is_async)
|
||||
|
||||
_test_stopping(
|
||||
llm_engine,
|
||||
stop=["group of peo", "short"],
|
||||
include_in_output=True,
|
||||
expected_output=
|
||||
"VLLM is a 100% volunteer organization. We are a group of peo",
|
||||
expected_reason="group of peo",
|
||||
use_async_output_proc=is_async)
|
||||
|
||||
|
||||
def _stop_partial_token(llm_engine, is_async):
|
||||
_test_stopping(llm_engine,
|
||||
stop=["gani"],
|
||||
include_in_output=False,
|
||||
expected_output="VLLM is a 100% volunteer or",
|
||||
expected_reason="gani",
|
||||
use_async_output_proc=is_async)
|
||||
|
||||
_test_stopping(llm_engine,
|
||||
stop=["gani"],
|
||||
include_in_output=True,
|
||||
expected_output="VLLM is a 100% volunteer organi",
|
||||
expected_reason="gani",
|
||||
use_async_output_proc=is_async)
|
||||
|
||||
|
||||
def _stop_token_id(llm_engine, is_async):
|
||||
# token id 13013 => " organization"
|
||||
|
||||
_test_stopping(llm_engine,
|
||||
stop_token_ids=[13013],
|
||||
include_in_output=False,
|
||||
expected_output="VLLM is a 100% volunteer",
|
||||
expected_reason=13013,
|
||||
use_async_output_proc=is_async)
|
||||
|
||||
_test_stopping(llm_engine,
|
||||
stop_token_ids=[13013],
|
||||
include_in_output=True,
|
||||
expected_output="VLLM is a 100% volunteer organization",
|
||||
expected_reason=13013,
|
||||
use_async_output_proc=is_async)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_stop_basic(vllm_model):
|
||||
_set_async_mode(vllm_model.model.llm_engine, True)
|
||||
_stop_basic(vllm_model.model.llm_engine, is_async=True)
|
||||
|
||||
_set_async_mode(vllm_model.model.llm_engine, False)
|
||||
_stop_basic(vllm_model.model.llm_engine, is_async=False)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_stop_multi_tokens(vllm_model):
|
||||
_set_async_mode(vllm_model.model.llm_engine, True)
|
||||
_stop_multi_tokens(vllm_model.model.llm_engine, is_async=True)
|
||||
|
||||
_set_async_mode(vllm_model.model.llm_engine, False)
|
||||
_stop_multi_tokens(vllm_model.model.llm_engine, is_async=False)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_stop_partial_token(vllm_model):
|
||||
_set_async_mode(vllm_model.model.llm_engine, True)
|
||||
_stop_partial_token(vllm_model.model.llm_engine, is_async=True)
|
||||
|
||||
_set_async_mode(vllm_model.model.llm_engine, False)
|
||||
_stop_partial_token(vllm_model.model.llm_engine, is_async=False)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_stop_token_id(vllm_model):
|
||||
_set_async_mode(vllm_model.model.llm_engine, True)
|
||||
_stop_token_id(vllm_model.model.llm_engine, is_async=True)
|
||||
|
||||
_set_async_mode(vllm_model.model.llm_engine, False)
|
||||
_stop_token_id(vllm_model.model.llm_engine, is_async=False)
|
||||
|
@ -6,6 +6,7 @@ import pytest
|
||||
from vllm import LLM, RequestOutput, SamplingParams
|
||||
|
||||
from ...conftest import cleanup
|
||||
from ..openai.test_vision import TEST_IMAGE_URLS
|
||||
|
||||
MODEL_NAME = "facebook/opt-125m"
|
||||
|
||||
@ -159,3 +160,36 @@ def test_chat():
|
||||
]
|
||||
outputs = llm.chat(messages)
|
||||
assert len(outputs) == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("image_urls",
|
||||
[[TEST_IMAGE_URLS[0], TEST_IMAGE_URLS[1]]])
|
||||
def test_chat_multi_image(image_urls: List[str]):
|
||||
llm = LLM(
|
||||
model="microsoft/Phi-3.5-vision-instruct",
|
||||
dtype="bfloat16",
|
||||
max_model_len=4096,
|
||||
max_num_seqs=5,
|
||||
enforce_eager=True,
|
||||
trust_remote_code=True,
|
||||
limit_mm_per_prompt={"image": 2},
|
||||
)
|
||||
|
||||
messages = [{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
*({
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url
|
||||
}
|
||||
} for image_url in image_urls),
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What's in this image?"
|
||||
},
|
||||
],
|
||||
}]
|
||||
outputs = llm.chat(messages)
|
||||
assert len(outputs) >= 0
|
||||
|
48
tests/entrypoints/llm/test_lazy_outlines.py
Normal file
48
tests/entrypoints/llm/test_lazy_outlines.py
Normal file
@ -0,0 +1,48 @@
|
||||
import sys
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
|
||||
def test_lazy_outlines(sample_regex):
|
||||
"""If users don't use guided decoding, outlines should not be imported.
|
||||
"""
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
|
||||
llm = LLM(model="facebook/opt-125m",
|
||||
enforce_eager=True,
|
||||
gpu_memory_utilization=0.3)
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
|
||||
# make sure outlines is not imported
|
||||
assert 'outlines' not in sys.modules
|
||||
|
||||
llm = LLM(model="facebook/opt-125m",
|
||||
enforce_eager=True,
|
||||
guided_decoding_backend="lm-format-enforcer",
|
||||
gpu_memory_utilization=0.3)
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
outputs = llm.generate(
|
||||
prompts=[
|
||||
f"Give an example IPv4 address with this regex: {sample_regex}"
|
||||
] * 2,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True,
|
||||
guided_options_request=dict(guided_regex=sample_regex))
|
||||
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
|
||||
# make sure outlines is not imported
|
||||
assert 'outlines' not in sys.modules
|
@ -2,6 +2,7 @@ from typing import Dict, List
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from vllm.assets.audio import AudioAsset
|
||||
from vllm.multimodal.utils import encode_audio_base64, fetch_audio
|
||||
@ -28,9 +29,10 @@ def server():
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def client(server):
|
||||
return server.get_async_client()
|
||||
@pytest_asyncio.fixture
|
||||
async def client(server):
|
||||
async with server.get_async_client() as async_client:
|
||||
yield async_client
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
|
@ -2,6 +2,7 @@ from http import HTTPStatus
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
import requests
|
||||
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
@ -28,9 +29,10 @@ def server():
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def client(server):
|
||||
return server.get_async_client()
|
||||
@pytest_asyncio.fixture
|
||||
async def client(server):
|
||||
async with server.get_async_client() as async_client:
|
||||
yield async_client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -6,6 +6,7 @@ from typing import Dict, List, Optional
|
||||
import jsonschema
|
||||
import openai # use the official client for correctness check
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
import torch
|
||||
from openai import BadRequestError
|
||||
|
||||
@ -46,9 +47,10 @@ def server(zephyr_lora_files, zephyr_lora_added_tokens_files): # noqa: F811
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def client(server):
|
||||
return server.get_async_client()
|
||||
@pytest_asyncio.fixture
|
||||
async def client(server):
|
||||
async with server.get_async_client() as async_client:
|
||||
yield async_client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -837,6 +839,39 @@ async def test_response_format_json_object(client: openai.AsyncOpenAI):
|
||||
assert loaded == {"result": 2}, loaded
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_response_format_json_schema(client: openai.AsyncOpenAI):
|
||||
for _ in range(2):
|
||||
resp = await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=[{
|
||||
"role":
|
||||
"user",
|
||||
"content": ('what is 1+1? please respond with a JSON object, '
|
||||
'the format is {"result": 2}')
|
||||
}],
|
||||
response_format={
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": "foo_test",
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"result": {
|
||||
"type": "integer"
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
})
|
||||
|
||||
content = resp.choices[0].message.content
|
||||
assert content is not None
|
||||
|
||||
loaded = json.loads(content)
|
||||
assert loaded == {"result": 2}, loaded
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extra_fields(client: openai.AsyncOpenAI):
|
||||
with pytest.raises(BadRequestError) as exc_info:
|
||||
|
@ -8,6 +8,7 @@ from typing import Dict, List, Optional
|
||||
import jsonschema
|
||||
import openai # use the official client for correctness check
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
# downloading lora to test lora requests
|
||||
from huggingface_hub import snapshot_download
|
||||
from openai import BadRequestError
|
||||
@ -89,11 +90,17 @@ def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files,
|
||||
|
||||
@pytest.fixture(scope="module",
|
||||
params=["", "--disable-frontend-multiprocessing"])
|
||||
def client(default_server_args, request):
|
||||
def server(default_server_args, request):
|
||||
if request.param:
|
||||
default_server_args.append(request.param)
|
||||
with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server:
|
||||
yield remote_server.get_async_client()
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(server):
|
||||
async with server.get_async_client() as async_client:
|
||||
yield async_client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -3,6 +3,7 @@ import base64
|
||||
import numpy as np
|
||||
import openai
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
@ -24,10 +25,10 @@ def embedding_server():
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.fixture(scope="module")
|
||||
def embedding_client(embedding_server):
|
||||
return embedding_server.get_async_client()
|
||||
@pytest_asyncio.fixture
|
||||
async def embedding_client(embedding_server):
|
||||
async with embedding_server.get_async_client() as async_client:
|
||||
yield async_client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -128,9 +129,18 @@ async def test_batch_base64_embedding(embedding_client: openai.AsyncOpenAI,
|
||||
for data in responses_base64.data:
|
||||
decoded_responses_base64_data.append(
|
||||
np.frombuffer(base64.b64decode(data.embedding),
|
||||
dtype="float").tolist())
|
||||
dtype="float32").tolist())
|
||||
|
||||
assert responses_float.data[0].embedding == decoded_responses_base64_data[
|
||||
0]
|
||||
assert responses_float.data[1].embedding == decoded_responses_base64_data[
|
||||
1]
|
||||
|
||||
# Default response is float32 decoded from base64 by OpenAI Client
|
||||
responses_default = await embedding_client.embeddings.create(
|
||||
input=input_texts, model=model_name)
|
||||
|
||||
assert responses_float.data[0].embedding == responses_default.data[
|
||||
0].embedding
|
||||
assert responses_float.data[1].embedding == responses_default.data[
|
||||
1].embedding
|
||||
|
@ -1,5 +1,6 @@
|
||||
import openai
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
@ -18,9 +19,10 @@ def server():
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def client(server):
|
||||
return server.get_async_client()
|
||||
@pytest_asyncio.fixture
|
||||
async def client(server):
|
||||
async with server.get_async_client() as async_client:
|
||||
yield async_client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -1,7 +1,12 @@
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
from http import HTTPStatus
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
import requests
|
||||
from prometheus_client.parser import text_string_to_metric_families
|
||||
from transformers import AutoTokenizer
|
||||
@ -31,11 +36,17 @@ def default_server_args():
|
||||
"--enable-chunked-prefill",
|
||||
"--disable-frontend-multiprocessing",
|
||||
])
|
||||
def client(default_server_args, request):
|
||||
def server(default_server_args, request):
|
||||
if request.param:
|
||||
default_server_args.append(request.param)
|
||||
with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server:
|
||||
yield remote_server.get_async_client()
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(server):
|
||||
async with server.get_async_client() as cl:
|
||||
yield cl
|
||||
|
||||
|
||||
_PROMPT = "Hello my name is Robert and I love magic"
|
||||
@ -177,3 +188,48 @@ async def test_metrics_exist(client: openai.AsyncOpenAI):
|
||||
|
||||
for metric in EXPECTED_METRICS:
|
||||
assert metric in response.text
|
||||
|
||||
|
||||
def test_metrics_exist_run_batch():
|
||||
input_batch = """{"custom_id": "request-0", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/e5-mistral-7b-instruct", "input": "You are a helpful assistant."}}""" # noqa: E501
|
||||
|
||||
base_url = "0.0.0.0"
|
||||
port = "8001"
|
||||
server_url = f"http://{base_url}:{port}"
|
||||
|
||||
with tempfile.NamedTemporaryFile(
|
||||
"w") as input_file, tempfile.NamedTemporaryFile(
|
||||
"r") as output_file:
|
||||
input_file.write(input_batch)
|
||||
input_file.flush()
|
||||
proc = subprocess.Popen([
|
||||
sys.executable,
|
||||
"-m",
|
||||
"vllm.entrypoints.openai.run_batch",
|
||||
"-i",
|
||||
input_file.name,
|
||||
"-o",
|
||||
output_file.name,
|
||||
"--model",
|
||||
"intfloat/e5-mistral-7b-instruct",
|
||||
"--enable-metrics",
|
||||
"--url",
|
||||
base_url,
|
||||
"--port",
|
||||
port,
|
||||
], )
|
||||
|
||||
def is_server_up(url):
|
||||
try:
|
||||
response = requests.get(url)
|
||||
return response.status_code == 200
|
||||
except requests.ConnectionError:
|
||||
return False
|
||||
|
||||
while not is_server_up(server_url):
|
||||
time.sleep(1)
|
||||
|
||||
response = requests.get(server_url + "/metrics")
|
||||
assert response.status_code == HTTPStatus.OK
|
||||
|
||||
proc.wait()
|
||||
|
@ -1,5 +1,6 @@
|
||||
import openai # use the official client for correctness check
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
# downloading lora to test lora requests
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
@ -43,9 +44,10 @@ def server(zephyr_lora_files):
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def client(server):
|
||||
return server.get_async_client()
|
||||
@pytest_asyncio.fixture
|
||||
async def client(server):
|
||||
async with server.get_async_client() as async_client:
|
||||
yield async_client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -25,59 +25,63 @@ def server_with_return_tokens_as_token_ids_flag(
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion_return_tokens_as_token_ids_completion(
|
||||
server_with_return_tokens_as_token_ids_flag):
|
||||
client = server_with_return_tokens_as_token_ids_flag.get_async_client()
|
||||
async with server_with_return_tokens_as_token_ids_flag.get_async_client(
|
||||
) as client:
|
||||
|
||||
completion = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
# Include Unicode characters to test for dividing a single
|
||||
# character across multiple tokens: 🎉 is [28705, 31862] for the
|
||||
# Zephyr tokenizer
|
||||
prompt="Say 'Hello, world! 🎉'",
|
||||
echo=True,
|
||||
temperature=0,
|
||||
max_tokens=10,
|
||||
logprobs=1)
|
||||
completion = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
# Include Unicode characters to test for dividing a single
|
||||
# character across multiple tokens: 🎉 is [28705, 31862] for the
|
||||
# Zephyr tokenizer
|
||||
prompt="Say 'Hello, world! 🎉'",
|
||||
echo=True,
|
||||
temperature=0,
|
||||
max_tokens=10,
|
||||
logprobs=1)
|
||||
|
||||
text = completion.choices[0].text
|
||||
token_strs = completion.choices[0].logprobs.tokens
|
||||
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
|
||||
# Check that the token representations are consistent between raw tokens
|
||||
# and top_logprobs
|
||||
# Slice off the first one, because there's no scoring associated with BOS
|
||||
top_logprobs = completion.choices[0].logprobs.top_logprobs[1:]
|
||||
top_logprob_keys = [
|
||||
next(iter(logprob_by_tokens)) for logprob_by_tokens in top_logprobs
|
||||
]
|
||||
assert token_strs[1:] == top_logprob_keys
|
||||
text = completion.choices[0].text
|
||||
token_strs = completion.choices[0].logprobs.tokens
|
||||
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
|
||||
# Check that the token representations are consistent between raw
|
||||
# tokens and top_logprobs
|
||||
# Slice off the first one, because there's no scoring associated
|
||||
# with BOS
|
||||
top_logprobs = completion.choices[0].logprobs.top_logprobs[1:]
|
||||
top_logprob_keys = [
|
||||
next(iter(logprob_by_tokens)) for logprob_by_tokens in top_logprobs
|
||||
]
|
||||
assert token_strs[1:] == top_logprob_keys
|
||||
|
||||
# Check that decoding the tokens gives the expected text
|
||||
tokens = [int(token.removeprefix("token_id:")) for token in token_strs]
|
||||
assert text == tokenizer.decode(tokens, skip_special_tokens=True)
|
||||
# Check that decoding the tokens gives the expected text
|
||||
tokens = [int(token.removeprefix("token_id:")) for token in token_strs]
|
||||
assert text == tokenizer.decode(tokens, skip_special_tokens=True)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_return_tokens_as_token_ids_completion(
|
||||
server_with_return_tokens_as_token_ids_flag):
|
||||
client = server_with_return_tokens_as_token_ids_flag.get_async_client()
|
||||
response = await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
# Include Unicode characters to test for dividing a single
|
||||
# character across multiple tokens: 🎉 is [28705, 31862] for the
|
||||
# Zephyr tokenizer
|
||||
messages=[{
|
||||
"role": "system",
|
||||
"content": "You like to respond in only emojis, like 🎉"
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "Please write some emojis: 🐱🐶🎉"
|
||||
}],
|
||||
temperature=0,
|
||||
max_tokens=8,
|
||||
logprobs=True)
|
||||
async with server_with_return_tokens_as_token_ids_flag.get_async_client(
|
||||
) as client:
|
||||
response = await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
# Include Unicode characters to test for dividing a single
|
||||
# character across multiple tokens: 🎉 is [28705, 31862] for the
|
||||
# Zephyr tokenizer
|
||||
messages=[{
|
||||
"role": "system",
|
||||
"content": "You like to respond in only emojis, like 🎉"
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "Please write some emojis: 🐱🐶🎉"
|
||||
}],
|
||||
temperature=0,
|
||||
max_tokens=8,
|
||||
logprobs=True)
|
||||
|
||||
text = response.choices[0].message.content
|
||||
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
|
||||
token_ids = []
|
||||
for logprob_content in response.choices[0].logprobs.content:
|
||||
token_ids.append(int(logprob_content.token.removeprefix("token_id:")))
|
||||
assert tokenizer.decode(token_ids, skip_special_tokens=True) == text
|
||||
text = response.choices[0].message.content
|
||||
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
|
||||
token_ids = []
|
||||
for logprob_content in response.choices[0].logprobs.content:
|
||||
token_ids.append(
|
||||
int(logprob_content.token.removeprefix("token_id:")))
|
||||
assert tokenizer.decode(token_ids, skip_special_tokens=True) == text
|
||||
|
@ -3,6 +3,7 @@ from contextlib import suppress
|
||||
from dataclasses import dataclass
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from vllm.config import MultiModalConfig
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
@ -20,6 +21,7 @@ class MockModelConfig:
|
||||
max_model_len = 100
|
||||
tokenizer_revision = None
|
||||
embedding_mode = False
|
||||
multimodal_config = MultiModalConfig()
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -35,13 +35,14 @@ async def test_shutdown_on_engine_failure(tmp_path):
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
client = remote_server.get_async_client()
|
||||
async with remote_server.get_async_client() as client:
|
||||
|
||||
with pytest.raises(openai.APIConnectionError):
|
||||
# This crashes the engine
|
||||
await client.completions.create(model="bad-adapter",
|
||||
prompt="Hello, my name is")
|
||||
with pytest.raises(
|
||||
(openai.APIConnectionError, openai.InternalServerError)):
|
||||
# This crashes the engine
|
||||
await client.completions.create(model="bad-adapter",
|
||||
prompt="Hello, my name is")
|
||||
|
||||
# Now the server should shut down
|
||||
return_code = remote_server.proc.wait(timeout=1)
|
||||
assert return_code is not None
|
||||
# Now the server should shut down
|
||||
return_code = remote_server.proc.wait(timeout=3)
|
||||
assert return_code is not None
|
||||
|
@ -1,5 +1,6 @@
|
||||
import openai # use the official client for correctness check
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
import requests
|
||||
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
@ -42,9 +43,10 @@ def tokenizer_name(model_name: str,
|
||||
model_name == "zephyr-lora2") else model_name
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def client(server):
|
||||
return server.get_async_client()
|
||||
@pytest_asyncio.fixture
|
||||
async def client(server):
|
||||
async with server.get_async_client() as async_client:
|
||||
yield async_client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -2,14 +2,14 @@ from typing import Dict, List
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from vllm.multimodal.utils import encode_image_base64, fetch_image
|
||||
|
||||
from ...utils import VLLM_PATH, RemoteOpenAIServer
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
MODEL_NAME = "llava-hf/llava-1.5-7b-hf"
|
||||
LLAVA_CHAT_TEMPLATE = VLLM_PATH / "examples/template_llava.jinja"
|
||||
assert LLAVA_CHAT_TEMPLATE.exists()
|
||||
MODEL_NAME = "microsoft/Phi-3.5-vision-instruct"
|
||||
MAXIMUM_IMAGES = 2
|
||||
|
||||
# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
|
||||
TEST_IMAGE_URLS = [
|
||||
@ -23,22 +23,19 @@ TEST_IMAGE_URLS = [
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
args = [
|
||||
"--dtype",
|
||||
"bfloat16",
|
||||
"--max-model-len",
|
||||
"4096",
|
||||
"--enforce-eager",
|
||||
"--chat-template",
|
||||
str(LLAVA_CHAT_TEMPLATE),
|
||||
"--dtype", "bfloat16", "--max-model-len", "4096", "--max-num-seqs",
|
||||
"5", "--enforce-eager", "--trust-remote-code", "--limit-mm-per-prompt",
|
||||
f"image={MAXIMUM_IMAGES}"
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def client(server):
|
||||
return server.get_async_client()
|
||||
@pytest_asyncio.fixture
|
||||
async def client(server):
|
||||
async with server.get_async_client() as async_client:
|
||||
yield async_client
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
@ -82,7 +79,7 @@ async def test_single_chat_session_image(client: openai.AsyncOpenAI,
|
||||
choice = chat_completion.choices[0]
|
||||
assert choice.finish_reason == "length"
|
||||
assert chat_completion.usage == openai.types.CompletionUsage(
|
||||
completion_tokens=10, prompt_tokens=596, total_tokens=606)
|
||||
completion_tokens=10, prompt_tokens=772, total_tokens=782)
|
||||
|
||||
message = choice.message
|
||||
message = chat_completion.choices[0].message
|
||||
@ -137,7 +134,7 @@ async def test_single_chat_session_image_base64encoded(
|
||||
choice = chat_completion.choices[0]
|
||||
assert choice.finish_reason == "length"
|
||||
assert chat_completion.usage == openai.types.CompletionUsage(
|
||||
completion_tokens=10, prompt_tokens=596, total_tokens=606)
|
||||
completion_tokens=10, prompt_tokens=772, total_tokens=782)
|
||||
|
||||
message = choice.message
|
||||
message = chat_completion.choices[0].message
|
||||
@ -215,26 +212,22 @@ async def test_chat_streaming_image(client: openai.AsyncOpenAI,
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
|
||||
@pytest.mark.parametrize(
|
||||
"image_urls",
|
||||
[TEST_IMAGE_URLS[:i] for i in range(2, len(TEST_IMAGE_URLS))])
|
||||
async def test_multi_image_input(client: openai.AsyncOpenAI, model_name: str,
|
||||
image_url: str):
|
||||
image_urls: List[str]):
|
||||
|
||||
messages = [{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
{
|
||||
*({
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url
|
||||
}
|
||||
},
|
||||
} for image_url in image_urls),
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What's in this image?"
|
||||
@ -242,20 +235,30 @@ async def test_multi_image_input(client: openai.AsyncOpenAI, model_name: str,
|
||||
],
|
||||
}]
|
||||
|
||||
with pytest.raises(openai.BadRequestError): # test multi-image input
|
||||
await client.chat.completions.create(
|
||||
if len(image_urls) > MAXIMUM_IMAGES:
|
||||
with pytest.raises(openai.BadRequestError): # test multi-image input
|
||||
await client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
max_tokens=10,
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
# the server should still work afterwards
|
||||
completion = await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=[0, 0, 0, 0, 0],
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
)
|
||||
completion = completion.choices[0].text
|
||||
assert completion is not None and len(completion) >= 0
|
||||
else:
|
||||
chat_completion = await client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
max_tokens=10,
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
# the server should still work afterwards
|
||||
completion = await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=[0, 0, 0, 0, 0],
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
)
|
||||
completion = completion.choices[0].text
|
||||
assert completion is not None and len(completion) >= 0
|
||||
message = chat_completion.choices[0].message
|
||||
assert message.content is not None and len(message.content) >= 0
|
||||
|
389
tests/entrypoints/test_chat_utils.py
Normal file
389
tests/entrypoints/test_chat_utils.py
Normal file
@ -0,0 +1,389 @@
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
from PIL import Image
|
||||
|
||||
from vllm.assets.image import ImageAsset
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.entrypoints.chat_utils import (parse_chat_messages,
|
||||
parse_chat_messages_futures)
|
||||
from vllm.multimodal import MultiModalDataDict
|
||||
from vllm.multimodal.utils import encode_image_base64
|
||||
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
||||
|
||||
PHI3V_MODEL_ID = "microsoft/Phi-3.5-vision-instruct"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def phi3v_model_config():
|
||||
return ModelConfig(PHI3V_MODEL_ID,
|
||||
PHI3V_MODEL_ID,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=True,
|
||||
dtype="bfloat16",
|
||||
seed=0,
|
||||
limit_mm_per_prompt={
|
||||
"image": 2,
|
||||
})
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def phi3v_tokenizer():
|
||||
return TokenizerGroup(
|
||||
tokenizer_id=PHI3V_MODEL_ID,
|
||||
enable_lora=False,
|
||||
max_num_seqs=5,
|
||||
max_input_length=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def image_url():
|
||||
image = ImageAsset('cherry_blossom')
|
||||
base64 = encode_image_base64(image.pil_image)
|
||||
return f"data:image/jpeg;base64,{base64}"
|
||||
|
||||
|
||||
def _assert_mm_data_is_image_input(
|
||||
mm_data: Optional[MultiModalDataDict],
|
||||
image_count: int,
|
||||
) -> None:
|
||||
assert mm_data is not None
|
||||
assert set(mm_data.keys()) == {"image"}
|
||||
|
||||
image_data = mm_data.get("image")
|
||||
assert image_data is not None
|
||||
|
||||
if image_count == 1:
|
||||
assert isinstance(image_data, Image.Image)
|
||||
else:
|
||||
assert isinstance(image_data, list) and len(image_data) == image_count
|
||||
|
||||
|
||||
def test_parse_chat_messages_single_image(
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
image_url,
|
||||
):
|
||||
conversation, mm_data = parse_chat_messages([{
|
||||
"role":
|
||||
"user",
|
||||
"content": [{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url
|
||||
}
|
||||
}, {
|
||||
"type": "text",
|
||||
"text": "What's in the image?"
|
||||
}]
|
||||
}], phi3v_model_config, phi3v_tokenizer)
|
||||
|
||||
assert conversation == [{
|
||||
"role": "user",
|
||||
"content": "<|image_1|>\nWhat's in the image?"
|
||||
}]
|
||||
_assert_mm_data_is_image_input(mm_data, 1)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_chat_messages_single_image_async(
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
image_url,
|
||||
):
|
||||
conversation, mm_future = parse_chat_messages_futures([{
|
||||
"role":
|
||||
"user",
|
||||
"content": [{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url
|
||||
}
|
||||
}, {
|
||||
"type": "text",
|
||||
"text": "What's in the image?"
|
||||
}]
|
||||
}], phi3v_model_config, phi3v_tokenizer)
|
||||
|
||||
assert conversation == [{
|
||||
"role": "user",
|
||||
"content": "<|image_1|>\nWhat's in the image?"
|
||||
}]
|
||||
_assert_mm_data_is_image_input(await mm_future, 1)
|
||||
|
||||
|
||||
def test_parse_chat_messages_multiple_images(
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
image_url,
|
||||
):
|
||||
conversation, mm_data = parse_chat_messages([{
|
||||
"role":
|
||||
"user",
|
||||
"content": [{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url
|
||||
}
|
||||
}, {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url
|
||||
}
|
||||
}, {
|
||||
"type": "text",
|
||||
"text": "What's in these images?"
|
||||
}]
|
||||
}], phi3v_model_config, phi3v_tokenizer)
|
||||
|
||||
assert conversation == [{
|
||||
"role":
|
||||
"user",
|
||||
"content":
|
||||
"<|image_1|>\n<|image_2|>\nWhat's in these images?"
|
||||
}]
|
||||
_assert_mm_data_is_image_input(mm_data, 2)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_chat_messages_multiple_images_async(
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
image_url,
|
||||
):
|
||||
conversation, mm_future = parse_chat_messages_futures([{
|
||||
"role":
|
||||
"user",
|
||||
"content": [{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url
|
||||
}
|
||||
}, {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url
|
||||
}
|
||||
}, {
|
||||
"type": "text",
|
||||
"text": "What's in these images?"
|
||||
}]
|
||||
}], phi3v_model_config, phi3v_tokenizer)
|
||||
|
||||
assert conversation == [{
|
||||
"role":
|
||||
"user",
|
||||
"content":
|
||||
"<|image_1|>\n<|image_2|>\nWhat's in these images?"
|
||||
}]
|
||||
_assert_mm_data_is_image_input(await mm_future, 2)
|
||||
|
||||
|
||||
def test_parse_chat_messages_placeholder_already_in_prompt(
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
image_url,
|
||||
):
|
||||
conversation, mm_data = parse_chat_messages([{
|
||||
"role":
|
||||
"user",
|
||||
"content": [{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url
|
||||
}
|
||||
}, {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url
|
||||
}
|
||||
}, {
|
||||
"type":
|
||||
"text",
|
||||
"text":
|
||||
"What's in <|image_1|> and how does it compare to <|image_2|>?"
|
||||
}]
|
||||
}], phi3v_model_config, phi3v_tokenizer)
|
||||
|
||||
assert conversation == [{
|
||||
"role":
|
||||
"user",
|
||||
"content":
|
||||
"What's in <|image_1|> and how does it compare to <|image_2|>?"
|
||||
}]
|
||||
_assert_mm_data_is_image_input(mm_data, 2)
|
||||
|
||||
|
||||
def test_parse_chat_messages_placeholder_one_already_in_prompt(
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
image_url,
|
||||
):
|
||||
conversation, mm_data = parse_chat_messages([{
|
||||
"role":
|
||||
"user",
|
||||
"content": [{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url
|
||||
}
|
||||
}, {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url
|
||||
}
|
||||
}, {
|
||||
"type":
|
||||
"text",
|
||||
"text":
|
||||
"What's in <|image_1|> and how does it compare to the other one?"
|
||||
}]
|
||||
}], phi3v_model_config, phi3v_tokenizer)
|
||||
|
||||
assert conversation == [{
|
||||
"role":
|
||||
"user",
|
||||
"content":
|
||||
"<|image_2|>\nWhat's in <|image_1|> and how does it compare to the "
|
||||
"other one?"
|
||||
}]
|
||||
_assert_mm_data_is_image_input(mm_data, 2)
|
||||
|
||||
|
||||
def test_parse_chat_messages_multiple_images_across_messages(
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
image_url,
|
||||
):
|
||||
conversation, mm_data = parse_chat_messages([{
|
||||
"role":
|
||||
"user",
|
||||
"content": [{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url
|
||||
}
|
||||
}, {
|
||||
"type": "text",
|
||||
"text": "What's in this image?"
|
||||
}]
|
||||
}, {
|
||||
"role": "assistant",
|
||||
"content": "Some stuff."
|
||||
}, {
|
||||
"role":
|
||||
"user",
|
||||
"content": [{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url
|
||||
}
|
||||
}, {
|
||||
"type": "text",
|
||||
"text": "What about this one?"
|
||||
}]
|
||||
}], phi3v_model_config, phi3v_tokenizer)
|
||||
|
||||
assert conversation == [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "<|image_1|>\nWhat's in this image?"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Some stuff."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "<|image_2|>\nWhat about this one?"
|
||||
},
|
||||
]
|
||||
_assert_mm_data_is_image_input(mm_data, 2)
|
||||
|
||||
|
||||
def test_parse_chat_messages_rejects_too_many_images_in_one_message(
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
image_url,
|
||||
):
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message="coroutine 'async_get_and_parse_image' was never awaited")
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="At most 2 image\\(s\\) may be provided in one request\\."
|
||||
):
|
||||
parse_chat_messages([{
|
||||
"role":
|
||||
"user",
|
||||
"content": [{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url
|
||||
}
|
||||
}, {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url
|
||||
}
|
||||
}, {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url
|
||||
}
|
||||
}, {
|
||||
"type": "text",
|
||||
"text": "What's in these images?"
|
||||
}]
|
||||
}], phi3v_model_config, phi3v_tokenizer)
|
||||
|
||||
|
||||
def test_parse_chat_messages_rejects_too_many_images_across_messages(
|
||||
phi3v_model_config,
|
||||
phi3v_tokenizer,
|
||||
image_url,
|
||||
):
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message="coroutine 'async_get_and_parse_image' was never awaited")
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="At most 2 image\\(s\\) may be provided in one request\\."
|
||||
):
|
||||
parse_chat_messages([{
|
||||
"role":
|
||||
"user",
|
||||
"content": [{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url
|
||||
}
|
||||
}, {
|
||||
"type": "text",
|
||||
"text": "What's in this image?"
|
||||
}]
|
||||
}, {
|
||||
"role": "assistant",
|
||||
"content": "Some stuff."
|
||||
}, {
|
||||
"role":
|
||||
"user",
|
||||
"content": [{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url
|
||||
}
|
||||
}, {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url
|
||||
}
|
||||
}, {
|
||||
"type": "text",
|
||||
"text": "What about these two?"
|
||||
}]
|
||||
}], phi3v_model_config, phi3v_tokenizer)
|
169
tests/kernels/test_awq_triton.py
Normal file
169
tests/kernels/test_awq_triton.py
Normal file
@ -0,0 +1,169 @@
|
||||
"""Tests for the AWQ Triton kernel.
|
||||
|
||||
Run `pytest tests/kernels/test_awq_triton.py`.
|
||||
"""
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.awq_triton import (
|
||||
AWQ_TRITON_SUPPORTED_GROUP_SIZES, awq_dequantize_triton, awq_gemm_triton)
|
||||
|
||||
device = "cuda"
|
||||
|
||||
|
||||
def reverse_awq_order(t: torch.Tensor):
|
||||
bits = 4
|
||||
AWQ_REVERSE_ORDER = [0, 4, 1, 5, 2, 6, 3, 7]
|
||||
reverse_order_tensor = torch.arange(
|
||||
t.shape[-1],
|
||||
dtype=torch.int32,
|
||||
device=t.device,
|
||||
)
|
||||
reverse_order_tensor = reverse_order_tensor.view(-1, 32 // bits)
|
||||
reverse_order_tensor = reverse_order_tensor[:, AWQ_REVERSE_ORDER]
|
||||
reverse_order_tensor = reverse_order_tensor.view(-1)
|
||||
|
||||
t = t[:, reverse_order_tensor] & 0xF
|
||||
return t
|
||||
|
||||
|
||||
# qweights - [R , C // 8], int32
|
||||
# scales - [R // G, C ], float16
|
||||
# zeros - [R // G, C // 8], int32
|
||||
def awq_dequantize_torch(qweight: torch.Tensor, scales: torch.Tensor,
|
||||
qzeros: torch.Tensor,
|
||||
group_size: int) -> torch.Tensor:
|
||||
|
||||
if group_size == -1:
|
||||
group_size = qweight.shape[0]
|
||||
|
||||
bits = 4
|
||||
shifts = torch.arange(0, 32, bits, device=qzeros.device)
|
||||
|
||||
iweights = torch.bitwise_right_shift(qweight[:, :, None],
|
||||
shifts[None, None, :]).to(torch.int8)
|
||||
|
||||
iweights = iweights.view(iweights.shape[0], -1)
|
||||
|
||||
zeros = torch.bitwise_right_shift(qzeros[:, :, None],
|
||||
shifts[None, None, :]).to(torch.int8)
|
||||
zeros = zeros.view(qzeros.shape[0], -1)
|
||||
zeros = reverse_awq_order(zeros)
|
||||
|
||||
iweights = reverse_awq_order(iweights)
|
||||
|
||||
iweights = torch.bitwise_and(iweights, (2**bits) - 1)
|
||||
zeros = torch.bitwise_and(zeros, (2**bits) - 1)
|
||||
|
||||
scales = scales.repeat_interleave(group_size, dim=0)
|
||||
zeros = zeros.repeat_interleave(group_size, dim=0)
|
||||
return (iweights - zeros) * scales
|
||||
|
||||
|
||||
# qweights - [R , C // 8], int32
|
||||
# scales - [R // G, C ], float16
|
||||
# zeros - [R // G, C // 8], int32
|
||||
@pytest.mark.parametrize("qweight_rows", [3584, 18944, 128, 256, 512, 1024])
|
||||
@pytest.mark.parametrize("qweight_cols", [448, 576, 4736, 16, 32, 64, 128])
|
||||
@pytest.mark.parametrize("group_size", AWQ_TRITON_SUPPORTED_GROUP_SIZES)
|
||||
def test_dequantize(qweight_rows, qweight_cols, group_size):
|
||||
|
||||
if group_size == -1:
|
||||
group_size = qweight_rows
|
||||
|
||||
qweight_dtype = torch.int32
|
||||
scales_rows = qweight_rows // group_size
|
||||
scales_cols = qweight_cols * 8
|
||||
scales_dtype = torch.float16
|
||||
zeros_rows = scales_rows
|
||||
zeros_cols = qweight_cols
|
||||
zeros_dtype = torch.int32
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
qweight = torch.randint(0,
|
||||
torch.iinfo(torch.int32).max,
|
||||
(qweight_rows, qweight_cols),
|
||||
dtype=qweight_dtype,
|
||||
device=device)
|
||||
scales = torch.rand(scales_rows,
|
||||
scales_cols,
|
||||
dtype=scales_dtype,
|
||||
device=device)
|
||||
zeros = torch.randint(0,
|
||||
torch.iinfo(torch.int32).max,
|
||||
(zeros_rows, zeros_cols),
|
||||
dtype=zeros_dtype,
|
||||
device=device)
|
||||
|
||||
iweights_triton = awq_dequantize_triton(qweight, scales, zeros)
|
||||
|
||||
assert (not torch.any(torch.isinf(iweights_triton))
|
||||
and not torch.any(torch.isnan(iweights_triton)))
|
||||
|
||||
iweights_torch = awq_dequantize_torch(qweight, scales, zeros, group_size)
|
||||
|
||||
torch.testing.assert_close(iweights_triton, iweights_torch)
|
||||
|
||||
|
||||
# input - [N, K]
|
||||
# qweight - [K, M // 8]
|
||||
# qzeros - [K // G, M // 8]
|
||||
# scales - [K // G, M]
|
||||
@pytest.mark.parametrize("N", [1, 2, 4, 8, 14, 17, 23, 32])
|
||||
@pytest.mark.parametrize("K", [128])
|
||||
@pytest.mark.parametrize("M", [16, 24, 32])
|
||||
@pytest.mark.parametrize("group_size", AWQ_TRITON_SUPPORTED_GROUP_SIZES)
|
||||
@pytest.mark.parametrize("splitK", [1, 8])
|
||||
def test_gemm(N, K, M, splitK, group_size):
|
||||
|
||||
if group_size == -1:
|
||||
group_size = K
|
||||
|
||||
split_k_iters = splitK
|
||||
|
||||
input_rows = N
|
||||
input_cols = K
|
||||
input_dtype = torch.float32
|
||||
qweight_rows = input_cols
|
||||
qweight_cols = M // 8
|
||||
scales_rows = qweight_rows // group_size
|
||||
scales_cols = M
|
||||
scales_dtype = torch.float32
|
||||
qzeros_rows = scales_rows
|
||||
qzeros_cols = qweight_cols
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
input = torch.rand((input_rows, input_cols),
|
||||
dtype=input_dtype,
|
||||
device=device)
|
||||
qweight = torch.randint(0,
|
||||
torch.iinfo(torch.int32).max,
|
||||
(qweight_rows, qweight_cols),
|
||||
device=device)
|
||||
qzeros = torch.randint(0,
|
||||
torch.iinfo(torch.int32).max,
|
||||
(qzeros_rows, qzeros_cols),
|
||||
device=device)
|
||||
scales = torch.rand((scales_rows, scales_cols),
|
||||
dtype=scales_dtype,
|
||||
device=device)
|
||||
|
||||
output_triton = awq_gemm_triton(input, qweight, scales, qzeros,
|
||||
split_k_iters)
|
||||
|
||||
assert (not torch.any(torch.isinf(output_triton))
|
||||
and not torch.any(torch.isnan(output_triton)))
|
||||
|
||||
dequantized_weights = awq_dequantize_triton(qweight, scales, qzeros)
|
||||
|
||||
output_torch = torch.matmul(input, dequantized_weights)
|
||||
|
||||
assert (not torch.any(torch.isinf(output_torch))
|
||||
and not torch.any(torch.isnan(output_torch)))
|
||||
|
||||
torch.testing.assert_close(output_triton.cpu(),
|
||||
output_torch.cpu(),
|
||||
atol=1e-1,
|
||||
rtol=1e-1)
|
205
tests/kernels/test_causal_conv1d.py
Normal file
205
tests/kernels/test_causal_conv1d.py
Normal file
@ -0,0 +1,205 @@
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||
causal_conv1d_fn, causal_conv1d_update)
|
||||
|
||||
|
||||
def causal_conv1d_ref(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
initial_states: Optional[torch.Tensor] = None,
|
||||
return_final_states: bool = False,
|
||||
final_states_out: Optional[torch.Tensor] = None,
|
||||
activation: Optional[str] = "silu",
|
||||
):
|
||||
"""
|
||||
x: (batch, dim, seqlen)
|
||||
weight: (dim, width)
|
||||
bias: (dim,)
|
||||
initial_states: (batch, dim, width - 1)
|
||||
final_states_out: (batch, dim, width - 1)
|
||||
|
||||
out: (batch, dim, seqlen)
|
||||
"""
|
||||
if activation not in [None, "silu", "swish"]:
|
||||
raise NotImplementedError("activation must be None, silu, or swish")
|
||||
dtype_in = x.dtype
|
||||
x = x.to(weight.dtype)
|
||||
seqlen = x.shape[-1]
|
||||
dim, width = weight.shape
|
||||
if initial_states is None:
|
||||
out = F.conv1d(x,
|
||||
weight.unsqueeze(1),
|
||||
bias,
|
||||
padding=width - 1,
|
||||
groups=dim)
|
||||
else:
|
||||
x = torch.cat([initial_states, x], dim=-1)
|
||||
out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
|
||||
out = out[..., :seqlen]
|
||||
if return_final_states:
|
||||
final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
|
||||
dtype_in) # (batch, dim, width - 1)
|
||||
if final_states_out is not None:
|
||||
final_states_out.copy_(final_states)
|
||||
else:
|
||||
final_states_out = final_states
|
||||
out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
|
||||
return (out, None) if not return_final_states else (out, final_states_out)
|
||||
|
||||
|
||||
def causal_conv1d_update_ref(x: torch.Tensor,
|
||||
conv_state: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
activation: Optional[str] = None):
|
||||
"""
|
||||
x: (batch, dim)
|
||||
conv_state: (batch, dim, width)
|
||||
weight: (dim, width)
|
||||
bias: (dim,)
|
||||
|
||||
out: (batch, dim)
|
||||
"""
|
||||
if activation not in [None, "silu", "swish"]:
|
||||
raise NotImplementedError("activation must be None, silu, or swish")
|
||||
dtype_in = x.dtype
|
||||
batch, dim = x.shape
|
||||
width = weight.shape[1]
|
||||
assert conv_state.shape == (batch, dim, width)
|
||||
assert weight.shape == (dim, width)
|
||||
conv_state.copy_(torch.roll(conv_state, shifts=-1,
|
||||
dims=-1)) # Update state (B D W)
|
||||
conv_state[:, :, -1] = x
|
||||
out = torch.sum(conv_state * weight, dim=-1) # (B D)
|
||||
if bias is not None:
|
||||
out += bias
|
||||
return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("return_final_states", [False, True])
|
||||
@pytest.mark.parametrize("has_initial_states", [False, True])
|
||||
@pytest.mark.parametrize("channel_last", [False, True])
|
||||
@pytest.mark.parametrize("itype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("silu_activation", [False, True])
|
||||
@pytest.mark.parametrize("has_bias", [False, True])
|
||||
@pytest.mark.parametrize("width", [4])
|
||||
@pytest.mark.parametrize("seqlen", [128, 512, 4096])
|
||||
@pytest.mark.parametrize('dim', [64, 4096 + 32])
|
||||
@pytest.mark.parametrize('batch', [1, 2])
|
||||
def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation,
|
||||
itype, channel_last, has_initial_states,
|
||||
return_final_states):
|
||||
if not channel_last and (has_initial_states or return_final_states):
|
||||
pytest.skip(
|
||||
"Only channel_last support initial_states or return_final_states")
|
||||
device = "cuda"
|
||||
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
|
||||
if itype == torch.bfloat16:
|
||||
rtol, atol = 1e-2, 5e-2
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
if not channel_last:
|
||||
x = torch.randn(batch,
|
||||
4096 + dim + 64,
|
||||
seqlen,
|
||||
device=device,
|
||||
dtype=itype)[:, 4096:4096 + dim, :]
|
||||
else:
|
||||
x = rearrange(
|
||||
torch.randn(batch,
|
||||
seqlen,
|
||||
4096 + dim + 64,
|
||||
device=device,
|
||||
dtype=itype)[:, :, 4096:4096 + dim], "b s d -> b d s")
|
||||
weight = torch.randn(dim, width, device=device, dtype=itype)
|
||||
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
|
||||
if has_initial_states:
|
||||
initial_states = torch.randn(batch,
|
||||
width - 1,
|
||||
dim,
|
||||
device=device,
|
||||
dtype=itype).transpose(1, 2)
|
||||
else:
|
||||
initial_states = None
|
||||
x_ref = x.detach().clone()
|
||||
weight_ref = weight.detach().clone()
|
||||
bias_ref = bias.detach().clone() if bias is not None else None
|
||||
initial_states_ref = initial_states.detach().clone(
|
||||
) if initial_states is not None else None
|
||||
activation = None if not silu_activation else "silu"
|
||||
out, final_states = causal_conv1d_fn(
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
initial_states=initial_states,
|
||||
return_final_states=return_final_states,
|
||||
activation=activation)
|
||||
out_ref, final_states_ref = causal_conv1d_ref(
|
||||
x_ref,
|
||||
weight_ref,
|
||||
bias_ref,
|
||||
initial_states=initial_states_ref,
|
||||
return_final_states=return_final_states,
|
||||
activation=activation)
|
||||
if return_final_states:
|
||||
assert final_states is not None and final_states_ref is not None
|
||||
assert torch.allclose(final_states,
|
||||
final_states_ref,
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
|
||||
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
||||
|
||||
if return_final_states:
|
||||
out += F.sigmoid(final_states).sum(dim=-1, keepdim=True)
|
||||
out_ref += F.sigmoid(final_states_ref).sum(dim=-1, keepdim=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("itype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("silu_activation", [False, True])
|
||||
@pytest.mark.parametrize("has_bias", [False, True])
|
||||
@pytest.mark.parametrize("width", [2, 3, 4])
|
||||
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
|
||||
@pytest.mark.parametrize("batch", [1, 2])
|
||||
def test_causal_conv1d_update(batch, dim, width, has_bias, silu_activation,
|
||||
itype):
|
||||
device = "cuda"
|
||||
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
|
||||
if itype == torch.bfloat16:
|
||||
rtol, atol = 1e-2, 5e-2
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
batch = 2
|
||||
x = torch.randn(batch, dim, device=device, dtype=itype)
|
||||
conv_state = torch.randn(batch, dim, width, device=device, dtype=itype)
|
||||
weight = torch.randn(dim,
|
||||
width,
|
||||
device=device,
|
||||
dtype=itype,
|
||||
requires_grad=True)
|
||||
if has_bias:
|
||||
bias = torch.randn(dim, device=device, dtype=itype, requires_grad=True)
|
||||
else:
|
||||
bias = None
|
||||
conv_state_ref = conv_state.detach().clone()
|
||||
activation = None if not silu_activation else "silu"
|
||||
out = causal_conv1d_update(x,
|
||||
conv_state,
|
||||
weight,
|
||||
bias,
|
||||
activation=activation)
|
||||
out_ref = causal_conv1d_update_ref(x,
|
||||
conv_state_ref,
|
||||
weight,
|
||||
bias,
|
||||
activation=activation)
|
||||
|
||||
assert torch.equal(conv_state, conv_state_ref)
|
||||
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
@ -73,11 +73,14 @@ def ref_paged_attn(
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
|
||||
@torch.inference_mode
|
||||
def test_flashinfer_decode_with_paged_kv(kv_lens: List[int],
|
||||
num_heads: Tuple[int,
|
||||
int], head_size: int,
|
||||
dtype: torch.dtype, block_size: int,
|
||||
soft_cap: Optional[float]) -> None:
|
||||
def test_flashinfer_decode_with_paged_kv(
|
||||
kv_lens: List[int],
|
||||
num_heads: Tuple[int, int],
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
block_size: int,
|
||||
soft_cap: Optional[float],
|
||||
) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
torch.cuda.manual_seed_all(0)
|
||||
num_seqs = len(kv_lens)
|
||||
@ -88,6 +91,7 @@ def test_flashinfer_decode_with_paged_kv(kv_lens: List[int],
|
||||
scale = head_size**-0.5
|
||||
|
||||
query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)
|
||||
|
||||
key_value_cache = torch.randn(NUM_BLOCKS,
|
||||
2,
|
||||
block_size,
|
||||
@ -125,7 +129,7 @@ def test_flashinfer_decode_with_paged_kv(kv_lens: List[int],
|
||||
wrapper = flashinfer.\
|
||||
BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD",
|
||||
use_tensor_cores=(
|
||||
(num_query_heads//num_kv_heads) not in (1, 2, 4, 8))
|
||||
(num_query_heads//num_kv_heads) > 4)
|
||||
)
|
||||
wrapper.begin_forward(kv_indptr,
|
||||
kv_indices,
|
||||
@ -249,3 +253,215 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]],
|
||||
soft_cap=soft_cap)
|
||||
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2), \
|
||||
f"{torch.max(torch.abs(output - ref_output))}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seq_lens", [[(1, 132), (5, 18)]])
|
||||
@pytest.mark.parametrize("num_heads", [(32, 8), (6, 1)])
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
|
||||
def test_flashinfer_prefill_with_paged_fp8_kv(
|
||||
seq_lens: List[Tuple[int, int]], num_heads: Tuple[int, int],
|
||||
head_size: int, dtype: torch.dtype, block_size: int,
|
||||
soft_cap: Optional[float]) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
torch.cuda.manual_seed_all(0)
|
||||
num_seqs = len(seq_lens)
|
||||
query_lens = [x[0] for x in seq_lens]
|
||||
kv_lens = [x[1] for x in seq_lens]
|
||||
num_query_heads = num_heads[0]
|
||||
num_kv_heads = num_heads[1]
|
||||
assert num_query_heads % num_kv_heads == 0
|
||||
max_kv_len = max(kv_lens)
|
||||
scale = head_size**-0.5
|
||||
|
||||
kv_cache_dtype = torch.float8_e4m3fn
|
||||
|
||||
query = torch.randn(sum(query_lens),
|
||||
num_query_heads,
|
||||
head_size,
|
||||
dtype=dtype)
|
||||
NUM_BLOCKS_FP8 = 2048
|
||||
key_value_cache = torch.randn(NUM_BLOCKS_FP8,
|
||||
2,
|
||||
block_size,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
dtype=dtype)
|
||||
key_cache, value_cache = torch.chunk(key_value_cache, 2, dim=1)
|
||||
key_cache /= head_size**0.5
|
||||
value_cache /= head_size**0.5
|
||||
|
||||
k_scale = key_cache.amax().item() / 448.0
|
||||
v_scale = value_cache.amax().item() / 448.0
|
||||
|
||||
kv_cache_fp8 = torch.cat([key_cache / k_scale, value_cache / v_scale],
|
||||
dim=1).to(kv_cache_dtype)
|
||||
|
||||
assert (kv_cache_fp8.shape == key_value_cache.shape)
|
||||
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
|
||||
block_tables = torch.randint(0,
|
||||
NUM_BLOCKS_FP8,
|
||||
(num_seqs, max_num_blocks_per_seq),
|
||||
dtype=torch.int32)
|
||||
|
||||
qo_indptr = [0]
|
||||
kv_indptr = [0]
|
||||
kv_indices = []
|
||||
kv_last_page_lens = []
|
||||
for i in range(num_seqs):
|
||||
seq_len = kv_lens[i]
|
||||
assert seq_len > 0
|
||||
num_blocks = (seq_len + block_size - 1) // block_size
|
||||
kv_indices.extend(block_tables[i, :num_blocks])
|
||||
kv_indptr.append(kv_indptr[-1] + num_blocks)
|
||||
kv_last_page_len = seq_len % block_size
|
||||
if kv_last_page_len == 0:
|
||||
kv_last_page_len = block_size
|
||||
kv_last_page_lens.append(kv_last_page_len)
|
||||
qo_indptr.append(qo_indptr[-1] + query_lens[i])
|
||||
|
||||
qo_indptr = torch.tensor(qo_indptr, dtype=torch.int32)
|
||||
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
|
||||
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
|
||||
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
|
||||
|
||||
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
|
||||
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
|
||||
workspace_buffer, "NHD")
|
||||
wrapper.begin_forward(
|
||||
qo_indptr,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_lens,
|
||||
num_query_heads,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
block_size,
|
||||
)
|
||||
|
||||
output = wrapper.forward(query,
|
||||
kv_cache_fp8,
|
||||
logits_soft_cap=soft_cap,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale)
|
||||
|
||||
ref_output = ref_paged_attn(query=query,
|
||||
key_cache=key_cache.squeeze(1),
|
||||
value_cache=value_cache.squeeze(1),
|
||||
query_lens=query_lens,
|
||||
kv_lens=kv_lens,
|
||||
block_tables=block_tables,
|
||||
scale=scale,
|
||||
soft_cap=soft_cap)
|
||||
del query
|
||||
del block_tables
|
||||
# verify prefill fp8
|
||||
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2), \
|
||||
f"{torch.max(torch.abs(output - ref_output))}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]])
|
||||
@pytest.mark.parametrize("num_heads", [(32, 8), (64, 8), (6, 1)])
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
|
||||
@torch.inference_mode
|
||||
def test_flashinfer_decode_with_paged_fp8_kv(
|
||||
kv_lens: List[int],
|
||||
num_heads: Tuple[int, int],
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
block_size: int,
|
||||
soft_cap: Optional[float],
|
||||
) -> None:
|
||||
# test doesn't work for num_heads = (16,16)
|
||||
torch.set_default_device("cuda")
|
||||
torch.cuda.manual_seed_all(0)
|
||||
num_seqs = len(kv_lens)
|
||||
num_query_heads = num_heads[0]
|
||||
num_kv_heads = num_heads[1]
|
||||
assert num_query_heads % num_kv_heads == 0
|
||||
max_kv_len = max(kv_lens)
|
||||
scale = head_size**-0.5
|
||||
use_tensor_cores = (num_query_heads // num_kv_heads) > 4
|
||||
kv_cache_dtype = torch.float8_e4m3fn
|
||||
|
||||
query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)
|
||||
NUM_BLOCKS_FP8 = 2048
|
||||
key_value_cache = torch.randn(NUM_BLOCKS_FP8,
|
||||
2,
|
||||
block_size,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
dtype=dtype)
|
||||
key_cache, value_cache = torch.chunk(key_value_cache, 2, dim=1)
|
||||
key_cache /= head_size**0.5
|
||||
value_cache /= head_size**0.5
|
||||
|
||||
k_scale = key_cache.amax().item() / 448.0
|
||||
v_scale = value_cache.amax().item() / 448.0
|
||||
|
||||
key_cache_fp8 = (key_cache / k_scale).to(kv_cache_dtype)
|
||||
value_cache_fp8 = (value_cache / v_scale).to(kv_cache_dtype)
|
||||
assert (key_cache_fp8.shape[1] == 1 and value_cache_fp8.shape[1] == 1)
|
||||
kv_cache_fp8 = torch.cat([key_cache_fp8, value_cache_fp8], dim=1)
|
||||
|
||||
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
|
||||
block_tables = torch.randint(0,
|
||||
NUM_BLOCKS_FP8,
|
||||
(num_seqs, max_num_blocks_per_seq),
|
||||
dtype=torch.int32)
|
||||
|
||||
kv_indptr = [0]
|
||||
kv_indices = []
|
||||
kv_last_page_lens = []
|
||||
for i in range(num_seqs):
|
||||
seq_len = kv_lens[i]
|
||||
assert seq_len > 0
|
||||
num_blocks = (seq_len + block_size - 1) // block_size
|
||||
kv_indices.extend(block_tables[i, :num_blocks])
|
||||
kv_indptr.append(kv_indptr[-1] + num_blocks)
|
||||
kv_last_page_len = seq_len % block_size
|
||||
if kv_last_page_len == 0:
|
||||
kv_last_page_len = block_size
|
||||
kv_last_page_lens.append(kv_last_page_len)
|
||||
|
||||
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
|
||||
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
|
||||
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
|
||||
|
||||
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
|
||||
wrapper = flashinfer.\
|
||||
BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD",
|
||||
use_tensor_cores=use_tensor_cores)
|
||||
wrapper.begin_forward(kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_lens,
|
||||
num_query_heads,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
block_size,
|
||||
"NONE",
|
||||
data_type=dtype)
|
||||
output = wrapper.forward(query,
|
||||
kv_cache_fp8,
|
||||
logits_soft_cap=soft_cap,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale)
|
||||
key_cache = key_value_cache[:, 0, :, :, :].squeeze(1)
|
||||
value_cache = key_value_cache[:, 1, :, :, :].squeeze(1)
|
||||
|
||||
ref_output = ref_paged_attn(query=query,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
query_lens=[1] * num_seqs,
|
||||
kv_lens=kv_lens,
|
||||
block_tables=block_tables,
|
||||
scale=scale,
|
||||
soft_cap=soft_cap)
|
||||
# Temporary fix: Increasing the tolerance. Seems like a flashinfer issue
|
||||
torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), \
|
||||
f"{torch.max(torch.abs(output - ref_output))}"
|
||||
|
324
tests/kernels/test_mamba_ssm.py
Normal file
324
tests/kernels/test_mamba_ssm.py
Normal file
@ -0,0 +1,324 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
|
||||
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
|
||||
selective_scan_fn, selective_state_update)
|
||||
|
||||
|
||||
def selective_state_update_ref(state,
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D=None,
|
||||
z=None,
|
||||
dt_bias=None,
|
||||
dt_softplus=False):
|
||||
"""
|
||||
Argument:
|
||||
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
|
||||
x: (batch, dim) or (batch, nheads, dim)
|
||||
dt: (batch, dim) or (batch, nheads, dim)
|
||||
A: (dim, dstate) or (nheads, dim, dstate)
|
||||
B: (batch, dstate) or (batch, ngroups, dstate)
|
||||
C: (batch, dstate) or (batch, ngroups, dstate)
|
||||
D: (dim,) or (nheads, dim)
|
||||
z: (batch, dim) or (batch, nheads, dim)
|
||||
dt_bias: (dim,) or (nheads, dim)
|
||||
Return:
|
||||
out: (batch, dim) or (batch, nheads, dim)
|
||||
"""
|
||||
has_heads = state.dim() > 3
|
||||
if state.dim() == 3:
|
||||
state = state.unsqueeze(1)
|
||||
if x.dim() == 2:
|
||||
x = x.unsqueeze(1)
|
||||
if dt.dim() == 2:
|
||||
dt = dt.unsqueeze(1)
|
||||
if A.dim() == 2:
|
||||
A = A.unsqueeze(0)
|
||||
if B.dim() == 2:
|
||||
B = B.unsqueeze(1)
|
||||
if C.dim() == 2:
|
||||
C = C.unsqueeze(1)
|
||||
if D is not None and D.dim() == 1:
|
||||
D = D.unsqueeze(0)
|
||||
if z is not None and z.dim() == 2:
|
||||
z = z.unsqueeze(1)
|
||||
if dt_bias is not None and dt_bias.dim() == 1:
|
||||
dt_bias = dt_bias.unsqueeze(0)
|
||||
batch, nheads, dim, dstate = state.shape
|
||||
assert x.shape == (batch, nheads, dim)
|
||||
assert dt.shape == x.shape
|
||||
assert A.shape == (nheads, dim, dstate)
|
||||
ngroups = B.shape[1]
|
||||
assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
|
||||
assert B.shape == (batch, ngroups, dstate)
|
||||
assert C.shape == B.shape
|
||||
if D is not None:
|
||||
assert D.shape == (nheads, dim)
|
||||
if z is not None:
|
||||
assert z.shape == x.shape
|
||||
if dt_bias is not None:
|
||||
assert dt_bias.shape == (nheads, dim)
|
||||
dt = dt + dt_bias
|
||||
dt = F.softplus(dt) if dt_softplus else dt
|
||||
dA = torch.exp(rearrange(dt, "b h d -> b h d 1") *
|
||||
A) # (batch, nheads, dim, dstate)
|
||||
B = repeat(B, "b g n -> b (g h) n",
|
||||
h=nheads // ngroups) # (batch, nheads, dstate)
|
||||
C = repeat(C, "b g n -> b (g h) n",
|
||||
h=nheads // ngroups) # (batch, nheads, dstate)
|
||||
dB = rearrange(dt, "b h d -> b h d 1") * rearrange(
|
||||
B, "b h n -> b h 1 n") # (batch, nheads, dim, dstate)
|
||||
state.copy_(state * dA +
|
||||
dB * rearrange(x, "b h d -> b h d 1")) # (batch, dim, dstate
|
||||
out = torch.einsum("bhdn,bhn->bhd", state.to(C.dtype), C)
|
||||
if D is not None:
|
||||
out += (x * D).to(out.dtype)
|
||||
out = (out if z is None else out * F.silu(z)).to(x.dtype)
|
||||
if not has_heads:
|
||||
out = out.squeeze(1)
|
||||
return out
|
||||
|
||||
|
||||
def selective_scan_ref(u,
|
||||
delta,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D=None,
|
||||
z=None,
|
||||
delta_bias=None,
|
||||
delta_softplus=False,
|
||||
return_last_state=False,
|
||||
position_indices=None,
|
||||
prev_state=None):
|
||||
"""
|
||||
u: r(B D L)
|
||||
delta: r(B D L)
|
||||
A: c(D N) or r(D N)
|
||||
B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
|
||||
C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
|
||||
D: r(D)
|
||||
z: r(B D L)
|
||||
delta_bias: r(D), fp32
|
||||
prev_state: r(B D N), fp32
|
||||
|
||||
out: r(B D L)
|
||||
last_state (optional): r(B D dstate) or c(B D dstate)
|
||||
"""
|
||||
dtype_in = u.dtype
|
||||
u = u.float()
|
||||
delta = delta.float()
|
||||
if delta_bias is not None:
|
||||
delta = delta + delta_bias[..., None].float()
|
||||
if delta_softplus:
|
||||
delta = F.softplus(delta)
|
||||
batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
|
||||
is_variable_B = B.dim() >= 3
|
||||
is_variable_C = C.dim() >= 3
|
||||
B = B.float()
|
||||
C = C.float()
|
||||
x = A.new_zeros((batch, dim, dstate)) if prev_state is None else prev_state
|
||||
ys = []
|
||||
deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
|
||||
if not is_variable_B:
|
||||
deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u)
|
||||
else:
|
||||
if B.dim() == 3:
|
||||
deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u)
|
||||
else:
|
||||
B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
|
||||
deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u)
|
||||
if is_variable_C and C.dim() == 4:
|
||||
C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
|
||||
last_state = None
|
||||
for i in range(u.shape[2]):
|
||||
if position_indices is not None and position_indices[0, i] == 0:
|
||||
x = deltaB_u[:, :, i]
|
||||
else:
|
||||
x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
|
||||
if not is_variable_C:
|
||||
y = torch.einsum('bdn,dn->bd', x, C)
|
||||
else:
|
||||
if C.dim() == 3:
|
||||
y = torch.einsum('bdn,bn->bd', x, C[:, :, i])
|
||||
else:
|
||||
y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i])
|
||||
if i == u.shape[2] - 1:
|
||||
last_state = x
|
||||
ys.append(y)
|
||||
y = torch.stack(ys, dim=2) # (batch dim L)
|
||||
out = y if D is None else y + u * rearrange(D, "d -> d 1")
|
||||
if z is not None:
|
||||
out = out * F.silu(z)
|
||||
out = out.to(dtype=dtype_in)
|
||||
return out if not return_last_state else (out, last_state)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('wtype', [torch.float32])
|
||||
@pytest.mark.parametrize('itype', [torch.float32])
|
||||
@pytest.mark.parametrize('seqlen', [128, 256, 512, 1024, 2048, 4096])
|
||||
@pytest.mark.parametrize("return_last_state", [True])
|
||||
@pytest.mark.parametrize('has_delta_bias', [True])
|
||||
@pytest.mark.parametrize('delta_softplus', [True])
|
||||
@pytest.mark.parametrize('has_z', [True])
|
||||
@pytest.mark.parametrize('has_D', [True])
|
||||
@pytest.mark.parametrize("varBC_groups", [1, 2])
|
||||
@pytest.mark.parametrize("is_variable_C", [True])
|
||||
@pytest.mark.parametrize("is_variable_B", [True])
|
||||
@pytest.mark.parametrize("scan_chunks", [1, 2, 3])
|
||||
def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D,
|
||||
has_z, has_delta_bias, delta_softplus,
|
||||
return_last_state, seqlen, itype, wtype, scan_chunks):
|
||||
if varBC_groups > 1 and (not is_variable_B or not is_variable_C):
|
||||
pytest.skip() # This config is not applicable
|
||||
device = 'cuda'
|
||||
rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3)
|
||||
if itype == torch.bfloat16:
|
||||
rtol, atol = 3e-2, 5e-2
|
||||
rtolw, atolw = (1e-3, 1e-3)
|
||||
if has_z: # If we have z, the errors on the weights seem higher
|
||||
rtolw = max(rtolw, rtol)
|
||||
atolw = max(atolw, atol)
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
batch_size = 2
|
||||
dim = 4
|
||||
dstate = 8
|
||||
A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype))
|
||||
if not is_variable_B:
|
||||
B_shape = [dim, dstate]
|
||||
elif varBC_groups == 1:
|
||||
B_shape = [batch_size, dstate, seqlen]
|
||||
else:
|
||||
B_shape = [batch_size, varBC_groups, dstate, seqlen]
|
||||
B = torch.randn(B_shape,
|
||||
device=device,
|
||||
dtype=wtype if not is_variable_B else itype)
|
||||
if not is_variable_C:
|
||||
C_shape = [dim, dstate]
|
||||
elif varBC_groups == 1:
|
||||
C_shape = [batch_size, dstate, seqlen]
|
||||
else:
|
||||
C_shape = [batch_size, varBC_groups, dstate, seqlen]
|
||||
C = torch.randn(C_shape,
|
||||
device=device,
|
||||
dtype=wtype if not is_variable_C else itype)
|
||||
D = torch.randn(dim, device=device, dtype=torch.float32) if has_D else None
|
||||
z = torch.randn(batch_size, dim, seqlen, device=device,
|
||||
dtype=itype) if has_z else None
|
||||
delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32)
|
||||
) if has_delta_bias else None
|
||||
u = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype)
|
||||
delta = (0.5 *
|
||||
torch.rand(batch_size, dim, seqlen, device=device, dtype=itype))
|
||||
state = None
|
||||
state_ref = None
|
||||
out = None
|
||||
out_ref = None
|
||||
outs = []
|
||||
for c in range(scan_chunks):
|
||||
chunked_prompt_len = seqlen // scan_chunks
|
||||
chunk_start = chunked_prompt_len * c
|
||||
chunk_end = chunked_prompt_len * (c + 1)
|
||||
if c == scan_chunks - 1:
|
||||
chunk_end = seqlen
|
||||
_B = B
|
||||
if is_variable_B:
|
||||
_B = B[..., chunk_start:chunk_end]
|
||||
_C = C
|
||||
if is_variable_B:
|
||||
_C = C[..., chunk_start:chunk_end]
|
||||
_z = z
|
||||
if has_z:
|
||||
assert z is not None
|
||||
_z = z[..., chunk_start:chunk_end]
|
||||
out, *rest = selective_scan_fn(u[..., chunk_start:chunk_end],
|
||||
delta[..., chunk_start:chunk_end],
|
||||
A,
|
||||
_B,
|
||||
_C,
|
||||
D,
|
||||
z=_z,
|
||||
delta_bias=delta_bias,
|
||||
delta_softplus=delta_softplus,
|
||||
return_last_state=return_last_state,
|
||||
prev_state=state if c > 0 else None)
|
||||
outs.append(out)
|
||||
if return_last_state:
|
||||
state = rest[0]
|
||||
if len(outs) > 1:
|
||||
out = torch.cat(outs, dim=-1)
|
||||
out_ref, *rest = selective_scan_ref(u,
|
||||
delta,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D,
|
||||
z=z,
|
||||
delta_bias=delta_bias,
|
||||
delta_softplus=delta_softplus,
|
||||
return_last_state=return_last_state)
|
||||
if return_last_state:
|
||||
state_ref = rest[0]
|
||||
|
||||
assert out is not None and out_ref is not None
|
||||
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
||||
if return_last_state:
|
||||
assert state is not None and state_ref is not None
|
||||
assert torch.allclose(state, state_ref, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("itype",
|
||||
[torch.float32, torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("has_z", [False, True])
|
||||
@pytest.mark.parametrize("dstate", [16, 32, 64])
|
||||
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
|
||||
def test_selective_state_update(dim, dstate, has_z, itype):
|
||||
device = "cuda"
|
||||
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)
|
||||
if itype == torch.bfloat16:
|
||||
rtol, atol = 1e-2, 5e-2
|
||||
if torch.version.hip:
|
||||
atol *= 2
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
batch_size = 1
|
||||
state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device)
|
||||
x = torch.randn(batch_size, dim, device=device, dtype=itype)
|
||||
dt = torch.randn(batch_size, dim, device=device, dtype=itype)
|
||||
dt_bias = torch.rand(dim, device=device) - 4.0
|
||||
A = -torch.rand(dim, dstate, device=device) - 1.0
|
||||
B = torch.randn(batch_size, dstate, device=device)
|
||||
C = torch.randn(batch_size, dstate, device=device)
|
||||
D = torch.randn(dim, device=device)
|
||||
z = torch.randn_like(x) if has_z else None
|
||||
state_ref = state.detach().clone()
|
||||
out = selective_state_update(state,
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D=D,
|
||||
z=z,
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=True)
|
||||
out_ref = selective_state_update_ref(state_ref,
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D=D,
|
||||
z=z,
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=True)
|
||||
|
||||
assert torch.allclose(state, state_ref, rtol=rtol, atol=atol)
|
||||
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
@ -1,7 +1,10 @@
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
import vllm
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.utils import is_hip
|
||||
|
||||
MODEL_PATH = "google/gemma-7b"
|
||||
|
||||
@ -10,7 +13,7 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
|
||||
prompts = [
|
||||
"Quote: Imagination is",
|
||||
"Quote: Be yourself;",
|
||||
"Quote: So many books,",
|
||||
"Quote: Painting is poetry that is seen rather than felt,",
|
||||
]
|
||||
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=32)
|
||||
outputs = llm.generate(
|
||||
@ -28,6 +31,7 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
|
||||
return generated_texts
|
||||
|
||||
|
||||
@pytest.mark.xfail(is_hip(), reason="There can be output mismatch on ROCm")
|
||||
def test_gemma_lora(gemma_lora_files):
|
||||
llm = vllm.LLM(MODEL_PATH,
|
||||
max_model_len=1024,
|
||||
@ -37,7 +41,8 @@ def test_gemma_lora(gemma_lora_files):
|
||||
expected_lora_output = [
|
||||
"more important than knowledge.\nAuthor: Albert Einstein\n",
|
||||
"everyone else is already taken.\nAuthor: Oscar Wilde\n",
|
||||
"so little time.\nAuthor: Frank Zappa\n",
|
||||
"and poetry is painting that is felt rather than seen.\n"
|
||||
"Author: Leonardo da Vinci\n",
|
||||
]
|
||||
|
||||
output1 = do_sample(llm, gemma_lora_files, lora_id=1)
|
||||
|
@ -7,6 +7,7 @@ import pytest
|
||||
|
||||
import vllm
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.utils import is_hip
|
||||
|
||||
from .conftest import cleanup
|
||||
|
||||
@ -17,12 +18,23 @@ class ModelWithQuantization:
|
||||
quantization: str
|
||||
|
||||
|
||||
MODELS: List[ModelWithQuantization] = [
|
||||
ModelWithQuantization(model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ",
|
||||
quantization="AWQ"),
|
||||
ModelWithQuantization(model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ",
|
||||
quantization="GPTQ"),
|
||||
]
|
||||
MODELS: List[ModelWithQuantization]
|
||||
#AWQ quantization is currently not supported in ROCm.
|
||||
if is_hip():
|
||||
MODELS = [
|
||||
ModelWithQuantization(
|
||||
model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ",
|
||||
quantization="GPTQ"),
|
||||
]
|
||||
else:
|
||||
MODELS = [
|
||||
ModelWithQuantization(
|
||||
model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ",
|
||||
quantization="AWQ"),
|
||||
ModelWithQuantization(
|
||||
model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ",
|
||||
quantization="GPTQ"),
|
||||
]
|
||||
|
||||
|
||||
def do_sample(llm: vllm.LLM,
|
||||
|
@ -3,116 +3,97 @@
|
||||
Note: these tests will only pass on L4 GPU.
|
||||
"""
|
||||
import os
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from tests.kernels.utils import override_backend_env_variable
|
||||
from tests.quantization.utils import is_quant_method_supported
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
from ..models.utils import check_logprobs_close
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
||||
|
||||
MAX_MODEL_LEN = 1024
|
||||
|
||||
MODELS = [
|
||||
"nm-testing/Meta-Llama-3-8B-Instruct-FP8-KV",
|
||||
"meta-llama/Meta-Llama-3-8B-Instruct",
|
||||
]
|
||||
|
||||
EXPECTED_STRS_MAP = {
|
||||
"nm-testing/Meta-Llama-3-8B-Instruct-FP8-KV": {
|
||||
"auto": [
|
||||
'LLaMA is a high-throughput and memory-efficient inference and serving engine for Large Language Models (',
|
||||
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
|
||||
'Artificial intelligence (AI) and human intelligence (HI) process information in distinct ways, with both',
|
||||
'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne',
|
||||
'Zeta-5, a highly advanced robot designed for menial labor, whirred and beep',
|
||||
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The',
|
||||
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
|
||||
'Here are the translations:\n\n**Japanese:** (Haya aki no tori, nemuri no'
|
||||
],
|
||||
"fp8": [
|
||||
'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained',
|
||||
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
|
||||
'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.',
|
||||
'A neural network is a complex system made up of several basic components that work together to enable it to',
|
||||
'Zeta-5, a highly advanced robot designed for menial labor, had never experienced anything like',
|
||||
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. Here',
|
||||
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
|
||||
'Here are the translations:\n\n**Japanese:** (Haya kotori wa mushi o tsuk'
|
||||
]
|
||||
},
|
||||
"meta-llama/Meta-Llama-3-8B-Instruct": {
|
||||
"auto": [
|
||||
'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained',
|
||||
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
|
||||
'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.',
|
||||
'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne',
|
||||
'In the vast, sterile laboratory, Robot 3456-Alpha, or "Alpha" for short',
|
||||
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The',
|
||||
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
|
||||
'Here are the translations:\n\n**Japanese:** (Haya aki wa mushi o tsukamu'
|
||||
],
|
||||
"fp8": [
|
||||
'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained',
|
||||
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
|
||||
'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.',
|
||||
'A neural network is a complex system modeled after the human brain, consisting of interconnected nodes or "ne',
|
||||
'In the year 2154, robotics engineer Dr. Rachel Kim had spent years perfecting her latest',
|
||||
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The',
|
||||
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
|
||||
'Here are the translations:\n\n**Japanese:** (Haya tori, mushi o tsukamu'
|
||||
]
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# This test compares against golden strings for exact match since
|
||||
# there is no baseline implementation to compare against
|
||||
# and is unstable w.r.t specifics of the fp8 implementation or
|
||||
# the hardware being run on.
|
||||
# Disabled to prevent it from breaking the build
|
||||
@pytest.mark.skip(
|
||||
reason=
|
||||
"Prevent unstable test based on golden strings from breaking the build.")
|
||||
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
|
||||
reason="fp8 is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("model_name", MODELS)
|
||||
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
|
||||
def test_models(example_prompts, model_name, kv_cache_dtype) -> None:
|
||||
model = LLM(model=model_name,
|
||||
max_model_len=MAX_MODEL_LEN,
|
||||
trust_remote_code=True,
|
||||
enforce_eager=True,
|
||||
quantization="fp8",
|
||||
kv_cache_dtype=kv_cache_dtype)
|
||||
@pytest.mark.parametrize(
|
||||
"kv_cache_dtype,base_model,test_model,scale_path",
|
||||
[
|
||||
# Test FP8 checkpoint w. fp8_e4m3 kv-cache scaling factors.
|
||||
("fp8_e4m3", "meta-llama/Meta-Llama-3-8B-Instruct",
|
||||
"nm-testing/Meta-Llama-3-8B-Instruct-FP8-KV", None),
|
||||
# Test FP16 checkpoint w. fp8_e5m2 kv-cache.
|
||||
("fp8_e5m2", "meta-llama/Meta-Llama-3-8B-Instruct",
|
||||
"meta-llama/Meta-Llama-3-8B-Instruct", None),
|
||||
# Test FP16 checkpoint w. fp8_e4m3 kv-cache scaling factors in json.
|
||||
("fp8_e4m3", "meta-llama/Llama-2-7b-chat-hf",
|
||||
"meta-llama/Llama-2-7b-chat-hf",
|
||||
"./tests/fp8_kv/llama2-7b-fp8-kv/kv_cache_scales.json")
|
||||
])
|
||||
# Due to low-precision numerical divergence, we only test logprob of 4 tokens
|
||||
@pytest.mark.parametrize("max_tokens", [4])
|
||||
@pytest.mark.parametrize("enforce_eager", [False, True])
|
||||
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "XFORMERS", "FLASHINFER"])
|
||||
# NOTE: Increasing this in this suite will fail CI because we currently cannot
|
||||
# reset distributed env properly. Use a value > 1 just when you test.
|
||||
@pytest.mark.parametrize("tensor_parallel_size", [1])
|
||||
# Due to low-precision numerical divergence, this test is too sensitive for
|
||||
# the async postprocessor
|
||||
@pytest.mark.parametrize("disable_async_output_proc", [True])
|
||||
def test_models(
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
kv_cache_dtype: str,
|
||||
base_model: str,
|
||||
test_model: str,
|
||||
scale_path: Optional[str],
|
||||
max_tokens: int,
|
||||
enforce_eager: bool,
|
||||
backend: str,
|
||||
tensor_parallel_size: int,
|
||||
disable_async_output_proc: bool,
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
"""
|
||||
Only checks log probs match to cover the discrepancy in
|
||||
numerical sensitive kernels.
|
||||
"""
|
||||
override_backend_env_variable(monkeypatch, backend)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
formatted_prompts = [
|
||||
tokenizer.apply_chat_template([{
|
||||
"role": "user",
|
||||
"content": prompt
|
||||
}],
|
||||
tokenize=False,
|
||||
add_generation_prompt=True)
|
||||
for prompt in example_prompts
|
||||
]
|
||||
MAX_MODEL_LEN = 1024
|
||||
NUM_LOG_PROBS = 8
|
||||
|
||||
params = SamplingParams(max_tokens=20, temperature=0)
|
||||
generations: List[str] = []
|
||||
# Note: these need to be run 1 at a time due to numerical precision,
|
||||
# since the expected strs were generated this way.
|
||||
for prompt in formatted_prompts:
|
||||
outputs = model.generate(prompt, params)
|
||||
generations.append(outputs[0].outputs[0].text)
|
||||
del model
|
||||
with vllm_runner(
|
||||
base_model,
|
||||
max_model_len=MAX_MODEL_LEN,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
enforce_eager=enforce_eager,
|
||||
kv_cache_dtype="auto",
|
||||
disable_async_output_proc=disable_async_output_proc,
|
||||
) as vllm_model:
|
||||
baseline_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, NUM_LOG_PROBS)
|
||||
|
||||
print(model_name, kv_cache_dtype, generations)
|
||||
expected_strs = EXPECTED_STRS_MAP[model_name][kv_cache_dtype]
|
||||
for i in range(len(example_prompts)):
|
||||
generated_str = generations[i]
|
||||
expected_str = expected_strs[i]
|
||||
assert expected_str == generated_str, (
|
||||
f"Test{i}:\nExpected: {expected_str!r}\nvLLM: {generated_str!r}")
|
||||
extra_kwargs = {}
|
||||
if scale_path is not None:
|
||||
extra_kwargs["quantization_param_path"] = scale_path
|
||||
|
||||
with vllm_runner(
|
||||
test_model,
|
||||
max_model_len=MAX_MODEL_LEN,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
enforce_eager=enforce_eager,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
disable_async_output_proc=disable_async_output_proc,
|
||||
**extra_kwargs,
|
||||
) as vllm_model:
|
||||
test_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, NUM_LOG_PROBS)
|
||||
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=baseline_outputs,
|
||||
outputs_1_lst=test_outputs,
|
||||
name_0="fp16_kv_cache",
|
||||
name_1="fp8_kv_cache",
|
||||
)
|
||||
|
49
tests/models/test_granite.py
Normal file
49
tests/models/test_granite.py
Normal file
@ -0,0 +1,49 @@
|
||||
"""Compare the outputs of HF and vLLM for Granite models using greedy sampling.
|
||||
|
||||
Run `pytest tests/models/test_granite.py`.
|
||||
"""
|
||||
import importlib.metadata
|
||||
|
||||
import pytest
|
||||
|
||||
from .utils import check_logprobs_close
|
||||
|
||||
TRANSFORMERS_VERSION = tuple(
|
||||
map(int,
|
||||
importlib.metadata.version("transformers").split(".")))
|
||||
|
||||
MODELS = [
|
||||
"ibm/PowerLM-3b",
|
||||
]
|
||||
|
||||
|
||||
# GraniteForCausalLM will be in transformers >= 4.45
|
||||
@pytest.mark.skipif(TRANSFORMERS_VERSION < (4, 45),
|
||||
reason="granite model test requires transformers >= 4.45")
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
@pytest.mark.parametrize("max_tokens", [64])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
def test_models(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
) -> None:
|
||||
# TODO(sang): Sliding window should be tested separately.
|
||||
with hf_runner(model, dtype=dtype) as hf_model:
|
||||
hf_outputs = hf_model.generate_greedy_logprobs_limit(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
|
||||
with vllm_runner(model, dtype=dtype) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=vllm_outputs,
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
)
|
@ -6,8 +6,6 @@ import torch.nn as nn
|
||||
from huggingface_hub import snapshot_download
|
||||
from transformers import AutoConfig, AutoModel, CLIPImageProcessor
|
||||
|
||||
from vllm.model_executor.models.intern_vit import InternVisionModel
|
||||
|
||||
from ..conftest import _ImageAssets, cleanup
|
||||
|
||||
pytestmark = pytest.mark.vlm
|
||||
@ -49,6 +47,7 @@ def run_intern_vit_test(
|
||||
for pixel_value in pixel_values
|
||||
]
|
||||
|
||||
from vllm.model_executor.models.intern_vit import InternVisionModel
|
||||
vllm_model = InternVisionModel(config)
|
||||
vllm_model.load_weights(hf_model.state_dict().items())
|
||||
|
||||
|
@ -3,13 +3,9 @@ from typing import List, Optional, Tuple, Type
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from PIL.Image import Image
|
||||
from transformers import AutoConfig
|
||||
|
||||
from vllm.model_executor.models.internvl import (IMG_CONTEXT, IMG_END,
|
||||
IMG_START,
|
||||
image_to_pixel_values)
|
||||
from vllm.multimodal.utils import rescale_image_size
|
||||
from vllm.utils import is_cpu
|
||||
|
||||
@ -25,49 +21,15 @@ HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
|
||||
"<|im_start|>User\n<image>\nWhat is the season?<|im_end|>\n<|im_start|>Assistant\n", # noqa: E501
|
||||
})
|
||||
|
||||
# we use snapshot_download to prevent conflicts between
|
||||
# dynamic_module and trust_remote_code for hf_runner
|
||||
DOWNLOAD_PATTERN = ["*.json", "*.py", "*.safetensors", "*.txt", "*.model"]
|
||||
models = [
|
||||
snapshot_download("OpenGVLab/InternVL2-1B",
|
||||
allow_patterns=DOWNLOAD_PATTERN),
|
||||
snapshot_download("OpenGVLab/InternVL2-2B",
|
||||
allow_patterns=DOWNLOAD_PATTERN),
|
||||
"OpenGVLab/InternVL2-1B",
|
||||
"OpenGVLab/InternVL2-2B",
|
||||
# Broken due to outdated implementation of Phi-3
|
||||
# See: https://huggingface.co/OpenGVLab/InternVL2-4B/discussions/3
|
||||
# snapshot_download("OpenGVLab/InternVL2-4B"),
|
||||
# "OpenGVLab/InternVL2-4B",
|
||||
]
|
||||
|
||||
|
||||
class InternVLProcessor:
|
||||
"""A simple processor for InternVL2 HF model which misses a processor."""
|
||||
|
||||
def __init__(self, hf_runner: HfRunner):
|
||||
self.num_image_token = hf_runner.model.num_image_token
|
||||
self.tokenizer = hf_runner.tokenizer
|
||||
self.dtype = hf_runner.model.dtype
|
||||
|
||||
self.config = AutoConfig.from_pretrained(hf_runner.model_name)
|
||||
self.vision_config = self.config.vision_config
|
||||
self.use_thumbnail = self.config.use_thumbnail
|
||||
self.min_num = self.config.min_dynamic_patch
|
||||
self.max_num = self.config.max_dynamic_patch
|
||||
self.image_size = self.vision_config.image_size
|
||||
|
||||
def __call__(self, text: str, images: Image, **kwargs):
|
||||
pixel_values = image_to_pixel_values(images, self.image_size,
|
||||
self.min_num, self.max_num,
|
||||
self.use_thumbnail).to(self.dtype)
|
||||
num_patches_list = [pixel_values.shape[0]]
|
||||
for num_patches in num_patches_list:
|
||||
context_tokens = IMG_CONTEXT * self.num_image_token * num_patches
|
||||
image_tokens = IMG_START + context_tokens + IMG_END
|
||||
text = text.replace('<image>', image_tokens, 1)
|
||||
prompt = self.tokenizer(text, return_tensors="pt")
|
||||
prompt.update({"pixel_values": pixel_values})
|
||||
return prompt
|
||||
|
||||
|
||||
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B/blob/main/modeling_internvl_chat.py
|
||||
def generate(
|
||||
self,
|
||||
@ -133,6 +95,37 @@ def run_test(
|
||||
# if we run HF first, the cuda initialization will be done and it
|
||||
# will hurt multiprocessing backend with fork method (the default method).
|
||||
|
||||
class InternVLProcessor:
|
||||
"""A simple processor for InternVL2 which misses a processor."""
|
||||
|
||||
def __init__(self, hf_runner: HfRunner):
|
||||
self.num_image_token = hf_runner.model.num_image_token
|
||||
self.tokenizer = hf_runner.tokenizer
|
||||
self.dtype = hf_runner.model.dtype
|
||||
|
||||
self.config = AutoConfig.from_pretrained(hf_runner.model_name)
|
||||
self.vision_config = self.config.vision_config
|
||||
self.use_thumbnail = self.config.use_thumbnail
|
||||
self.min_num = self.config.min_dynamic_patch
|
||||
self.max_num = self.config.max_dynamic_patch
|
||||
self.image_size = self.vision_config.image_size
|
||||
|
||||
def __call__(self, text: str, images: Image, **kwargs):
|
||||
from vllm.model_executor.models.internvl import (
|
||||
IMG_CONTEXT, IMG_END, IMG_START, image_to_pixel_values)
|
||||
pixel_values = image_to_pixel_values(
|
||||
images, self.image_size, self.min_num, self.max_num,
|
||||
self.use_thumbnail).to(self.dtype)
|
||||
num_patches_list = [pixel_values.shape[0]]
|
||||
for num_patches in num_patches_list:
|
||||
context_tokens = IMG_CONTEXT * self.num_image_token \
|
||||
* num_patches
|
||||
image_tokens = IMG_START + context_tokens + IMG_END
|
||||
text = text.replace('<image>', image_tokens, 1)
|
||||
prompt = self.tokenizer(text, return_tensors="pt")
|
||||
prompt.update({"pixel_values": pixel_values})
|
||||
return prompt
|
||||
|
||||
# max_model_len should be greater than image_feature_size
|
||||
with vllm_runner(model,
|
||||
max_model_len=4096,
|
||||
|
@ -179,3 +179,20 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
|
||||
num_logprobs=num_logprobs,
|
||||
tensor_parallel_size=1,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", models)
|
||||
def test_context_length_too_short(vllm_runner, image_assets, model):
|
||||
images = [asset.pil_image for asset in image_assets]
|
||||
|
||||
with pytest.raises(ValueError, match="too long to fit into the model"):
|
||||
vllm_model = vllm_runner(
|
||||
model,
|
||||
max_model_len=128, # LLaVA has a feature size of 576
|
||||
enforce_eager=True,
|
||||
)
|
||||
|
||||
with vllm_model:
|
||||
vllm_model.generate_greedy([HF_IMAGE_PROMPTS[0]],
|
||||
max_tokens=1,
|
||||
images=[images[0]])
|
||||
|
@ -6,24 +6,22 @@ from transformers import AutoConfig, AutoModelForVision2Seq, AutoTokenizer
|
||||
from vllm.multimodal.utils import rescale_image_size
|
||||
from vllm.sequence import SampleLogprobs
|
||||
|
||||
from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
|
||||
from ..conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner,
|
||||
_ImageAssets)
|
||||
from .utils import check_logprobs_close
|
||||
|
||||
pytestmark = pytest.mark.vlm
|
||||
|
||||
_PREFACE = (
|
||||
"A chat between a curious human and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the human's "
|
||||
"questions.")
|
||||
_LIMIT_IMAGE_PER_PROMPT = 4
|
||||
|
||||
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
|
||||
"stop_sign":
|
||||
f"{_PREFACE} USER: <image>\nWhat's the content of the image? ASSISTANT:",
|
||||
"[INST] <image>\nWhat's the content of the image? [/INST]",
|
||||
"cherry_blossom":
|
||||
f"{_PREFACE} USER: <image>\nWhat is the season? ASSISTANT:",
|
||||
"[INST] <image>\nWhat is the season? [/INST]",
|
||||
})
|
||||
|
||||
models = ["llava-hf/llava-v1.6-vicuna-7b-hf"]
|
||||
models = ["llava-hf/llava-v1.6-mistral-7b-hf"]
|
||||
|
||||
|
||||
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
|
||||
@ -114,19 +112,43 @@ def run_test(
|
||||
else:
|
||||
raise ValueError("You must provide either `size_factors` or `sizes`")
|
||||
|
||||
_run_test(hf_runner,
|
||||
vllm_runner,
|
||||
inputs_per_image,
|
||||
model,
|
||||
dtype=dtype,
|
||||
max_tokens=max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
distributed_executor_backend=distributed_executor_backend)
|
||||
|
||||
|
||||
def _run_test(
|
||||
hf_runner: Type[HfRunner],
|
||||
vllm_runner: Type[VllmRunner],
|
||||
inputs: List[Tuple[List[str], PromptImageInput]],
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
tensor_parallel_size: int,
|
||||
distributed_executor_backend: Optional[str] = None,
|
||||
):
|
||||
# max_model_len should be greater than image_feature_size
|
||||
with vllm_runner(model,
|
||||
dtype=dtype,
|
||||
max_model_len=4096,
|
||||
max_model_len=10240,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
enforce_eager=True) as vllm_model:
|
||||
enforce_eager=True,
|
||||
limit_mm_per_prompt={"image": _LIMIT_IMAGE_PER_PROMPT
|
||||
}) as vllm_model:
|
||||
vllm_outputs_per_image = [
|
||||
vllm_model.generate_greedy_logprobs(prompts,
|
||||
max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
images=images)
|
||||
for prompts, images in inputs_per_image
|
||||
for prompts, images in inputs
|
||||
]
|
||||
|
||||
with hf_runner(model, dtype=dtype,
|
||||
@ -136,7 +158,7 @@ def run_test(
|
||||
max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
images=images)
|
||||
for prompts, images in inputs_per_image
|
||||
for prompts, images in inputs
|
||||
]
|
||||
|
||||
for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
|
||||
@ -177,7 +199,7 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
|
||||
|
||||
All the image fixtures for the test is under tests/images.
|
||||
For huggingface runner, we provide the PIL images as input.
|
||||
For vllm runner, we provide MultiModalDataDict objects
|
||||
For vllm runner, we provide MultiModalDataDict objects
|
||||
and corresponding MultiModalConfig as input.
|
||||
Note, the text input is also adjusted to abide by vllm contract.
|
||||
The text output is sanitized to be able to compare with hf.
|
||||
@ -216,3 +238,48 @@ def test_models_fixed_sizes(hf_runner, vllm_runner, image_assets, model, sizes,
|
||||
num_logprobs=num_logprobs,
|
||||
tensor_parallel_size=1,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", models)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
def test_models_multiple_image_inputs(hf_runner, vllm_runner, image_assets,
|
||||
model, dtype, max_tokens,
|
||||
num_logprobs) -> None:
|
||||
stop_sign = image_assets[0].pil_image
|
||||
cherry_blossom = image_assets[1].pil_image
|
||||
|
||||
inputs = [(
|
||||
[
|
||||
"[INST] <image><image>\nDescribe 2 images. [/INST]",
|
||||
"[INST] <image><image>\nDescribe 2 images. [/INST]",
|
||||
"[INST] <image><image><image><image>\nDescribe 4 images. [/INST]",
|
||||
"[INST] <image>\nWhat is the season? [/INST]"
|
||||
],
|
||||
[
|
||||
[stop_sign, cherry_blossom],
|
||||
# Images with different sizes and aspect-ratios
|
||||
[
|
||||
rescale_image_size(stop_sign, 0.1),
|
||||
stop_sign,
|
||||
],
|
||||
[
|
||||
stop_sign,
|
||||
rescale_image_size(stop_sign, 0.25),
|
||||
cherry_blossom.resize((183, 488)),
|
||||
cherry_blossom.resize((488, 183))
|
||||
],
|
||||
cherry_blossom,
|
||||
])]
|
||||
|
||||
_run_test(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
inputs,
|
||||
model,
|
||||
dtype=dtype,
|
||||
max_tokens=max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
tensor_parallel_size=1,
|
||||
)
|
||||
|
@ -1,14 +1,15 @@
|
||||
from typing import List, Optional, Tuple, Type
|
||||
from typing import List, Optional, Tuple, Type, Union
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.types
|
||||
from PIL import Image
|
||||
from transformers import BatchEncoding
|
||||
|
||||
from vllm.multimodal.utils import rescale_image_size
|
||||
from vllm.sequence import SampleLogprobs
|
||||
|
||||
from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
|
||||
from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner
|
||||
from .utils import check_logprobs_close
|
||||
|
||||
pytestmark = pytest.mark.vlm
|
||||
@ -24,6 +25,11 @@ HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
|
||||
"(<image>./</image>)\nWhat is the season?<|eot_id|>" \
|
||||
"<|start_header_id|>assistant<|end_header_id|>\n\n",
|
||||
})
|
||||
HF_MULTIIMAGE_IMAGE_PROMPT = \
|
||||
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" \
|
||||
"(<image>./</image>)\n(<image>./</image>)\n" \
|
||||
"Describe these images.<|eot_id|>" \
|
||||
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
|
||||
models = ["openbmb/MiniCPM-Llama3-V-2_5"]
|
||||
|
||||
@ -46,13 +52,14 @@ target_dtype = "half"
|
||||
def run_test(
|
||||
hf_runner: Type[HfRunner],
|
||||
vllm_runner: Type[VllmRunner],
|
||||
image_assets: _ImageAssets,
|
||||
inputs: List[Tuple[List[str], Union[List[Image.Image],
|
||||
List[List[Image.Image]]]]],
|
||||
model: str,
|
||||
*,
|
||||
size_factors: List[float],
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
mm_limit: int,
|
||||
tensor_parallel_size: int,
|
||||
distributed_executor_backend: Optional[str] = None,
|
||||
):
|
||||
@ -65,12 +72,6 @@ def run_test(
|
||||
Note, the text input is also adjusted to abide by vllm contract.
|
||||
The text output is sanitized to be able to compare with hf.
|
||||
"""
|
||||
images = [asset.pil_image for asset in image_assets]
|
||||
|
||||
inputs_per_image = [(
|
||||
[prompt for _ in size_factors],
|
||||
[rescale_image_size(image, factor) for factor in size_factors],
|
||||
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
|
||||
|
||||
# NOTE: take care of the order. run vLLM first, and then run HF.
|
||||
# vLLM needs a fresh new process without cuda initialization.
|
||||
@ -82,6 +83,7 @@ def run_test(
|
||||
max_model_len=4096,
|
||||
max_num_seqs=1,
|
||||
dtype=dtype,
|
||||
limit_mm_per_prompt={"image": mm_limit},
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
enforce_eager=True) as vllm_model:
|
||||
@ -93,7 +95,7 @@ def run_test(
|
||||
num_logprobs=num_logprobs,
|
||||
images=images,
|
||||
stop_token_ids=stop_token_ids)
|
||||
for prompts, images in inputs_per_image
|
||||
for prompts, images in inputs
|
||||
]
|
||||
|
||||
hf_model = hf_runner(model, dtype=dtype, postprocess_inputs=_wrap_inputs)
|
||||
@ -104,7 +106,7 @@ def run_test(
|
||||
num_logprobs=num_logprobs,
|
||||
images=images,
|
||||
tokenizer=tokenizer)
|
||||
for prompts, images in inputs_per_image
|
||||
for prompts, images in inputs
|
||||
]
|
||||
|
||||
for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
|
||||
@ -138,104 +140,26 @@ def run_test(
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
|
||||
dtype: str, max_tokens: int, num_logprobs: int) -> None:
|
||||
images = [asset.pil_image for asset in image_assets]
|
||||
|
||||
inputs_per_image = [(
|
||||
[prompt for _ in size_factors],
|
||||
[rescale_image_size(image, factor) for factor in size_factors],
|
||||
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
|
||||
|
||||
run_test(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
image_assets,
|
||||
inputs_per_image,
|
||||
model,
|
||||
size_factors=size_factors,
|
||||
dtype=dtype,
|
||||
max_tokens=max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
mm_limit=1,
|
||||
tensor_parallel_size=1,
|
||||
)
|
||||
|
||||
|
||||
HF_MULTIIMAGE_IMAGE_PROMPT = \
|
||||
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" \
|
||||
"(<image>./</image>)\n(<image>./</image>)\n" \
|
||||
"Describe these images.<|eot_id|>" \
|
||||
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
|
||||
|
||||
def run_multi_image_test(
|
||||
hf_runner: Type[HfRunner],
|
||||
vllm_runner: Type[VllmRunner],
|
||||
image_assets: _ImageAssets,
|
||||
model: str,
|
||||
*,
|
||||
size_factors: List[float],
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
tensor_parallel_size: int,
|
||||
distributed_executor_backend: Optional[str] = None,
|
||||
):
|
||||
"""Inference result should be the same between hf and vllm.
|
||||
|
||||
All the image fixtures for the test is under tests/images.
|
||||
For huggingface runner, we provide the PIL images as input.
|
||||
For vllm runner, we provide MultiModalDataDict objects
|
||||
and corresponding MultiModalConfig as input.
|
||||
Note, the text input is also adjusted to abide by vllm contract.
|
||||
The text output is sanitized to be able to compare with hf.
|
||||
"""
|
||||
images = [asset.pil_image for asset in image_assets]
|
||||
|
||||
inputs_per_case = [
|
||||
([HF_MULTIIMAGE_IMAGE_PROMPT for _ in size_factors],
|
||||
[[rescale_image_size(image, factor) for image in images]
|
||||
for factor in size_factors])
|
||||
]
|
||||
|
||||
# NOTE: take care of the order. run vLLM first, and then run HF.
|
||||
# vLLM needs a fresh new process without cuda initialization.
|
||||
# if we run HF first, the cuda initialization will be done and it
|
||||
# will hurt multiprocessing backend with fork method (the default method).
|
||||
|
||||
# max_model_len should be greater than image_feature_size
|
||||
with vllm_runner(model,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=1,
|
||||
limit_mm_per_prompt={"image": len(images)},
|
||||
dtype=dtype,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
enforce_eager=True) as vllm_model:
|
||||
tokenizer = vllm_model.model.get_tokenizer()
|
||||
stop_token_ids = [tokenizer.eos_id, tokenizer.eot_id]
|
||||
vllm_outputs_per_case = [
|
||||
vllm_model.generate_greedy_logprobs(prompts,
|
||||
max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
images=images,
|
||||
stop_token_ids=stop_token_ids)
|
||||
for prompts, images in inputs_per_case
|
||||
]
|
||||
|
||||
hf_model = hf_runner(model, dtype=dtype, postprocess_inputs=_wrap_inputs)
|
||||
with hf_model, torch.no_grad():
|
||||
hf_outputs_per_case = [
|
||||
hf_model.generate_greedy_logprobs_limit(prompts,
|
||||
max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
images=images,
|
||||
tokenizer=tokenizer)
|
||||
for prompts, images in inputs_per_case
|
||||
]
|
||||
|
||||
for hf_outputs, vllm_outputs in zip(hf_outputs_per_case,
|
||||
vllm_outputs_per_case):
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=[
|
||||
trunc_hf_output(hf_output) for hf_output in hf_outputs
|
||||
],
|
||||
outputs_1_lst=vllm_outputs,
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", models)
|
||||
@pytest.mark.parametrize(
|
||||
"size_factors",
|
||||
@ -256,14 +180,22 @@ def run_multi_image_test(
|
||||
def test_multi_images_models(hf_runner, vllm_runner, image_assets, model,
|
||||
size_factors, dtype: str, max_tokens: int,
|
||||
num_logprobs: int) -> None:
|
||||
run_multi_image_test(
|
||||
images = [asset.pil_image for asset in image_assets]
|
||||
|
||||
inputs_per_case = [
|
||||
([HF_MULTIIMAGE_IMAGE_PROMPT for _ in size_factors],
|
||||
[[rescale_image_size(image, factor) for image in images]
|
||||
for factor in size_factors])
|
||||
]
|
||||
|
||||
run_test(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
image_assets,
|
||||
inputs_per_case,
|
||||
model,
|
||||
size_factors=size_factors,
|
||||
dtype=dtype,
|
||||
max_tokens=max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
mm_limit=2,
|
||||
tensor_parallel_size=1,
|
||||
)
|
||||
|
@ -30,9 +30,11 @@ def test_models(
|
||||
hf_outputs = hf_model.generate_greedy_logprobs_limit(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
|
||||
with vllm_runner(model, dtype=dtype) as vllm_model:
|
||||
with vllm_runner(model, dtype=dtype,
|
||||
tokenizer_mode="mistral") as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=vllm_outputs,
|
||||
|
@ -1,15 +1,16 @@
|
||||
import os
|
||||
import re
|
||||
from typing import List, Optional, Tuple, Type
|
||||
from typing import List, Optional, Tuple, Type, Union
|
||||
|
||||
import pytest
|
||||
from PIL import Image
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm.multimodal.utils import rescale_image_size
|
||||
from vllm.sequence import SampleLogprobs
|
||||
from vllm.utils import is_cpu, is_hip
|
||||
|
||||
from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
|
||||
from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner
|
||||
from .utils import check_logprobs_close
|
||||
|
||||
pytestmark = pytest.mark.vlm
|
||||
@ -20,6 +21,7 @@ HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
|
||||
"cherry_blossom":
|
||||
"<|user|>\n<|image_1|>\nWhat is the season?<|end|>\n<|assistant|>\n",
|
||||
})
|
||||
HF_MULTIIMAGE_IMAGE_PROMPT = "<|user|>\n<|image_1|>\n<|image_2|>\nDescribe these images.<|end|>\n<|assistant|>\n" # noqa: E501
|
||||
|
||||
models = ["microsoft/Phi-3.5-vision-instruct"]
|
||||
|
||||
@ -58,13 +60,14 @@ if is_hip():
|
||||
def run_test(
|
||||
hf_runner: Type[HfRunner],
|
||||
vllm_runner: Type[VllmRunner],
|
||||
image_assets: _ImageAssets,
|
||||
inputs: List[Tuple[List[str], Union[List[Image.Image],
|
||||
List[List[Image.Image]]]]],
|
||||
model: str,
|
||||
*,
|
||||
size_factors: List[float],
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
mm_limit: int,
|
||||
tensor_parallel_size: int,
|
||||
distributed_executor_backend: Optional[str] = None,
|
||||
):
|
||||
@ -77,15 +80,6 @@ def run_test(
|
||||
Note, the text input is also adjusted to abide by vllm contract.
|
||||
The text output is sanitized to be able to compare with hf.
|
||||
"""
|
||||
images = [asset.pil_image for asset in image_assets]
|
||||
|
||||
inputs_per_image = [(
|
||||
[prompt for _ in size_factors],
|
||||
[
|
||||
rescale_image_size(image, factor, transpose=idx)
|
||||
for idx, factor in enumerate(size_factors)
|
||||
],
|
||||
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
|
||||
|
||||
# NOTE: take care of the order. run vLLM first, and then run HF.
|
||||
# vLLM needs a fresh new process without cuda initialization.
|
||||
@ -97,15 +91,16 @@ def run_test(
|
||||
max_model_len=4096,
|
||||
max_num_seqs=1,
|
||||
dtype=dtype,
|
||||
limit_mm_per_prompt={"image": mm_limit},
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
enforce_eager=True) as vllm_model:
|
||||
vllm_outputs_per_image = [
|
||||
vllm_outputs_per_case = [
|
||||
vllm_model.generate_greedy_logprobs(prompts,
|
||||
max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
images=images)
|
||||
for prompts, images in inputs_per_image
|
||||
for prompts, images in inputs
|
||||
]
|
||||
|
||||
# use eager mode for hf runner, since phi3_v didn't work with flash_attn
|
||||
@ -113,17 +108,17 @@ def run_test(
|
||||
with hf_runner(model, dtype=dtype,
|
||||
model_kwargs=hf_model_kwargs) as hf_model:
|
||||
eos_token_id = hf_model.processor.tokenizer.eos_token_id
|
||||
hf_outputs_per_image = [
|
||||
hf_outputs_per_case = [
|
||||
hf_model.generate_greedy_logprobs_limit(prompts,
|
||||
max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
images=images,
|
||||
eos_token_id=eos_token_id)
|
||||
for prompts, images in inputs_per_image
|
||||
for prompts, images in inputs
|
||||
]
|
||||
|
||||
for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
|
||||
vllm_outputs_per_image):
|
||||
for hf_outputs, vllm_outputs in zip(hf_outputs_per_case,
|
||||
vllm_outputs_per_case):
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=[
|
||||
@ -156,14 +151,86 @@ def run_test(
|
||||
@pytest.mark.parametrize("num_logprobs", [10])
|
||||
def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
|
||||
dtype: str, max_tokens: int, num_logprobs: int) -> None:
|
||||
images = [asset.pil_image for asset in image_assets]
|
||||
|
||||
inputs_per_image = [(
|
||||
[prompt for _ in size_factors],
|
||||
[rescale_image_size(image, factor) for factor in size_factors],
|
||||
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
|
||||
|
||||
run_test(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
image_assets,
|
||||
inputs_per_image,
|
||||
model,
|
||||
size_factors=size_factors,
|
||||
dtype=dtype,
|
||||
max_tokens=max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
mm_limit=1,
|
||||
tensor_parallel_size=1,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", models)
|
||||
@pytest.mark.parametrize("dtype", [target_dtype])
|
||||
def test_regression_7840(hf_runner, vllm_runner, image_assets, model,
|
||||
dtype) -> None:
|
||||
images = [asset.pil_image for asset in image_assets]
|
||||
|
||||
inputs_regresion_7840 = [
|
||||
([prompt], [image]) for image, prompt in zip(images, HF_IMAGE_PROMPTS)
|
||||
]
|
||||
|
||||
# Regression test for #7840.
|
||||
run_test(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
inputs_regresion_7840,
|
||||
model,
|
||||
dtype=dtype,
|
||||
max_tokens=128,
|
||||
num_logprobs=10,
|
||||
mm_limit=1,
|
||||
tensor_parallel_size=1,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", models)
|
||||
@pytest.mark.parametrize(
|
||||
"size_factors",
|
||||
[
|
||||
# No image
|
||||
[],
|
||||
# Single-scale
|
||||
[1.0],
|
||||
# Single-scale, batched
|
||||
[1.0, 1.0, 1.0],
|
||||
# Multi-scale
|
||||
[0.25, 0.5, 1.0],
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("dtype", [target_dtype])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
@pytest.mark.parametrize("num_logprobs", [10])
|
||||
def test_multi_images_models(hf_runner, vllm_runner, image_assets, model,
|
||||
size_factors, dtype: str, max_tokens: int,
|
||||
num_logprobs: int) -> None:
|
||||
images = [asset.pil_image for asset in image_assets]
|
||||
|
||||
inputs_per_case = [
|
||||
([HF_MULTIIMAGE_IMAGE_PROMPT for _ in size_factors],
|
||||
[[rescale_image_size(image, factor) for image in images]
|
||||
for factor in size_factors])
|
||||
]
|
||||
|
||||
run_test(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
inputs_per_case,
|
||||
model,
|
||||
dtype=dtype,
|
||||
max_tokens=max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
mm_limit=2,
|
||||
tensor_parallel_size=1,
|
||||
)
|
||||
|
111
tests/models/test_phimoe.py
Normal file
111
tests/models/test_phimoe.py
Normal file
@ -0,0 +1,111 @@
|
||||
"""Compare the outputs of HF and vLLM for moe models using greedy sampling.
|
||||
|
||||
Run `pytest tests/models/test_phimoe.py`.
|
||||
"""
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.utils import is_cpu
|
||||
|
||||
from .utils import check_logprobs_close
|
||||
|
||||
MODELS = [
|
||||
"microsoft/Phi-3.5-MoE-instruct",
|
||||
]
|
||||
|
||||
|
||||
def test_phimoe_routing_function():
|
||||
from vllm.model_executor.models.phimoe import phimoe_routing_function
|
||||
test_case = {
|
||||
0: {
|
||||
"hidden_states":
|
||||
torch.tensor([1, 2, 3, 4, 5, 6, 7, 8],
|
||||
dtype=torch.float32,
|
||||
requires_grad=False).view(4, 2),
|
||||
"gating_output":
|
||||
torch.tensor([0.1, 0.2, 0.3, 0.4],
|
||||
dtype=torch.float32,
|
||||
requires_grad=False),
|
||||
"topk":
|
||||
2,
|
||||
"renormalize":
|
||||
False,
|
||||
},
|
||||
1: {
|
||||
"hidden_states":
|
||||
torch.tensor([1, 2, 3, 4, 5, 6, 7, 8],
|
||||
dtype=torch.float32,
|
||||
requires_grad=False).view(4, 2),
|
||||
"gating_output":
|
||||
torch.tensor([0.4, 0.2, 0.3, 0.4],
|
||||
dtype=torch.float32,
|
||||
requires_grad=False),
|
||||
"topk":
|
||||
2,
|
||||
"renormalize":
|
||||
False,
|
||||
}
|
||||
}
|
||||
|
||||
ground_truth = {
|
||||
0: {
|
||||
"topk_weights":
|
||||
torch.tensor([1., 1.], dtype=torch.float32, requires_grad=False),
|
||||
"topk_ids":
|
||||
torch.tensor([3, 2], dtype=torch.long, requires_grad=False),
|
||||
},
|
||||
1: {
|
||||
"topk_weights":
|
||||
torch.tensor([0.5, 1.], dtype=torch.float32, requires_grad=False),
|
||||
"topk_ids":
|
||||
torch.tensor([0, 3], dtype=torch.long, requires_grad=False),
|
||||
}
|
||||
}
|
||||
|
||||
for test_id in test_case:
|
||||
topk_weights, topk_ids = phimoe_routing_function(**test_case[test_id])
|
||||
assert torch.allclose(topk_weights,
|
||||
ground_truth[test_id]["topk_weights"])
|
||||
assert torch.equal(topk_ids, ground_truth[test_id]["topk_ids"])
|
||||
|
||||
|
||||
def get_gpu_memory():
|
||||
try:
|
||||
props = torch.cuda.get_device_properties(torch.cuda.current_device())
|
||||
gpu_memory = props.total_memory / (1024**3)
|
||||
return gpu_memory
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
|
||||
@pytest.mark.skipif(condition=is_cpu(),
|
||||
reason="This test takes a lot time to run on CPU, "
|
||||
"and vllm CI's disk space is not enough for this model.")
|
||||
@pytest.mark.skipif(condition=get_gpu_memory() < 100,
|
||||
reason="Skip this test if GPU memory is insufficient.")
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
@pytest.mark.parametrize("max_tokens", [64])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
def test_models(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
) -> None:
|
||||
with hf_runner(model, dtype=dtype) as hf_model:
|
||||
hf_outputs = hf_model.generate_greedy_logprobs_limit(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
|
||||
with vllm_runner(model, dtype=dtype) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=vllm_outputs,
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
)
|
@ -1,11 +1,9 @@
|
||||
from typing import List, Optional, Tuple, Type
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import pytest
|
||||
from transformers import AutoModel, AutoTokenizer, BatchEncoding
|
||||
|
||||
from vllm.assets.audio import AudioAsset
|
||||
from vllm.sequence import SampleLogprobs
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
|
||||
@ -18,36 +16,32 @@ MODEL_NAME = "fixie-ai/ultravox-v0_3"
|
||||
|
||||
AudioTuple = Tuple[np.ndarray, int]
|
||||
|
||||
VLLM_PLACEHOLDER = "<|reserved_special_token_0|>"
|
||||
HF_PLACEHOLDER = "<|audio|>"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def audio_and_sample_rate():
|
||||
return AudioAsset("mary_had_lamb").audio_and_sample_rate
|
||||
def audio_assets():
|
||||
from vllm.assets.audio import AudioAsset
|
||||
return [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def prompts_and_audios(audio_and_sample_rate):
|
||||
@pytest.fixture(scope="module", params=("mary_had_lamb", "winning_call"))
|
||||
def audio(request):
|
||||
from vllm.assets.audio import AudioAsset
|
||||
return AudioAsset(request.param)
|
||||
|
||||
|
||||
def _get_prompt(audio_count, question, placeholder):
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
placeholder = f"{placeholder}\n" * audio_count
|
||||
|
||||
vllm_placeholder = "<|reserved_special_token_0|>"
|
||||
hf_placeholder = "<|audio|>"
|
||||
|
||||
question = "What's in the audio?"
|
||||
vllm_prompt = tokenizer.apply_chat_template(
|
||||
[{
|
||||
'role': 'user',
|
||||
'content': f"{vllm_placeholder}\n{question}"
|
||||
}],
|
||||
tokenize=False,
|
||||
add_generation_prompt=True)
|
||||
hf_prompt = tokenizer.apply_chat_template(
|
||||
[{
|
||||
'role': 'user',
|
||||
'content': f"{hf_placeholder}\n{question}"
|
||||
}],
|
||||
tokenize=False,
|
||||
add_generation_prompt=True)
|
||||
|
||||
return [(vllm_prompt, hf_prompt, audio_and_sample_rate)]
|
||||
return tokenizer.apply_chat_template([{
|
||||
'role': 'user',
|
||||
'content': f"{placeholder}{question}"
|
||||
}],
|
||||
tokenize=False,
|
||||
add_generation_prompt=True)
|
||||
|
||||
|
||||
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
|
||||
@ -109,6 +103,7 @@ def run_test(
|
||||
dtype=dtype,
|
||||
postprocess_inputs=process,
|
||||
auto_cls=AutoModel) as hf_model:
|
||||
import librosa
|
||||
|
||||
hf_outputs_per_audio = [
|
||||
hf_model.generate_greedy_logprobs_limit(
|
||||
@ -134,15 +129,71 @@ def run_test(
|
||||
)
|
||||
|
||||
|
||||
def run_multi_audio_test(
|
||||
vllm_runner: Type[VllmRunner],
|
||||
prompts_and_audios: List[Tuple[str, List[AudioTuple]]],
|
||||
model: str,
|
||||
*,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
tensor_parallel_size: int,
|
||||
distributed_executor_backend: Optional[str] = None,
|
||||
):
|
||||
with vllm_runner(model,
|
||||
dtype=dtype,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
enforce_eager=True,
|
||||
limit_mm_per_prompt={
|
||||
"audio":
|
||||
max((len(audio) for _, audio in prompts_and_audios))
|
||||
}) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_greedy_logprobs(
|
||||
[prompt for prompt, _ in prompts_and_audios],
|
||||
max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
audios=[audios for _, audios in prompts_and_audios])
|
||||
|
||||
# The HuggingFace model doesn't support multiple audios yet, so
|
||||
# just assert that some tokens were generated.
|
||||
assert all(tokens for tokens, *_ in vllm_outputs)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
def test_models(hf_runner, vllm_runner, prompts_and_audios, dtype: str,
|
||||
max_tokens: int, num_logprobs: int) -> None:
|
||||
def test_models(hf_runner, vllm_runner, audio, dtype: str, max_tokens: int,
|
||||
num_logprobs: int) -> None:
|
||||
|
||||
vllm_prompt = _get_prompt(1, "Describe the audio above.", VLLM_PLACEHOLDER)
|
||||
hf_prompt = _get_prompt(1, "Describe the audio above.", HF_PLACEHOLDER)
|
||||
run_test(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
prompts_and_audios,
|
||||
[(vllm_prompt, hf_prompt, audio.audio_and_sample_rate)],
|
||||
MODEL_NAME,
|
||||
dtype=dtype,
|
||||
max_tokens=max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
tensor_parallel_size=1,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
def test_models_with_multiple_audios(vllm_runner, audio_assets, dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int) -> None:
|
||||
|
||||
vllm_prompt = _get_prompt(len(audio_assets),
|
||||
"Describe each of the audios above.",
|
||||
VLLM_PLACEHOLDER)
|
||||
run_multi_audio_test(
|
||||
vllm_runner,
|
||||
[(vllm_prompt, [audio.audio_and_sample_rate
|
||||
for audio in audio_assets])],
|
||||
MODEL_NAME,
|
||||
dtype=dtype,
|
||||
max_tokens=max_tokens,
|
||||
|
@ -1,7 +1,7 @@
|
||||
import warnings
|
||||
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
from vllm.sequence import SampleLogprobs
|
||||
from vllm.sequence import Logprob, SampleLogprobs
|
||||
|
||||
TokensText = Tuple[List[int], str]
|
||||
|
||||
@ -38,34 +38,39 @@ TokensTextLogprobs = Tuple[List[int], str, Optional[Union[List[Dict[int,
|
||||
float]],
|
||||
SampleLogprobs]]]
|
||||
|
||||
# Allow for tokens to be represented as str's rather than IDs
|
||||
TextTextLogprobs = Tuple[List[str], str, Optional[Union[List[Dict[str, float]],
|
||||
List[Dict[str,
|
||||
Logprob]]]]]
|
||||
|
||||
|
||||
def check_logprobs_close(
|
||||
*,
|
||||
outputs_0_lst: Sequence[TokensTextLogprobs],
|
||||
outputs_1_lst: Sequence[TokensTextLogprobs],
|
||||
outputs_0_lst: Sequence[Union[TokensTextLogprobs, TextTextLogprobs]],
|
||||
outputs_1_lst: Sequence[Union[TokensTextLogprobs, TextTextLogprobs]],
|
||||
name_0: str,
|
||||
name_1: str,
|
||||
num_outputs_0_skip_tokens: int = 0,
|
||||
warn_on_mismatch: bool = True,
|
||||
):
|
||||
"""
|
||||
Compare the logprobs of two sequences generated by different models,
|
||||
always_check_logprobs: bool = False,
|
||||
) -> None:
|
||||
"""Compare the logprobs of two sequences generated by different models,
|
||||
which should be similar but not necessarily equal.
|
||||
|
||||
Arguments:
|
||||
|
||||
* outputs_0_lst: First sequence to compare
|
||||
* outputs_0_lst: Second sequence to compare
|
||||
* name_0: sequence #0 name
|
||||
* name_1: sequence #1 name
|
||||
* num_outputs_0_skip_tokens: If > 0, specifies the number of initial
|
||||
Args:
|
||||
outputs_0_lst: First sequence to compare
|
||||
outputs_0_lst: Second sequence to compare
|
||||
name_0: sequence #0 name
|
||||
name_1: sequence #1 name
|
||||
num_outputs_0_skip_tokens: If > 0, specifies the number of initial
|
||||
sequence #0 tokens & logprobs to discard
|
||||
before comparison, i.e. all
|
||||
of sequence #1 will be compared to
|
||||
sequence #0 beginning at index
|
||||
num_outputs_0_skip_tokens
|
||||
* warn_on_mismatch: Issue a warning if there is token-wise or text-wise
|
||||
warn_on_mismatch: Issue a warning if there is token-wise or text-wise
|
||||
mismatch between the two sequences
|
||||
always_check_logprobs: If true, check logprobs even when tokens match
|
||||
"""
|
||||
assert len(outputs_0_lst) == len(outputs_1_lst)
|
||||
|
||||
@ -94,8 +99,12 @@ def check_logprobs_close(
|
||||
for idx, (output_id_0,
|
||||
output_id_1) in enumerate(zip(output_ids_0, output_ids_1)):
|
||||
|
||||
# If generated tokens don't match, then
|
||||
if output_id_0 != output_id_1:
|
||||
is_tok_mismatch = output_id_0 != output_id_1
|
||||
|
||||
# If generated tokens don't match
|
||||
# or it is desired to always check logprobs,
|
||||
# then
|
||||
if is_tok_mismatch or always_check_logprobs:
|
||||
logprobs_elem_0 = logprobs_0[idx]
|
||||
logprobs_elem_1 = logprobs_1[idx]
|
||||
|
||||
@ -111,7 +120,7 @@ def check_logprobs_close(
|
||||
assert output_id_0 in logprobs_elem_1, fail_msg
|
||||
assert output_id_1 in logprobs_elem_0, fail_msg
|
||||
|
||||
if warn_on_mismatch:
|
||||
if warn_on_mismatch and is_tok_mismatch:
|
||||
with warnings.catch_warnings():
|
||||
# This ensures that repeated warnings are shown
|
||||
# in the output, not just the first occurrence
|
||||
|
@ -1,85 +0,0 @@
|
||||
# Test the AsyncLLMEngine with multi-step-decoding
|
||||
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
from ..utils import RemoteOpenAIServer
|
||||
|
||||
MODELS = [
|
||||
"JackFram/llama-160m",
|
||||
]
|
||||
NUM_SCHEDULER_STEPS = [8] # Multi-step decoding steps
|
||||
NUM_PROMPTS = [10]
|
||||
|
||||
DEFAULT_SERVER_ARGS: List[str] = [
|
||||
"--disable-log-requests",
|
||||
"--use-v2-block-manager",
|
||||
"--worker-use-ray",
|
||||
"--gpu-memory-utilization",
|
||||
"0.85",
|
||||
"--swap-space",
|
||||
"16",
|
||||
]
|
||||
|
||||
|
||||
async def completions_with_server_args(prompts: List[str], model_name: str,
|
||||
server_cli_args: List[str]):
|
||||
|
||||
outputs = None
|
||||
with RemoteOpenAIServer(model_name, server_cli_args) as server:
|
||||
client = server.get_async_client()
|
||||
outputs = await client.completions.create(model=model_name,
|
||||
prompt=prompts,
|
||||
temperature=0,
|
||||
stream=False,
|
||||
max_tokens=5)
|
||||
assert outputs is not None
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize(("tp_size, pp_size"), [
|
||||
(1, 1),
|
||||
(2, 2),
|
||||
])
|
||||
@pytest.mark.parametrize("eager_mode", [False, True])
|
||||
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
|
||||
@pytest.mark.parametrize("num_prompts", NUM_PROMPTS)
|
||||
@pytest.mark.asyncio
|
||||
async def test_multi_step(example_prompts, model: str, tp_size: int,
|
||||
pp_size: int, eager_mode: int,
|
||||
num_scheduler_steps: int, num_prompts: int):
|
||||
|
||||
prompts = example_prompts
|
||||
if len(prompts) < num_prompts:
|
||||
prompts = prompts * ((num_prompts // len(prompts)) + 1)
|
||||
prompts = prompts[:num_prompts]
|
||||
assert len(prompts) == num_prompts
|
||||
|
||||
server_args = DEFAULT_SERVER_ARGS + ["--enforce-eager"]
|
||||
ms_server_args = DEFAULT_SERVER_ARGS + \
|
||||
["--num-scheduler-steps", f"{num_scheduler_steps}"]
|
||||
|
||||
if eager_mode:
|
||||
ms_server_args.append("--enforce-eager")
|
||||
|
||||
distributed_args = [
|
||||
"--tensor-parallel-size",
|
||||
str(tp_size),
|
||||
"--pipeline-parallel-size",
|
||||
str(pp_size),
|
||||
]
|
||||
|
||||
ref_completions = await completions_with_server_args(
|
||||
prompts, model, server_args + distributed_args)
|
||||
test_completions = await completions_with_server_args(
|
||||
prompts, model, ms_server_args + distributed_args)
|
||||
|
||||
def get_text_generations(completions):
|
||||
return [x.text for x in completions.choices]
|
||||
|
||||
ref_generations = get_text_generations(ref_completions)
|
||||
test_generations = get_text_generations(test_completions)
|
||||
assert ref_generations == test_generations
|
129
tests/multi_step/test_correctness_async_llm.py
Normal file
129
tests/multi_step/test_correctness_async_llm.py
Normal file
@ -0,0 +1,129 @@
|
||||
# Test the AsyncLLMEngine with multi-step-decoding
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from ..models.utils import check_logprobs_close
|
||||
from ..utils import (completions_with_server_args, get_client_text_generations,
|
||||
get_client_text_logprob_generations)
|
||||
|
||||
MODELS = [
|
||||
"JackFram/llama-160m",
|
||||
]
|
||||
NUM_SCHEDULER_STEPS = [8] # Multi-step decoding steps
|
||||
NUM_PROMPTS = [10]
|
||||
|
||||
DEFAULT_SERVER_ARGS: List[str] = [
|
||||
"--disable-log-requests",
|
||||
"--use-v2-block-manager",
|
||||
"--worker-use-ray",
|
||||
"--gpu-memory-utilization",
|
||||
"0.85",
|
||||
"--swap-space",
|
||||
"16",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize(("tp_size, pp_size"), [
|
||||
(1, 1),
|
||||
(2, 2),
|
||||
])
|
||||
@pytest.mark.parametrize("eager_mode", [False, True])
|
||||
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
|
||||
@pytest.mark.parametrize("num_prompts", NUM_PROMPTS)
|
||||
@pytest.mark.parametrize("num_logprobs", [None, 5])
|
||||
@pytest.mark.parametrize("is_async", [False, True])
|
||||
@pytest.mark.asyncio
|
||||
async def test_multi_step(
|
||||
example_prompts,
|
||||
model: str,
|
||||
tp_size: int,
|
||||
pp_size: int,
|
||||
eager_mode: int,
|
||||
num_scheduler_steps: int,
|
||||
num_prompts: int,
|
||||
is_async: bool,
|
||||
num_logprobs: Optional[int],
|
||||
) -> None:
|
||||
"""Test vLLM engine with multi-step scheduling in an OpenAI-protocol
|
||||
client/server environment.
|
||||
|
||||
Set up an engine with single-step scheduling as a ground-truth reference.
|
||||
|
||||
Send a completions API request to both engines with the same prompts.
|
||||
|
||||
Validate:
|
||||
* Generated tokens match
|
||||
* Generated logprobs are all very close
|
||||
|
||||
Args:
|
||||
example_prompts: test fixture providing example prompts
|
||||
model: model under test (same for single- and multi-step engines)
|
||||
tp_size: degree of tensor-parallelism
|
||||
pp_size: degree of pipeline-parallelism
|
||||
eager_mode
|
||||
num_scheduler_steps: for multi-step scheduling, GPU-side steps per
|
||||
GPU -> CPU output transfer
|
||||
num_prompts: number of example prompts under test
|
||||
num_logprobs: corresponds to the `logprobs` argument to the OpenAI
|
||||
completions endpoint; `None` -> no logprobs
|
||||
"""
|
||||
|
||||
prompts = example_prompts
|
||||
if len(prompts) < num_prompts:
|
||||
prompts = prompts * ((num_prompts // len(prompts)) + 1)
|
||||
prompts = prompts[:num_prompts]
|
||||
assert len(prompts) == num_prompts
|
||||
|
||||
server_args = DEFAULT_SERVER_ARGS + ["--enforce-eager"]
|
||||
ms_server_args = DEFAULT_SERVER_ARGS + \
|
||||
["--num-scheduler-steps", f"{num_scheduler_steps}"]
|
||||
|
||||
if not is_async:
|
||||
ms_server_args += ["--disable-async-output-proc"]
|
||||
|
||||
if eager_mode:
|
||||
ms_server_args.append("--enforce-eager")
|
||||
|
||||
distributed_args = [
|
||||
"--tensor-parallel-size",
|
||||
str(tp_size),
|
||||
"--pipeline-parallel-size",
|
||||
str(pp_size),
|
||||
]
|
||||
|
||||
# Spin up client/server & issue completion API requests.
|
||||
# Default `max_wait_seconds` is 240 but was empirically
|
||||
# was raised 3x to 720 *just for this test* due to
|
||||
# observed timeouts in GHA CI
|
||||
ref_completions = await completions_with_server_args(
|
||||
prompts,
|
||||
model,
|
||||
server_args + distributed_args,
|
||||
num_logprobs,
|
||||
max_wait_seconds=5 * 240)
|
||||
test_completions = await completions_with_server_args(
|
||||
prompts,
|
||||
model,
|
||||
ms_server_args + distributed_args,
|
||||
num_logprobs,
|
||||
max_wait_seconds=5 * 240)
|
||||
|
||||
# Assert multi-step scheduling produces identical tokens
|
||||
# to single-step scheduling.
|
||||
ref_generations = get_client_text_generations(ref_completions)
|
||||
test_generations = get_client_text_generations(test_completions)
|
||||
assert ref_generations == test_generations
|
||||
|
||||
# Assert multi-step scheduling produces nearly-identical logprobs
|
||||
# to single-step scheduling.
|
||||
ref_text_logprobs = get_client_text_logprob_generations(ref_completions)
|
||||
test_text_logprobs = get_client_text_logprob_generations(test_completions)
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=ref_text_logprobs,
|
||||
outputs_1_lst=test_text_logprobs,
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
)
|
102
tests/multi_step/test_correctness_llm.py
Normal file
102
tests/multi_step/test_correctness_llm.py
Normal file
@ -0,0 +1,102 @@
|
||||
# Test the LLMEngine with multi-step-decoding
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from ..models.utils import check_logprobs_close, check_outputs_equal
|
||||
|
||||
MODELS = [
|
||||
"JackFram/llama-160m",
|
||||
]
|
||||
NUM_SCHEDULER_STEPS = [8] # Multi-step decoding steps
|
||||
NUM_PROMPTS = [10]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("tp_size", [1])
|
||||
@pytest.mark.parametrize("max_tokens", [5])
|
||||
@pytest.mark.parametrize("enforce_eager", [True])
|
||||
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
|
||||
@pytest.mark.parametrize("num_prompts", NUM_PROMPTS)
|
||||
@pytest.mark.parametrize("num_logprobs", [None, 5])
|
||||
def test_multi_step_llm(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
tp_size: int,
|
||||
max_tokens: int,
|
||||
enforce_eager: int,
|
||||
num_scheduler_steps: int,
|
||||
num_prompts: int,
|
||||
num_logprobs: Optional[int],
|
||||
) -> None:
|
||||
"""Test vLLM engine with multi-step scheduling via sync LLM Engine.
|
||||
|
||||
Set up a HuggingFace (HF) transformers model as a ground-truth reference.
|
||||
|
||||
Prompt them with the same example prompts.
|
||||
|
||||
Validate:
|
||||
* Generated tokens match
|
||||
* Generated logprobs are all very close
|
||||
|
||||
Args:
|
||||
hf_runner: HF transformers model runner fixture
|
||||
vllm_runner: vLLM model runner fixture
|
||||
example_prompts: test fixture providing example prompts
|
||||
model: model under test (same for single- and multi-step engines)
|
||||
dtype: tensor datatype for engine to utilize
|
||||
tp_size: degree of tensor-parallelism
|
||||
max_tokens: the maximum number of tokens to generate
|
||||
enforce_eager
|
||||
num_scheduler_steps: for multi-step scheduling, GPU-side steps per
|
||||
GPU -> CPU output transfer
|
||||
num_prompts: number of example prompts under test
|
||||
num_logprobs: corresponds to the `logprobs` argument to the OpenAI
|
||||
completions endpoint; `None` -> no logprobs
|
||||
"""
|
||||
|
||||
prompts = example_prompts
|
||||
if len(prompts) < num_prompts:
|
||||
prompts = prompts * ((num_prompts // len(prompts)) + 1)
|
||||
prompts = prompts[:num_prompts]
|
||||
assert len(prompts) == num_prompts
|
||||
|
||||
with vllm_runner(
|
||||
model,
|
||||
dtype=dtype,
|
||||
enforce_eager=enforce_eager,
|
||||
gpu_memory_utilization=0.7,
|
||||
tensor_parallel_size=tp_size,
|
||||
use_v2_block_manager=True,
|
||||
num_scheduler_steps=num_scheduler_steps,
|
||||
) as vllm_model:
|
||||
vllm_outputs = (vllm_model.generate_greedy(prompts, max_tokens)
|
||||
if num_logprobs is None else
|
||||
vllm_model.generate_greedy_logprobs(
|
||||
prompts, max_tokens, num_logprobs))
|
||||
|
||||
with hf_runner(model, dtype=dtype) as hf_model:
|
||||
hf_outputs = (hf_model.generate_greedy(prompts, max_tokens)
|
||||
if num_logprobs is None else
|
||||
hf_model.generate_greedy_logprobs_limit(
|
||||
prompts, max_tokens, num_logprobs))
|
||||
|
||||
if num_logprobs is None:
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=vllm_outputs,
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
)
|
||||
else:
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=vllm_outputs,
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
)
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user