mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
Compare commits
168 Commits
v0.4.0.pos
...
jax-tpu
Author | SHA1 | Date | |
---|---|---|---|
c00ddd6834 | |||
881b884046 | |||
98a3df0f8d | |||
3f6288cc89 | |||
408ff4950c | |||
278e8a1adc | |||
07be6ed3eb | |||
f6637dba18 | |||
707a5f6473 | |||
57690a9c09 | |||
b15db234ba | |||
d1591f0f1f | |||
85d4488458 | |||
8d072dbfbd | |||
d830766c0c | |||
5ae2f81c2b | |||
4ea41d01a9 | |||
d16a348477 | |||
aa092834bb | |||
d2c6a32c0c | |||
21f35c2289 | |||
2aa9831dd3 | |||
028f528aad | |||
fa5bacd5b0 | |||
b62170e4e3 | |||
98eda57899 | |||
81b8b813f1 | |||
e2c7dedb3a | |||
5323969fcf | |||
f42b4c27d8 | |||
620e7646d3 | |||
d5fb1c20c1 | |||
092e3d6d6d | |||
84284302d8 | |||
743695f586 | |||
62b870fa07 | |||
7e3a230c38 | |||
186c88c497 | |||
ef762cb110 | |||
756c4e78d3 | |||
4880de35d2 | |||
0fb07c08d0 | |||
e4377dd698 | |||
5cb213c85e | |||
25bbc21ef6 | |||
b25fcc06c2 | |||
6661c030c4 | |||
8888d1c474 | |||
cedb67028a | |||
91b47e3f2f | |||
6d62e4c6aa | |||
de82e95787 | |||
b3b89cf755 | |||
6692a30266 | |||
eb0a0466a9 | |||
c59c1e7b2c | |||
d4adf92beb | |||
05434764cd | |||
4e7ee664e2 | |||
37e84a403d | |||
4695397dcf | |||
d619ae2d19 | |||
eb46fbfda2 | |||
0003e9154b | |||
e11e200736 | |||
8db1bf32f8 | |||
aceb17cf2d | |||
563c54f760 | |||
2cd6b4f362 | |||
711a000255 | |||
989ae2538d | |||
0a430b4ae2 | |||
ec8e3c695f | |||
98afde19fc | |||
5c2e66e487 | |||
546e721168 | |||
b8aacac31a | |||
d04973ad54 | |||
fbb9d9eef4 | |||
09473ee41c | |||
d4ec9ffb95 | |||
96b6a6d790 | |||
36729bac13 | |||
7fd3949a0b | |||
1096717ae9 | |||
c2b4a1bce9 | |||
e46a60aa4c | |||
1e96c3341a | |||
95e7d4a97c | |||
559eb852f8 | |||
a10d3056da | |||
8afca50889 | |||
08ccee1e83 | |||
c1dc547129 | |||
f3d0bf7589 | |||
e9da5a40c6 | |||
e42df7227d | |||
caada5e50a | |||
67b4221a61 | |||
63e7176f26 | |||
934d3662f7 | |||
92cd2e2f21 | |||
e4c4072c94 | |||
e35397468f | |||
8b317c6dd0 | |||
bd3c144e0b | |||
0258b7a94b | |||
363e6a950f | |||
696b653193 | |||
0d6402ddfd | |||
60ff6b8c5c | |||
b3104b2a10 | |||
c2e00af523 | |||
c013d32c75 | |||
11dd6ebb89 | |||
6c0b04515f | |||
e23a43aef8 | |||
e7c7067b45 | |||
6d592eb430 | |||
d036198e23 | |||
59a6abf3c9 | |||
bc0c0192d1 | |||
f46864d68d | |||
b4543c8f6b | |||
0ce0539d47 | |||
2f19283549 | |||
95baec828f | |||
e4be7d70bb | |||
54951ac4bf | |||
18de883489 | |||
1d7c940d74 | |||
cfaf49a167 | |||
9edec652e2 | |||
e0dd4d3589 | |||
e5043a3e75 | |||
d03d64fd2e | |||
78107fa091 | |||
c391e4b68e | |||
9117f892f0 | |||
db2a6a41e2 | |||
ca81ff5196 | |||
b7782002e1 | |||
819a309c0f | |||
aabe8f40f2 | |||
498eb5cfa3 | |||
537ee25f43 | |||
294f8f6665 | |||
b95047f2da | |||
2ff767b513 | |||
3dcb3e8b98 | |||
c64cf38673 | |||
76b889bf1d | |||
c9b506dad4 | |||
5757d90e26 | |||
d899009a63 | |||
6894d3efef | |||
38e3d33a62 | |||
02e614d922 | |||
46b31ed98d | |||
31d05f7edb | |||
4cdb732cef | |||
27c592b97b | |||
5083aa9092 | |||
824521c987 | |||
3b8f43024f | |||
d148c2ef00 | |||
86f073edd6 | |||
52a1e908e4 |
@ -12,7 +12,13 @@ steps:
|
||||
command: pytest -v -s async_engine
|
||||
|
||||
- label: Basic Correctness Test
|
||||
command: pytest -v -s basic_correctness
|
||||
commands:
|
||||
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_basic_correctness.py
|
||||
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_basic_correctness.py
|
||||
- VLLM_ATTENTION_BACKEND=ROCM_FLASH pytest -v -s basic_correctness/test_basic_correctness.py
|
||||
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py
|
||||
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py
|
||||
- VLLM_ATTENTION_BACKEND=ROCM_FLASH pytest -v -s basic_correctness/test_chunked_prefill.py
|
||||
|
||||
- label: Core Test
|
||||
command: pytest -v -s core
|
||||
@ -29,12 +35,17 @@ steps:
|
||||
- pytest -v -s test_pynccl.py
|
||||
- TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_basic_distributed_correctness.py
|
||||
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_basic_distributed_correctness.py
|
||||
- TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_chunked_prefill_distributed.py
|
||||
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_chunked_prefill_distributed.py
|
||||
|
||||
- label: Engine Test
|
||||
command: pytest -v -s engine tokenization test_sequence.py test_config.py
|
||||
|
||||
- label: Entrypoints Test
|
||||
command: pytest -v -s entrypoints
|
||||
commands:
|
||||
# these tests have to be separated, because each one will allocate all posible GPU memory
|
||||
- pytest -v -s entrypoints --ignore=entrypoints/test_server_oot_registration.py
|
||||
- pytest -v -s entrypoints/test_server_oot_registration.py
|
||||
|
||||
- label: Examples Test
|
||||
working_dir: "/vllm-workspace/examples"
|
||||
@ -80,6 +91,9 @@ steps:
|
||||
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
|
||||
parallelism: 4
|
||||
|
||||
- label: Tensorizer Test
|
||||
command: apt-get install curl libsodium23 && pytest -v -s tensorizer
|
||||
|
||||
- label: Metrics Test
|
||||
command: pytest -v -s metrics
|
||||
|
||||
@ -90,7 +104,7 @@ steps:
|
||||
- bash run-benchmarks.sh
|
||||
|
||||
- label: Documentation Build
|
||||
working_dir: "/vllm-workspace/docs"
|
||||
working_dir: "/vllm-workspace/test_docs/docs"
|
||||
no_gpu: True
|
||||
commands:
|
||||
- pip install -r requirements-docs.txt
|
||||
|
50
.github/workflows/mypy.yaml
vendored
Normal file
50
.github/workflows/mypy.yaml
vendored
Normal file
@ -0,0 +1,50 @@
|
||||
name: mypy
|
||||
|
||||
on:
|
||||
# Trigger the workflow on push or pull request,
|
||||
# but only for the main branch
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
|
||||
jobs:
|
||||
ruff:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.8", "3.9", "3.10", "3.11"]
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install mypy==1.9.0
|
||||
pip install types-setuptools
|
||||
pip install types-PyYAML
|
||||
pip install types-requests
|
||||
pip install types-setuptools
|
||||
- name: Mypy
|
||||
run: |
|
||||
mypy vllm/attention/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
mypy vllm/distributed/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
mypy vllm/entrypoints/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
mypy vllm/executor/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
mypy vllm/usage/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
mypy vllm/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
mypy vllm/transformers_utils/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
|
||||
# TODO(sang): Follow up
|
||||
# mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
# mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
# mypy vllm/spec_decoding/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
# mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
# mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
|
2
.github/workflows/publish.yml
vendored
2
.github/workflows/publish.yml
vendored
@ -49,7 +49,7 @@ jobs:
|
||||
matrix:
|
||||
os: ['ubuntu-20.04']
|
||||
python-version: ['3.8', '3.9', '3.10', '3.11']
|
||||
pytorch-version: ['2.1.2'] # Must be the most recent version that meets requirements.txt.
|
||||
pytorch-version: ['2.2.1'] # Must be the most recent version that meets requirements-cuda.txt.
|
||||
cuda-version: ['11.8', '12.1']
|
||||
|
||||
steps:
|
||||
|
2
.github/workflows/ruff.yml
vendored
2
.github/workflows/ruff.yml
vendored
@ -15,7 +15,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.10"]
|
||||
python-version: ["3.8", "3.9", "3.10", "3.11"]
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
|
2
.github/workflows/scripts/build.sh
vendored
2
.github/workflows/scripts/build.sh
vendored
@ -9,7 +9,7 @@ LD_LIBRARY_PATH=${cuda_home}/lib64:$LD_LIBRARY_PATH
|
||||
|
||||
# Install requirements
|
||||
$python_executable -m pip install wheel packaging
|
||||
$python_executable -m pip install -r requirements.txt
|
||||
$python_executable -m pip install -r requirements-cuda.txt
|
||||
|
||||
# Limit the number of parallel jobs to avoid OOM
|
||||
export MAX_JOBS=1
|
||||
|
2
.github/workflows/yapf.yml
vendored
2
.github/workflows/yapf.yml
vendored
@ -14,7 +14,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.10"]
|
||||
python-version: ["3.8", "3.9", "3.10", "3.11"]
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
|
1
.gitignore
vendored
1
.gitignore
vendored
@ -181,6 +181,7 @@ _build/
|
||||
# hip files generated by PyTorch
|
||||
*.hip
|
||||
*_hip*
|
||||
hip_compat.h
|
||||
|
||||
# Benchmark dataset
|
||||
*.json
|
||||
|
@ -19,7 +19,7 @@ set(PYTHON_SUPPORTED_VERSIONS "3.8" "3.9" "3.10" "3.11")
|
||||
set(CUDA_SUPPORTED_ARCHS "7.0;7.5;8.0;8.6;8.9;9.0")
|
||||
|
||||
# Supported AMD GPU architectures.
|
||||
set(HIP_SUPPORTED_ARCHS "gfx908;gfx90a;gfx942;gfx1100")
|
||||
set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100")
|
||||
|
||||
#
|
||||
# Supported/expected torch versions for CUDA/ROCm.
|
||||
@ -31,7 +31,7 @@ set(HIP_SUPPORTED_ARCHS "gfx908;gfx90a;gfx942;gfx1100")
|
||||
# requirements.txt files and should be kept consistent. The ROCm torch
|
||||
# versions are derived from Dockerfile.rocm
|
||||
#
|
||||
set(TORCH_SUPPORTED_VERSION_CUDA "2.1.2")
|
||||
set(TORCH_SUPPORTED_VERSION_CUDA "2.2.1")
|
||||
set(TORCH_SUPPORTED_VERSION_ROCM_5X "2.0.1")
|
||||
set(TORCH_SUPPORTED_VERSION_ROCM_6X "2.1.1")
|
||||
|
||||
|
@ -21,7 +21,6 @@ Express your support on Twitter if vLLM aids you, or simply offer your appreciat
|
||||
### Build from source
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
pip install -e . # This may take several minutes.
|
||||
```
|
||||
|
||||
@ -30,6 +29,8 @@ pip install -e . # This may take several minutes.
|
||||
```bash
|
||||
pip install -r requirements-dev.txt
|
||||
|
||||
# linting and formatting
|
||||
bash format.sh
|
||||
# Static type checking
|
||||
mypy
|
||||
# Unit tests
|
||||
|
107
Dockerfile
107
Dockerfile
@ -2,6 +2,7 @@
|
||||
# to run the OpenAI compatible server.
|
||||
|
||||
#################### BASE BUILD IMAGE ####################
|
||||
# prepare basic build environment
|
||||
FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 AS dev
|
||||
|
||||
RUN apt-get update -y \
|
||||
@ -16,18 +17,26 @@ RUN ldconfig /usr/local/cuda-12.1/compat/
|
||||
WORKDIR /workspace
|
||||
|
||||
# install build and runtime dependencies
|
||||
COPY requirements.txt requirements.txt
|
||||
COPY requirements-common.txt requirements-common.txt
|
||||
COPY requirements-cuda.txt requirements-cuda.txt
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
pip install -r requirements.txt
|
||||
pip install -r requirements-cuda.txt
|
||||
|
||||
# install development dependencies
|
||||
COPY requirements-dev.txt requirements-dev.txt
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
pip install -r requirements-dev.txt
|
||||
|
||||
# cuda arch list used by torch
|
||||
# can be useful for both `dev` and `test`
|
||||
# explicitly set the list to avoid issues with torch 2.2
|
||||
# see https://github.com/pytorch/pytorch/pull/123243
|
||||
ARG torch_cuda_arch_list='7.0 7.5 8.0 8.6 8.9 9.0+PTX'
|
||||
ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
|
||||
#################### BASE BUILD IMAGE ####################
|
||||
|
||||
|
||||
#################### EXTENSION BUILD IMAGE ####################
|
||||
#################### WHEEL BUILD IMAGE ####################
|
||||
FROM dev AS build
|
||||
|
||||
# install build dependencies
|
||||
@ -38,18 +47,16 @@ RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
# install compiler cache to speed up compilation leveraging local or remote caching
|
||||
RUN apt-get update -y && apt-get install -y ccache
|
||||
|
||||
# copy input files
|
||||
# files and directories related to build wheels
|
||||
COPY csrc csrc
|
||||
COPY setup.py setup.py
|
||||
COPY cmake cmake
|
||||
COPY CMakeLists.txt CMakeLists.txt
|
||||
COPY requirements.txt requirements.txt
|
||||
COPY requirements-common.txt requirements-common.txt
|
||||
COPY requirements-cuda.txt requirements-cuda.txt
|
||||
COPY pyproject.toml pyproject.toml
|
||||
COPY vllm/__init__.py vllm/__init__.py
|
||||
COPY vllm vllm
|
||||
|
||||
# cuda arch list used by torch
|
||||
ARG torch_cuda_arch_list='7.0 7.5 8.0 8.6 8.9 9.0+PTX'
|
||||
ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
|
||||
# max jobs used by Ninja to build extensions
|
||||
ARG max_jobs=2
|
||||
ENV MAX_JOBS=${max_jobs}
|
||||
@ -61,7 +68,15 @@ ENV VLLM_INSTALL_PUNICA_KERNELS=1
|
||||
|
||||
ENV CCACHE_DIR=/root/.cache/ccache
|
||||
RUN --mount=type=cache,target=/root/.cache/ccache \
|
||||
python3 setup.py build_ext --inplace
|
||||
--mount=type=cache,target=/root/.cache/pip \
|
||||
python3 setup.py bdist_wheel --dist-dir=dist
|
||||
|
||||
# the `vllm_nccl` package must be installed from source distribution
|
||||
# pip is too smart to store a wheel in the cache, and other CI jobs
|
||||
# will directly use the wheel from the cache, which is not what we want.
|
||||
# we need to remove it manually
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
pip cache remove vllm_nccl*
|
||||
#################### EXTENSION Build IMAGE ####################
|
||||
|
||||
#################### FLASH_ATTENTION Build IMAGE ####################
|
||||
@ -81,57 +96,59 @@ RUN pip --verbose wheel flash-attn==${FLASH_ATTN_VERSION} \
|
||||
|
||||
#################### FLASH_ATTENTION Build IMAGE ####################
|
||||
|
||||
#################### vLLM installation IMAGE ####################
|
||||
# image with vLLM installed
|
||||
FROM nvidia/cuda:12.1.0-base-ubuntu22.04 AS vllm-base
|
||||
WORKDIR /vllm-workspace
|
||||
|
||||
RUN apt-get update -y \
|
||||
&& apt-get install -y python3-pip git vim
|
||||
|
||||
# Workaround for https://github.com/openai/triton/issues/2507 and
|
||||
# https://github.com/pytorch/pytorch/issues/107960 -- hopefully
|
||||
# this won't be needed for future versions of this docker image
|
||||
# or future versions of triton.
|
||||
RUN ldconfig /usr/local/cuda-12.1/compat/
|
||||
|
||||
# install vllm wheel first, so that torch etc will be installed
|
||||
RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist \
|
||||
--mount=type=cache,target=/root/.cache/pip \
|
||||
pip install dist/*.whl --verbose
|
||||
|
||||
RUN --mount=type=bind,from=flash-attn-builder,src=/usr/src/flash-attention-v2,target=/usr/src/flash-attention-v2 \
|
||||
--mount=type=cache,target=/root/.cache/pip \
|
||||
pip install /usr/src/flash-attention-v2/*.whl --no-cache-dir
|
||||
#################### vLLM installation IMAGE ####################
|
||||
|
||||
|
||||
#################### TEST IMAGE ####################
|
||||
# image to run unit testing suite
|
||||
FROM dev AS test
|
||||
# note that this uses vllm installed by `pip`
|
||||
FROM vllm-base AS test
|
||||
|
||||
# copy pytorch extensions separately to avoid having to rebuild
|
||||
# when python code changes
|
||||
WORKDIR /vllm-workspace
|
||||
# ADD is used to preserve directory structure
|
||||
ADD . /vllm-workspace/
|
||||
COPY --from=build /workspace/vllm/*.so /vllm-workspace/vllm/
|
||||
# Install flash attention (from pre-built wheel)
|
||||
RUN --mount=type=bind,from=flash-attn-builder,src=/usr/src/flash-attention-v2,target=/usr/src/flash-attention-v2 \
|
||||
pip install /usr/src/flash-attention-v2/*.whl --no-cache-dir
|
||||
# ignore build dependencies installation because we are using pre-complied extensions
|
||||
RUN rm pyproject.toml
|
||||
RUN --mount=type=cache,target=/root/.cache/pip VLLM_USE_PRECOMPILED=1 pip install . --verbose
|
||||
#################### TEST IMAGE ####################
|
||||
|
||||
|
||||
#################### RUNTIME BASE IMAGE ####################
|
||||
# We used base cuda image because pytorch installs its own cuda libraries.
|
||||
# However pynccl depends on cuda libraries so we had to switch to the runtime image
|
||||
# In the future it would be nice to get a container with pytorch and cuda without duplicating cuda
|
||||
FROM nvidia/cuda:12.1.0-runtime-ubuntu22.04 AS vllm-base
|
||||
|
||||
# libnccl required for ray
|
||||
RUN apt-get update -y \
|
||||
&& apt-get install -y python3-pip
|
||||
|
||||
WORKDIR /workspace
|
||||
COPY requirements.txt requirements.txt
|
||||
# install development dependencies (for testing)
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
pip install -r requirements.txt
|
||||
pip install -r requirements-dev.txt
|
||||
|
||||
# Install flash attention (from pre-built wheel)
|
||||
RUN --mount=type=bind,from=flash-attn-builder,src=/usr/src/flash-attention-v2,target=/usr/src/flash-attention-v2 \
|
||||
pip install /usr/src/flash-attention-v2/*.whl --no-cache-dir
|
||||
|
||||
#################### RUNTIME BASE IMAGE ####################
|
||||
# doc requires source code
|
||||
# we hide them inside `test_docs/` , so that this source code
|
||||
# will not be imported by other tests
|
||||
RUN mkdir test_docs
|
||||
RUN mv docs test_docs/
|
||||
RUN mv vllm test_docs/
|
||||
|
||||
#################### TEST IMAGE ####################
|
||||
|
||||
#################### OPENAI API SERVER ####################
|
||||
# openai api server alternative
|
||||
FROM vllm-base AS vllm-openai
|
||||
|
||||
# install additional dependencies for openai api server
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
pip install accelerate hf_transfer modelscope
|
||||
|
||||
COPY --from=build /workspace/vllm/*.so /workspace/vllm/
|
||||
COPY vllm vllm
|
||||
|
||||
ENV VLLM_USAGE_SOURCE production-docker-image
|
||||
|
||||
ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"]
|
||||
|
@ -23,6 +23,9 @@ RUN echo "FA_BRANCH is $FA_BRANCH"
|
||||
# In that case, we need to use the python reference attention implementation in vllm
|
||||
ARG BUILD_FA="1"
|
||||
|
||||
# whether to build triton on rocm
|
||||
ARG BUILD_TRITON="1"
|
||||
|
||||
# Install some basic utilities
|
||||
RUN apt-get update && apt-get install python3 python3-pip -y
|
||||
|
||||
@ -75,9 +78,20 @@ RUN if [ "$BUILD_FA" = "1" ]; then \
|
||||
RUN if [ "$BASE_IMAGE" = "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" ]; then \
|
||||
rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/; fi
|
||||
|
||||
# build triton
|
||||
RUN if [ "$BUILD_TRITON" = "1" ]; then \
|
||||
mkdir -p libs \
|
||||
&& cd libs \
|
||||
&& pip uninstall -y triton \
|
||||
&& git clone https://github.com/ROCm/triton.git \
|
||||
&& cd triton/python \
|
||||
&& pip3 install . \
|
||||
&& cd ../..; \
|
||||
fi
|
||||
|
||||
COPY ./ /app/vllm
|
||||
|
||||
RUN python3 -m pip install --upgrade pip
|
||||
RUN python3 -m pip install --upgrade pip numba
|
||||
RUN python3 -m pip install xformers==0.0.23 --no-deps
|
||||
|
||||
RUN cd /app \
|
||||
|
@ -1,5 +1,6 @@
|
||||
include LICENSE
|
||||
include requirements.txt
|
||||
include requirements-common.txt
|
||||
include requirements-cuda.txt
|
||||
include CMakeLists.txt
|
||||
|
||||
recursive-include cmake *
|
||||
|
17
README.md
17
README.md
@ -14,18 +14,8 @@ Easy, fast, and cheap LLM serving for everyone
|
||||
|
||||
</p>
|
||||
|
||||
---
|
||||
|
||||
**The Third vLLM Bay Area Meetup (April 2nd 6pm-8:30pm PT)**
|
||||
|
||||
We are thrilled to announce our third vLLM Meetup!
|
||||
The vLLM team will share recent updates and roadmap.
|
||||
We will also have vLLM collaborators from Roblox coming up to the stage to discuss their experience in deploying LLMs with vLLM.
|
||||
Please register [here](https://robloxandvllmmeetup2024.splashthat.com/) and join us!
|
||||
|
||||
---
|
||||
|
||||
*Latest News* 🔥
|
||||
- [2024/04] We hosted [the third vLLM meetup](https://robloxandvllmmeetup2024.splashthat.com/) with Roblox! Please find the meetup slides [here](https://docs.google.com/presentation/d/1A--47JAK4BJ39t954HyTkvtfwn0fkqtsL8NGFuslReM/edit?usp=sharing).
|
||||
- [2024/01] We hosted [the second vLLM meetup](https://lu.ma/ygxbpzhl) in SF! Please find the meetup slides [here](https://docs.google.com/presentation/d/12mI2sKABnUw5RBWXDYY-HtHth4iMSNcEoQ10jDQbxgA/edit?usp=sharing).
|
||||
- [2024/01] Added ROCm 6.0 support to vLLM.
|
||||
- [2023/12] Added ROCm 5.7 support to vLLM.
|
||||
@ -80,15 +70,16 @@ vLLM seamlessly supports many Hugging Face models, including the following archi
|
||||
- InternLM2 (`internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc.)
|
||||
- Jais (`core42/jais-13b`, `core42/jais-13b-chat`, `core42/jais-30b-v3`, `core42/jais-30b-chat-v3`, etc.)
|
||||
- LLaMA & LLaMA-2 (`meta-llama/Llama-2-70b-hf`, `lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.)
|
||||
- MiniCPM (`openbmb/MiniCPM-2B-sft-bf16`, `openbmb/MiniCPM-2B-dpo-bf16`, etc.)
|
||||
- Mistral (`mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc.)
|
||||
- Mixtral (`mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, etc.)
|
||||
- Mixtral (`mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, `mistral-community/Mixtral-8x22B-v0.1`, etc.)
|
||||
- MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.)
|
||||
- OLMo (`allenai/OLMo-1B`, `allenai/OLMo-7B`, etc.)
|
||||
- OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.)
|
||||
- Orion (`OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc.)
|
||||
- Phi (`microsoft/phi-1_5`, `microsoft/phi-2`, etc.)
|
||||
- Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.)
|
||||
- Qwen2 (`Qwen/Qwen2-7B-beta`, `Qwen/Qwen-7B-Chat-beta`, etc.)
|
||||
- Qwen2 (`Qwen/Qwen1.5-7B`, `Qwen/Qwen1.5-7B-Chat`, etc.)
|
||||
- Qwen2MoE (`Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc.)
|
||||
- StableLM(`stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc.)
|
||||
- Starcoder2(`bigcode/starcoder2-3b`, `bigcode/starcoder2-7b`, `bigcode/starcoder2-15b`, etc.)
|
||||
|
@ -27,8 +27,8 @@ class RequestFuncInput:
|
||||
class RequestFuncOutput:
|
||||
generated_text: str = ""
|
||||
success: bool = False
|
||||
latency: float = 0
|
||||
ttft: float = 0 # Time to first token
|
||||
latency: float = 0.0
|
||||
ttft: float = 0.0 # Time to first token
|
||||
itl: List[float] = field(
|
||||
default_factory=list) # List of inter-token latencies
|
||||
prompt_len: int = 0
|
||||
@ -58,23 +58,24 @@ async def async_request_tgi(
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
|
||||
ttft = 0
|
||||
ttft = 0.0
|
||||
st = time.perf_counter()
|
||||
most_recent_timestamp = st
|
||||
try:
|
||||
async with session.post(url=api_url, json=payload) as response:
|
||||
if response.status == 200:
|
||||
async for chunk in response.content:
|
||||
chunk = chunk.strip()
|
||||
if not chunk:
|
||||
async for chunk_bytes in response.content:
|
||||
chunk_bytes = chunk_bytes.strip()
|
||||
if not chunk_bytes:
|
||||
continue
|
||||
|
||||
chunk = remove_prefix(chunk.decode("utf-8"), "data:")
|
||||
chunk = remove_prefix(chunk_bytes.decode("utf-8"),
|
||||
"data:")
|
||||
|
||||
data = json.loads(chunk)
|
||||
timestamp = time.perf_counter()
|
||||
# First token
|
||||
if ttft == 0:
|
||||
if ttft == 0.0:
|
||||
ttft = time.perf_counter() - st
|
||||
output.ttft = ttft
|
||||
|
||||
@ -119,23 +120,24 @@ async def async_request_trt_llm(
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
|
||||
ttft = 0
|
||||
ttft = 0.0
|
||||
st = time.perf_counter()
|
||||
most_recent_timestamp = st
|
||||
try:
|
||||
async with session.post(url=api_url, json=payload) as response:
|
||||
if response.status == 200:
|
||||
async for chunk in response.content:
|
||||
chunk = chunk.strip()
|
||||
if not chunk:
|
||||
async for chunk_bytes in response.content:
|
||||
chunk_bytes = chunk_bytes.strip()
|
||||
if not chunk_bytes:
|
||||
continue
|
||||
|
||||
chunk = remove_prefix(chunk.decode("utf-8"), "data:")
|
||||
chunk = remove_prefix(chunk_bytes.decode("utf-8"),
|
||||
"data:")
|
||||
|
||||
data = json.loads(chunk)
|
||||
timestamp = time.perf_counter()
|
||||
# First token
|
||||
if ttft == 0:
|
||||
if ttft == 0.0:
|
||||
ttft = time.perf_counter() - st
|
||||
output.ttft = ttft
|
||||
|
||||
@ -151,7 +153,7 @@ async def async_request_trt_llm(
|
||||
output.success = True
|
||||
|
||||
else:
|
||||
output.error = response.reason
|
||||
output.error = response.reason or ""
|
||||
output.success = False
|
||||
except Exception:
|
||||
output.success = False
|
||||
@ -195,7 +197,7 @@ async def async_request_deepspeed_mii(
|
||||
output.generated_text = parsed_resp["text"][0]
|
||||
output.success = True
|
||||
else:
|
||||
output.error = response.reason
|
||||
output.error = response.reason or ""
|
||||
output.success = False
|
||||
except Exception:
|
||||
output.success = False
|
||||
@ -234,19 +236,20 @@ async def async_request_openai_completions(
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
|
||||
generated_text = ""
|
||||
ttft = 0
|
||||
ttft = 0.0
|
||||
st = time.perf_counter()
|
||||
most_recent_timestamp = st
|
||||
try:
|
||||
async with session.post(url=api_url, json=payload,
|
||||
headers=headers) as response:
|
||||
if response.status == 200:
|
||||
async for chunk in response.content:
|
||||
chunk = chunk.strip()
|
||||
if not chunk:
|
||||
async for chunk_bytes in response.content:
|
||||
chunk_bytes = chunk_bytes.strip()
|
||||
if not chunk_bytes:
|
||||
continue
|
||||
|
||||
chunk = remove_prefix(chunk.decode("utf-8"), "data: ")
|
||||
chunk = remove_prefix(chunk_bytes.decode("utf-8"),
|
||||
"data: ")
|
||||
if chunk == "[DONE]":
|
||||
latency = time.perf_counter() - st
|
||||
else:
|
||||
@ -255,7 +258,7 @@ async def async_request_openai_completions(
|
||||
if data["choices"][0]["text"]:
|
||||
timestamp = time.perf_counter()
|
||||
# First token
|
||||
if ttft == 0:
|
||||
if ttft == 0.0:
|
||||
ttft = time.perf_counter() - st
|
||||
output.ttft = ttft
|
||||
|
||||
@ -315,19 +318,20 @@ async def async_request_openai_chat_completions(
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
|
||||
generated_text = ""
|
||||
ttft = 0
|
||||
ttft = 0.0
|
||||
st = time.perf_counter()
|
||||
most_recent_timestamp = st
|
||||
try:
|
||||
async with session.post(url=api_url, json=payload,
|
||||
headers=headers) as response:
|
||||
if response.status == 200:
|
||||
async for chunk in response.content:
|
||||
chunk = chunk.strip()
|
||||
if not chunk:
|
||||
async for chunk_bytes in response.content:
|
||||
chunk_bytes = chunk_bytes.strip()
|
||||
if not chunk_bytes:
|
||||
continue
|
||||
|
||||
chunk = remove_prefix(chunk.decode("utf-8"), "data: ")
|
||||
chunk = remove_prefix(chunk_bytes.decode("utf-8"),
|
||||
"data: ")
|
||||
if chunk == "[DONE]":
|
||||
latency = time.perf_counter() - st
|
||||
else:
|
||||
@ -337,7 +341,7 @@ async def async_request_openai_chat_completions(
|
||||
delta = data["choices"][0]["delta"]
|
||||
if delta.get("content", None):
|
||||
# First token
|
||||
if ttft == 0:
|
||||
if ttft == 0.0:
|
||||
ttft = time.perf_counter() - st
|
||||
output.ttft = ttft
|
||||
|
||||
@ -354,7 +358,7 @@ async def async_request_openai_chat_completions(
|
||||
output.success = True
|
||||
output.latency = latency
|
||||
else:
|
||||
output.error = response.reason
|
||||
output.error = response.reason or ""
|
||||
output.success = False
|
||||
except Exception:
|
||||
output.success = False
|
||||
|
148
benchmarks/bench_cache_write.py
Normal file
148
benchmarks/bench_cache_write.py
Normal file
@ -0,0 +1,148 @@
|
||||
import functools
|
||||
import time
|
||||
from typing import Tuple
|
||||
|
||||
import chex
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
_PAD_SLOT_ID = -1
|
||||
|
||||
|
||||
@jax.jit
|
||||
def write_to_kv_cache1(
|
||||
key: jax.Array, # [batch_size, seq_len, num_heads, head_size]
|
||||
value: jax.Array, # [batch_size, seq_len, num_heads, head_size]
|
||||
k_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]
|
||||
v_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]
|
||||
slot_mapping: jax.Array, # [batch_size, seq_len]
|
||||
) -> Tuple[jax.Array, jax.Array]:
|
||||
num_heads = key.shape[-2]
|
||||
head_size = key.shape[-1]
|
||||
|
||||
key = key.reshape(-1, num_heads, head_size)
|
||||
key = key.transpose((1, 0, 2))
|
||||
value = value.reshape(-1, num_heads, head_size)
|
||||
value = value.transpose((1, 0, 2))
|
||||
|
||||
k_cache = k_cache.at[:, slot_mapping.reshape(-1), :].set(key)
|
||||
v_cache = v_cache.at[:, slot_mapping.reshape(-1), :].set(value)
|
||||
return k_cache, v_cache
|
||||
|
||||
|
||||
@functools.partial(jax.jit, donate_argnums=(2, 3))
|
||||
def write_to_kv_cache2(
|
||||
key: jax.Array, # [batch_size, seq_len, num_heads, head_size]
|
||||
value: jax.Array, # [batch_size, seq_len, num_heads, head_size]
|
||||
k_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]
|
||||
v_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]
|
||||
slot_mapping: jax.Array, # [batch_size, seq_len]
|
||||
) -> Tuple[jax.Array, jax.Array]:
|
||||
batch_size = slot_mapping.shape[0]
|
||||
|
||||
def cond(val: _IteratorState):
|
||||
return val.idx < batch_size
|
||||
|
||||
def body(val: _IteratorState):
|
||||
k_cache, v_cache = _write_seq_to_kv_cache(
|
||||
key[val.idx],
|
||||
value[val.idx],
|
||||
val.k_cache,
|
||||
val.v_cache,
|
||||
slot_mapping[val.idx],
|
||||
)
|
||||
val.k_cache = k_cache
|
||||
val.v_cache = v_cache
|
||||
val.idx += 1
|
||||
return val
|
||||
|
||||
iterator = _IteratorState(idx=0, k_cache=k_cache, v_cache=v_cache)
|
||||
iterator = jax.lax.while_loop(cond, body, iterator)
|
||||
return iterator.k_cache, iterator.v_cache
|
||||
|
||||
|
||||
@functools.partial(jax.jit, donate_argnums=(2, 3))
|
||||
def _write_seq_to_kv_cache(
|
||||
key: jax.Array, # [seq_len, num_heads, head_size]
|
||||
value: jax.Array, # [seq_len, num_heads, head_size]
|
||||
k_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]
|
||||
v_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]
|
||||
slot_mapping: jax.Array, # [seq_len]
|
||||
) -> Tuple[jax.Array, jax.Array]:
|
||||
seq_len = slot_mapping.shape[0]
|
||||
num_heads, _, head_size = k_cache.shape
|
||||
# Reshape to match the rank of kv_cache.
|
||||
key = key.reshape(seq_len, num_heads, 1, head_size)
|
||||
value = value.reshape(seq_len, num_heads, 1, head_size)
|
||||
|
||||
def cond(val: _IteratorState):
|
||||
return jnp.logical_and(
|
||||
val.idx < seq_len, slot_mapping[val.idx] != _PAD_SLOT_ID)
|
||||
|
||||
def body(val: _IteratorState):
|
||||
slot_idx = slot_mapping[val.idx]
|
||||
val.k_cache = jax.lax.dynamic_update_slice(
|
||||
val.k_cache,
|
||||
key[val.idx],
|
||||
(0, slot_idx, 0),
|
||||
)
|
||||
val.v_cache = jax.lax.dynamic_update_slice(
|
||||
val.v_cache,
|
||||
value[val.idx],
|
||||
(0, slot_idx, 0),
|
||||
)
|
||||
val.idx += 1
|
||||
return val
|
||||
|
||||
iterator = _IteratorState(idx=0, k_cache=k_cache, v_cache=v_cache)
|
||||
iterator = jax.lax.while_loop(cond, body, iterator)
|
||||
return iterator.k_cache, iterator.v_cache
|
||||
|
||||
|
||||
@chex.dataclass
|
||||
class _IteratorState:
|
||||
|
||||
idx: jnp.int32
|
||||
k_cache: jnp.ndarray # [num_heads, num_blocks, block_size, head_size]
|
||||
v_cache: jnp.ndarray # [num_heads, num_blocks, block_size, head_size]
|
||||
|
||||
|
||||
def benchmark_write_to_kv_cache(
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
version: int = 1,
|
||||
):
|
||||
if version == 1:
|
||||
f = write_to_kv_cache1
|
||||
elif version == 2:
|
||||
f = write_to_kv_cache2
|
||||
else:
|
||||
raise ValueError(f"Invalid version: {version}")
|
||||
|
||||
rng_key = jax.random.PRNGKey(0)
|
||||
key = jax.random.normal(rng_key, (batch_size, seq_len, num_kv_heads, head_size), dtype=jnp.bfloat16)
|
||||
value = jax.random.normal(rng_key, (batch_size, seq_len, num_kv_heads, head_size), dtype=jnp.bfloat16)
|
||||
k_cache = jax.random.normal(rng_key, (num_kv_heads, num_blocks * block_size, head_size), dtype=jnp.bfloat16)
|
||||
v_cache = jax.random.normal(rng_key, (num_kv_heads, num_blocks * block_size, head_size), dtype=jnp.bfloat16)
|
||||
slot_mapping = jax.random.randint(rng_key, (batch_size, seq_len), 0, num_blocks * block_size, dtype=jnp.int32)
|
||||
|
||||
# For JIT compilation.
|
||||
k_cache, v_cache = f(key, value, k_cache, v_cache, slot_mapping)
|
||||
k_cache.block_until_ready()
|
||||
|
||||
start = time.time()
|
||||
for _ in range(100):
|
||||
k_cache, v_cache = f(key, value, k_cache, v_cache, slot_mapping)
|
||||
k_cache.block_until_ready()
|
||||
end = time.time()
|
||||
print(f"Time taken: {(end - start) * 10:.2f} ms")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
for num_blocks in [16, 256, 512, 1024, 2048, 8192, 16384]:
|
||||
print(f"Benchmarking Write to KV Cache w/ {num_blocks} blocks")
|
||||
benchmark_write_to_kv_cache(16, 256, 16, 256, num_blocks, 16, version=1)
|
101
benchmarks/bench_paged_attn.py
Normal file
101
benchmarks/bench_paged_attn.py
Normal file
@ -0,0 +1,101 @@
|
||||
import argparse
|
||||
import functools
|
||||
import time
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax.experimental.pallas.ops.tpu.paged_attention import paged_attention
|
||||
|
||||
BLOCK_SIZE = 16
|
||||
MAX_NUM_BLOCKS_PER_SEQ = 512
|
||||
|
||||
|
||||
@functools.partial(jax.jit, static_argnums=(6, 7))
|
||||
def paged_attn(
|
||||
q: jax.Array, # [batch, 1, num_heads, head_size]
|
||||
k_cache: jax.Array, # [num_kv_heads, num_blocks * block_size, head_size]
|
||||
v_cache: jax.Array, # [num_kv_heads, num_blocks * block_size, head_size]
|
||||
sm_scale: float,
|
||||
block_tables: jax.Array, # [batch, max_num_blocks_per_batch]
|
||||
context_lens: jax.Array, # [batch]
|
||||
block_size: int,
|
||||
pages_per_compute_block: int,
|
||||
) -> jax.Array: # [batch, 1, num_heads, head_size]
|
||||
q = q.squeeze(1)
|
||||
q = q * sm_scale
|
||||
|
||||
head_size = q.shape[-1]
|
||||
num_slots = k_cache.shape[-2]
|
||||
k_cache = k_cache.reshape(-1, num_slots // block_size, block_size, head_size)
|
||||
v_cache = v_cache.reshape(-1, num_slots // block_size, block_size, head_size)
|
||||
|
||||
output = paged_attention(
|
||||
q,
|
||||
k_cache,
|
||||
v_cache,
|
||||
context_lens,
|
||||
block_tables,
|
||||
pages_per_compute_block=pages_per_compute_block,
|
||||
)
|
||||
return output.reshape(q.shape[0], 1, q.shape[1], q.shape[2])
|
||||
|
||||
|
||||
def benchmark_paged_attn(
|
||||
batch_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
context_len: int,
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
pages_per_compute_block: int,
|
||||
):
|
||||
rng_key = jax.random.PRNGKey(0)
|
||||
query = jax.random.normal(rng_key, (batch_size, 1, num_heads, head_size), dtype=jnp.bfloat16)
|
||||
k_cache = jax.random.normal(rng_key, (num_kv_heads, num_blocks * block_size, head_size), dtype=jnp.bfloat16)
|
||||
v_cache = jax.random.normal(rng_key, (num_kv_heads, num_blocks * block_size, head_size), dtype=jnp.bfloat16)
|
||||
sm_scale = head_size ** -0.5
|
||||
block_tables = jax.random.randint(rng_key, (batch_size, MAX_NUM_BLOCKS_PER_SEQ), 0, num_blocks, dtype=jnp.int32)
|
||||
context_lens = jnp.array([context_len] * batch_size, dtype=jnp.int32)
|
||||
|
||||
# For JIT compilation.
|
||||
output = paged_attn(query, k_cache, v_cache, sm_scale, block_tables, context_lens, block_size, pages_per_compute_block)
|
||||
output.block_until_ready()
|
||||
|
||||
start = time.time()
|
||||
for _ in range(100):
|
||||
output = paged_attn(query, k_cache, v_cache, sm_scale, block_tables, context_lens, block_size, pages_per_compute_block)
|
||||
output.block_until_ready()
|
||||
end = time.time()
|
||||
|
||||
print(f"Time taken: {(end - start) * 10000:.2f} us")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--batch-size", type=int, default=8)
|
||||
parser.add_argument("--num-heads", type=int, default=16)
|
||||
parser.add_argument("--num-kv-heads", type=int, default=16)
|
||||
parser.add_argument("--head-size", type=int, default=256)
|
||||
parser.add_argument("--context-len", type=int, default=512)
|
||||
parser.add_argument("--num-blocks", type=int, default=2048)
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
|
||||
for block_size in [16, 32, 64, 128]:
|
||||
for pages_per_compute_block in [1, 2, 4, 8, 16, 32, 64, 128]:
|
||||
if pages_per_compute_block > MAX_NUM_BLOCKS_PER_SEQ:
|
||||
continue
|
||||
if block_size * pages_per_compute_block > 1024:
|
||||
continue
|
||||
print(f"block_size {block_size}, pages_per_compute_block: {pages_per_compute_block}")
|
||||
benchmark_paged_attn(
|
||||
args.batch_size,
|
||||
args.num_heads,
|
||||
args.num_kv_heads,
|
||||
args.head_size,
|
||||
args.context_len,
|
||||
args.num_blocks,
|
||||
block_size,
|
||||
pages_per_compute_block,
|
||||
)
|
@ -24,6 +24,7 @@ def main(args: argparse.Namespace):
|
||||
dtype=args.dtype,
|
||||
enforce_eager=args.enforce_eager,
|
||||
kv_cache_dtype=args.kv_cache_dtype,
|
||||
quantization_param_path=args.quantization_param_path,
|
||||
device=args.device,
|
||||
ray_workers_use_nsight=args.ray_workers_use_nsight,
|
||||
enable_chunked_prefill=args.enable_chunked_prefill,
|
||||
@ -67,6 +68,7 @@ def main(args: argparse.Namespace):
|
||||
return latency
|
||||
|
||||
print("Warming up...")
|
||||
for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"):
|
||||
run_to_completion(profile_dir=None)
|
||||
|
||||
if args.profile:
|
||||
@ -83,7 +85,12 @@ def main(args: argparse.Namespace):
|
||||
latencies = []
|
||||
for _ in tqdm(range(args.num_iters), desc="Profiling iterations"):
|
||||
latencies.append(run_to_completion(profile_dir=None))
|
||||
latencies = np.array(latencies)
|
||||
percentages = [10, 25, 50, 75, 90]
|
||||
percentiles = np.percentile(latencies, percentages)
|
||||
print(f'Avg latency: {np.mean(latencies)} seconds')
|
||||
for percentage, percentile in zip(percentages, percentiles):
|
||||
print(f'{percentage}% percentile latency: {percentile} seconds')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
@ -105,9 +112,13 @@ if __name__ == '__main__':
|
||||
default=1,
|
||||
help='Number of generated sequences per prompt.')
|
||||
parser.add_argument('--use-beam-search', action='store_true')
|
||||
parser.add_argument('--num-iters-warmup',
|
||||
type=int,
|
||||
default=10,
|
||||
help='Number of iterations to run for warmup.')
|
||||
parser.add_argument('--num-iters',
|
||||
type=int,
|
||||
default=3,
|
||||
default=30,
|
||||
help='Number of iterations to run.')
|
||||
parser.add_argument('--trust-remote-code',
|
||||
action='store_true',
|
||||
@ -127,10 +138,23 @@ if __name__ == '__main__':
|
||||
parser.add_argument(
|
||||
"--kv-cache-dtype",
|
||||
type=str,
|
||||
choices=['auto', 'fp8_e5m2'],
|
||||
choices=['auto', 'fp8'],
|
||||
default='auto',
|
||||
help=
|
||||
'Data type for kv cache storage. If "auto", will use model data type.')
|
||||
'Data type for kv cache storage. If "auto", will use model data type. '
|
||||
'FP8_E5M2 (without scaling) is only supported on cuda version greater '
|
||||
'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for '
|
||||
'common inference criteria.')
|
||||
parser.add_argument(
|
||||
'--quantization-param-path',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Path to the JSON file containing the KV cache scaling factors. '
|
||||
'This should generally be supplied, when KV cache dtype is FP8. '
|
||||
'Otherwise, KV cache scaling factors default to 1.0, which may cause '
|
||||
'accuracy issues. FP8_E5M2 (without scaling) is only supported on '
|
||||
'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is '
|
||||
'instead supported for common inference criteria.')
|
||||
parser.add_argument(
|
||||
'--profile',
|
||||
action='store_true',
|
||||
@ -145,16 +169,15 @@ if __name__ == '__main__':
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda",
|
||||
choices=["cuda"],
|
||||
help='device type for vLLM execution, supporting CUDA only currently.')
|
||||
choices=["cuda", "cpu"],
|
||||
help='device type for vLLM execution, supporting CUDA and CPU.')
|
||||
parser.add_argument('--block-size',
|
||||
type=int,
|
||||
default=16,
|
||||
help='block size of key/value cache')
|
||||
parser.add_argument(
|
||||
'--enable-chunked-prefill',
|
||||
type=bool,
|
||||
default=False,
|
||||
action='store_true',
|
||||
help='If True, the prefill requests can be chunked based on the '
|
||||
'max_num_batched_tokens')
|
||||
parser.add_argument(
|
||||
|
@ -110,7 +110,9 @@ def sample_sonnet_requests(
|
||||
prefix_len: int,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
) -> List[Tuple[str, str, int, int]]:
|
||||
assert input_len > prefix_len, "input_len must be greater than prefix_len."
|
||||
assert (
|
||||
input_len > prefix_len
|
||||
), "'args.sonnet-input-len' must be greater than 'args.prefix-input-len'."
|
||||
|
||||
# Load the dataset.
|
||||
with open(dataset_path) as f:
|
||||
@ -131,8 +133,9 @@ def sample_sonnet_requests(
|
||||
base_message, add_generation_prompt=True, tokenize=False)
|
||||
base_prompt_offset = len(tokenizer(base_prompt_formatted).input_ids)
|
||||
|
||||
assert (input_len > base_prompt_offset
|
||||
), f"Please set 'args.input-len' higher than {base_prompt_offset}."
|
||||
assert (
|
||||
input_len > base_prompt_offset
|
||||
), f"Please set 'args.sonnet-input-len' higher than {base_prompt_offset}."
|
||||
num_input_lines = round(
|
||||
(input_len - base_prompt_offset) / average_poem_len)
|
||||
|
||||
@ -140,7 +143,7 @@ def sample_sonnet_requests(
|
||||
# prompt are fixed poem lines.
|
||||
assert (
|
||||
prefix_len > base_prompt_offset
|
||||
), f"Please set 'args.prefix-len' higher than {base_prompt_offset}."
|
||||
), f"Please set 'args.sonnet-prefix-len' higher than {base_prompt_offset}."
|
||||
|
||||
num_prefix_lines = round(
|
||||
(prefix_len - base_prompt_offset) / average_poem_len)
|
||||
@ -373,9 +376,9 @@ def main(args: argparse.Namespace):
|
||||
input_requests = sample_sonnet_requests(
|
||||
dataset_path=args.dataset_path,
|
||||
num_requests=args.num_prompts,
|
||||
input_len=args.input_len,
|
||||
output_len=args.output_len,
|
||||
prefix_len=args.prefix_len,
|
||||
input_len=args.sonnet_input_len,
|
||||
output_len=args.sonnet_output_len,
|
||||
prefix_len=args.sonnet_prefix_len,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
input_requests = [(prompt, prompt_len, output_len)
|
||||
@ -388,9 +391,9 @@ def main(args: argparse.Namespace):
|
||||
input_requests = sample_sonnet_requests(
|
||||
dataset_path=args.dataset_path,
|
||||
num_requests=args.num_prompts,
|
||||
input_len=args.input_len,
|
||||
output_len=args.output_len,
|
||||
prefix_len=args.prefix_len,
|
||||
input_len=args.sonnet_input_len,
|
||||
output_len=args.sonnet_output_len,
|
||||
prefix_len=args.sonnet_prefix_len,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
input_requests = [(prompt_formatted, prompt_len, output_len)
|
||||
|
@ -29,22 +29,23 @@ def sample_requests(
|
||||
dataset = [(data["conversations"][0]["value"],
|
||||
data["conversations"][1]["value"]) for data in dataset]
|
||||
|
||||
# Tokenize the prompts and completions.
|
||||
prompts = [prompt for prompt, _ in dataset]
|
||||
prompt_token_ids = tokenizer(prompts).input_ids
|
||||
completions = [completion for _, completion in dataset]
|
||||
completion_token_ids = tokenizer(completions).input_ids
|
||||
tokenized_dataset = []
|
||||
for i in range(len(dataset)):
|
||||
output_len = len(completion_token_ids[i])
|
||||
if fixed_output_len is not None:
|
||||
output_len = fixed_output_len
|
||||
tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len))
|
||||
# Shuffle the dataset.
|
||||
random.shuffle(dataset)
|
||||
|
||||
# Filter out too long sequences.
|
||||
# Filter out sequences that are too long or too short
|
||||
filtered_dataset: List[Tuple[str, int, int]] = []
|
||||
for prompt, prompt_token_ids, output_len in tokenized_dataset:
|
||||
for i in range(len(dataset)):
|
||||
if len(filtered_dataset) == num_requests:
|
||||
break
|
||||
|
||||
# Tokenize the prompts and completions.
|
||||
prompt = dataset[i][0]
|
||||
prompt_token_ids = tokenizer(prompt).input_ids
|
||||
completion = dataset[i][1]
|
||||
completion_token_ids = tokenizer(completion).input_ids
|
||||
prompt_len = len(prompt_token_ids)
|
||||
output_len = len(completion_token_ids
|
||||
) if fixed_output_len is None else fixed_output_len
|
||||
if prompt_len < 4 or output_len < 4:
|
||||
# Prune too short sequences.
|
||||
continue
|
||||
@ -53,9 +54,7 @@ def sample_requests(
|
||||
continue
|
||||
filtered_dataset.append((prompt, prompt_len, output_len))
|
||||
|
||||
# Sample the requests.
|
||||
sampled_requests = random.sample(filtered_dataset, num_requests)
|
||||
return sampled_requests
|
||||
return filtered_dataset
|
||||
|
||||
|
||||
def run_vllm(
|
||||
@ -72,13 +71,17 @@ def run_vllm(
|
||||
max_model_len: Optional[int],
|
||||
enforce_eager: bool,
|
||||
kv_cache_dtype: str,
|
||||
quantization_param_path: Optional[str],
|
||||
device: str,
|
||||
enable_prefix_caching: bool,
|
||||
enable_chunked_prefill: bool,
|
||||
max_num_batched_tokens: int,
|
||||
gpu_memory_utilization: float = 0.9,
|
||||
download_dir: Optional[str] = None,
|
||||
) -> float:
|
||||
from vllm import LLM, SamplingParams
|
||||
llm = LLM(model=model,
|
||||
llm = LLM(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
quantization=quantization,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
@ -89,9 +92,13 @@ def run_vllm(
|
||||
gpu_memory_utilization=gpu_memory_utilization,
|
||||
enforce_eager=enforce_eager,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
quantization_param_path=quantization_param_path,
|
||||
device=device,
|
||||
enable_prefix_caching=enable_prefix_caching,
|
||||
download_dir=download_dir)
|
||||
download_dir=download_dir,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
)
|
||||
|
||||
# Add the requests to the engine.
|
||||
for prompt, _, output_len in requests:
|
||||
@ -212,14 +219,15 @@ def main(args: argparse.Namespace):
|
||||
args.output_len)
|
||||
|
||||
if args.backend == "vllm":
|
||||
elapsed_time = run_vllm(requests, args.model, args.tokenizer,
|
||||
args.quantization, args.tensor_parallel_size,
|
||||
args.seed, args.n, args.use_beam_search,
|
||||
args.trust_remote_code, args.dtype,
|
||||
args.max_model_len, args.enforce_eager,
|
||||
args.kv_cache_dtype, args.device,
|
||||
args.enable_prefix_caching,
|
||||
args.gpu_memory_utilization, args.download_dir)
|
||||
elapsed_time = run_vllm(
|
||||
requests, args.model, args.tokenizer, args.quantization,
|
||||
args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
|
||||
args.trust_remote_code, args.dtype, args.max_model_len,
|
||||
args.enforce_eager, args.kv_cache_dtype,
|
||||
args.quantization_param_path, args.device,
|
||||
args.enable_prefix_caching, args.enable_chunked_prefill,
|
||||
args.max_num_batched_tokens, args.gpu_memory_utilization,
|
||||
args.download_dir)
|
||||
elif args.backend == "hf":
|
||||
assert args.tensor_parallel_size == 1
|
||||
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
|
||||
@ -306,20 +314,41 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--kv-cache-dtype",
|
||||
type=str,
|
||||
choices=["auto", "fp8_e5m2"],
|
||||
choices=["auto", "fp8"],
|
||||
default="auto",
|
||||
help=
|
||||
'Data type for kv cache storage. If "auto", will use model data type.')
|
||||
'Data type for kv cache storage. If "auto", will use model data type. '
|
||||
'FP8_E5M2 (without scaling) is only supported on cuda version greater '
|
||||
'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for '
|
||||
'common inference criteria.')
|
||||
parser.add_argument(
|
||||
'--quantization-param-path',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Path to the JSON file containing the KV cache scaling factors. '
|
||||
'This should generally be supplied, when KV cache dtype is FP8. '
|
||||
'Otherwise, KV cache scaling factors default to 1.0, which may cause '
|
||||
'accuracy issues. FP8_E5M2 (without scaling) is only supported on '
|
||||
'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is '
|
||||
'instead supported for common inference criteria.')
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda",
|
||||
choices=["cuda"],
|
||||
help='device type for vLLM execution, supporting CUDA only currently.')
|
||||
choices=["cuda", "cpu", "tpu"],
|
||||
help='device type for vLLM execution, supporting CUDA and CPU.')
|
||||
parser.add_argument(
|
||||
"--enable-prefix-caching",
|
||||
action='store_true',
|
||||
help="enable automatic prefix caching for vLLM backend.")
|
||||
parser.add_argument("--enable-chunked-prefill",
|
||||
action='store_true',
|
||||
help="enable chunked prefill for vLLM backend.")
|
||||
parser.add_argument('--max-num-batched-tokens',
|
||||
type=int,
|
||||
default=None,
|
||||
help='maximum number of batched tokens per '
|
||||
'iteration')
|
||||
parser.add_argument('--download-dir',
|
||||
type=str,
|
||||
default=None,
|
||||
|
@ -5,7 +5,7 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm._C import ops
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random
|
||||
|
||||
NUM_BLOCKS = 1024
|
||||
@ -97,6 +97,9 @@ def main(
|
||||
torch.cuda.cudart().cudaProfilerStart()
|
||||
start_time = time.perf_counter()
|
||||
|
||||
# Using default kv_scale
|
||||
kv_scale = 1.0
|
||||
|
||||
for _ in range(num_iters):
|
||||
if version == "v1":
|
||||
ops.paged_attention_v1(
|
||||
@ -112,6 +115,7 @@ def main(
|
||||
max_context_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
kv_scale,
|
||||
)
|
||||
elif version == "v2":
|
||||
ops.paged_attention_v2(
|
||||
@ -130,6 +134,7 @@ def main(
|
||||
max_context_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
kv_scale,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid version: {version}")
|
||||
@ -179,11 +184,13 @@ if __name__ == '__main__':
|
||||
parser.add_argument(
|
||||
"--kv-cache-dtype",
|
||||
type=str,
|
||||
choices=["auto", "fp8_e5m2"],
|
||||
choices=["auto", "fp8"],
|
||||
default="auto",
|
||||
help=
|
||||
'Data type for kv cache storage. If "auto", will use model data type.')
|
||||
parser.add_argument("--device", type=str, choices=["cuda"], default="cuda")
|
||||
'Data type for kv cache storage. If "auto", will use model data type. '
|
||||
'FP8_E5M2 (without scaling) is only supported on cuda version greater '
|
||||
'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for '
|
||||
'common inference criteria.')
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
|
||||
|
@ -100,6 +100,8 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
|
||||
|
||||
if (CUDA_VERSION VERSION_GREATER_EQUAL 11.8)
|
||||
list(APPEND GPU_FLAGS "-DENABLE_FP8_E5M2")
|
||||
endif()
|
||||
if (CUDA_VERSION VERSION_GREATER_EQUAL 12.0)
|
||||
list(REMOVE_ITEM GPU_FLAGS
|
||||
"-D__CUDA_NO_HALF_OPERATORS__"
|
||||
"-D__CUDA_NO_HALF_CONVERSIONS__"
|
||||
@ -117,6 +119,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
|
||||
|
||||
list(APPEND GPU_FLAGS
|
||||
"-DUSE_ROCM"
|
||||
"-DENABLE_FP8_E4M3"
|
||||
"-U__HIP_NO_HALF_CONVERSIONS__"
|
||||
"-U__HIP_NO_HALF_OPERATORS__"
|
||||
"-fno-gpu-rdc")
|
||||
|
@ -4,4 +4,4 @@
|
||||
#include "dtype_float16.cuh"
|
||||
#include "dtype_float32.cuh"
|
||||
#include "dtype_bfloat16.cuh"
|
||||
#include "dtype_fp8_e5m2.cuh"
|
||||
#include "dtype_fp8.cuh"
|
||||
|
@ -22,12 +22,26 @@
|
||||
|
||||
#include "attention_dtypes.h"
|
||||
#include "attention_utils.cuh"
|
||||
#ifdef ENABLE_FP8_E5M2
|
||||
|
||||
#if defined(ENABLE_FP8_E5M2)
|
||||
#include "../quantization/fp8_e5m2_kvcache/quant_utils.cuh"
|
||||
#elif defined(ENABLE_FP8_E4M3)
|
||||
#include "../quantization/fp8/amd_detail/quant_utils.cuh"
|
||||
#endif
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#ifdef USE_ROCM
|
||||
#include <hip/hip_bf16.h>
|
||||
typedef __hip_bfloat16 __nv_bfloat16;
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define WARP_SIZE 32
|
||||
#else
|
||||
#define WARP_SIZE warpSize
|
||||
#endif
|
||||
|
||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
|
||||
@ -78,7 +92,7 @@ template<
|
||||
int HEAD_SIZE,
|
||||
int BLOCK_SIZE,
|
||||
int NUM_THREADS,
|
||||
bool IS_FP8_E5M2_KV_CACHE,
|
||||
bool IS_FP8_KV_CACHE,
|
||||
int PARTITION_SIZE = 0> // Zero means no partitioning.
|
||||
__device__ void paged_attention_kernel(
|
||||
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
||||
@ -95,7 +109,8 @@ __device__ void paged_attention_kernel(
|
||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||
const int q_stride,
|
||||
const int kv_block_stride,
|
||||
const int kv_head_stride) {
|
||||
const int kv_head_stride,
|
||||
const float kv_scale) {
|
||||
const int seq_idx = blockIdx.y;
|
||||
const int partition_idx = blockIdx.z;
|
||||
const int max_num_partitions = gridDim.z;
|
||||
@ -142,7 +157,7 @@ __device__ void paged_attention_kernel(
|
||||
constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
|
||||
using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
|
||||
using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
|
||||
#ifdef ENABLE_FP8_E5M2
|
||||
#if defined(ENABLE_FP8_E5M2) || defined(ENABLE_FP8_E4M3)
|
||||
using Quant_vec = typename Vec<cache_t, VEC_SIZE>::Type;
|
||||
#endif
|
||||
|
||||
@ -208,11 +223,16 @@ __device__ void paged_attention_kernel(
|
||||
const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
|
||||
const int offset1 = (vec_idx * VEC_SIZE) / x;
|
||||
const int offset2 = (vec_idx * VEC_SIZE) % x;
|
||||
if constexpr (IS_FP8_E5M2_KV_CACHE) {
|
||||
#ifdef ENABLE_FP8_E5M2
|
||||
if constexpr (IS_FP8_KV_CACHE) {
|
||||
#if defined(ENABLE_FP8_E5M2)
|
||||
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
|
||||
// Vector conversion from Quant_vec to K_vec.
|
||||
k_vecs[j] = fp8_e5m2_unscaled::vec_conversion<K_vec, Quant_vec>(k_vec_quant);
|
||||
#elif defined(ENABLE_FP8_E4M3)
|
||||
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
|
||||
// Vector conversion from Quant_vec to K_vec. Use scaled_vec_conversion to convert FP8_E4M3 quantized k
|
||||
// cache vec to k vec in higher precision (FP16, BFloat16, etc.)
|
||||
k_vecs[j] = fp8_e4m3::scaled_vec_conversion<K_vec, Quant_vec>(k_vec_quant, kv_scale);
|
||||
#else
|
||||
assert(false);
|
||||
#endif
|
||||
@ -292,7 +312,7 @@ __device__ void paged_attention_kernel(
|
||||
constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
|
||||
using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
|
||||
using L_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
|
||||
#ifdef ENABLE_FP8_E5M2
|
||||
#if defined(ENABLE_FP8_E5M2) || defined(ENABLE_FP8_E4M3)
|
||||
using V_quant_vec = typename Vec<cache_t, V_VEC_SIZE>::Type;
|
||||
#endif
|
||||
using Float_L_vec = typename FloatVec<L_vec>::Type;
|
||||
@ -328,11 +348,16 @@ __device__ void paged_attention_kernel(
|
||||
if (row_idx < HEAD_SIZE) {
|
||||
const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
|
||||
V_vec v_vec;
|
||||
if constexpr (IS_FP8_E5M2_KV_CACHE) {
|
||||
#ifdef ENABLE_FP8_E5M2
|
||||
if constexpr (IS_FP8_KV_CACHE) {
|
||||
#if defined(ENABLE_FP8_E5M2)
|
||||
V_quant_vec v_quant_vec = *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
|
||||
// Vector conversion from V_quant_vec to V_vec.
|
||||
v_vec = fp8_e5m2_unscaled::vec_conversion<V_vec, V_quant_vec>(v_quant_vec);
|
||||
#elif defined(ENABLE_FP8_E4M3)
|
||||
V_quant_vec v_quant_vec = *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
|
||||
// Vector conversion from V_quant_vec to V_vec. Use scaled_vec_conversion to convert
|
||||
// FP8_E4M3 quantized v cache vec to v vec in higher precision (FP16, BFloat16, etc.)
|
||||
v_vec = fp8_e4m3::scaled_vec_conversion<V_vec, V_quant_vec>(v_quant_vec, kv_scale);
|
||||
#else
|
||||
assert(false);
|
||||
#endif
|
||||
@ -423,7 +448,7 @@ template<
|
||||
int HEAD_SIZE,
|
||||
int BLOCK_SIZE,
|
||||
int NUM_THREADS,
|
||||
bool IS_FP8_E5M2_KV_CACHE>
|
||||
bool IS_FP8_KV_CACHE>
|
||||
__global__ void paged_attention_v1_kernel(
|
||||
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||
@ -437,11 +462,12 @@ __global__ void paged_attention_v1_kernel(
|
||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||
const int q_stride,
|
||||
const int kv_block_stride,
|
||||
const int kv_head_stride) {
|
||||
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_E5M2_KV_CACHE>(
|
||||
const int kv_head_stride,
|
||||
const float kv_scale) {
|
||||
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_KV_CACHE>(
|
||||
/* exp_sums */ nullptr, /* max_logits */ nullptr,
|
||||
out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens,
|
||||
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride);
|
||||
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_scale);
|
||||
}
|
||||
|
||||
// Grid: (num_heads, num_seqs, max_num_partitions).
|
||||
@ -451,7 +477,7 @@ template<
|
||||
int HEAD_SIZE,
|
||||
int BLOCK_SIZE,
|
||||
int NUM_THREADS,
|
||||
bool IS_FP8_E5M2_KV_CACHE,
|
||||
bool IS_FP8_KV_CACHE,
|
||||
int PARTITION_SIZE>
|
||||
__global__ void paged_attention_v2_kernel(
|
||||
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
||||
@ -468,11 +494,12 @@ __global__ void paged_attention_v2_kernel(
|
||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||
const int q_stride,
|
||||
const int kv_block_stride,
|
||||
const int kv_head_stride) {
|
||||
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_E5M2_KV_CACHE, PARTITION_SIZE>(
|
||||
const int kv_head_stride,
|
||||
const float kv_scale) {
|
||||
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_KV_CACHE, PARTITION_SIZE>(
|
||||
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
|
||||
block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes,
|
||||
q_stride, kv_block_stride, kv_head_stride);
|
||||
q_stride, kv_block_stride, kv_head_stride, kv_scale);
|
||||
}
|
||||
|
||||
// Grid: (num_heads, num_seqs).
|
||||
@ -579,9 +606,9 @@ __global__ void paged_attention_v2_reduce_kernel(
|
||||
#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
|
||||
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
|
||||
((void*)vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
|
||||
IS_FP8_E5M2_KV_CACHE>), shared_mem_size); \
|
||||
IS_FP8_KV_CACHE>), shared_mem_size); \
|
||||
vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
|
||||
IS_FP8_E5M2_KV_CACHE><<<grid, block, shared_mem_size, stream>>>( \
|
||||
IS_FP8_KV_CACHE><<<grid, block, shared_mem_size, stream>>>( \
|
||||
out_ptr, \
|
||||
query_ptr, \
|
||||
key_cache_ptr, \
|
||||
@ -594,14 +621,15 @@ __global__ void paged_attention_v2_reduce_kernel(
|
||||
alibi_slopes_ptr, \
|
||||
q_stride, \
|
||||
kv_block_stride, \
|
||||
kv_head_stride);
|
||||
kv_head_stride, \
|
||||
kv_scale);
|
||||
|
||||
// TODO(woosuk): Tune NUM_THREADS.
|
||||
template<
|
||||
typename T,
|
||||
typename CACHE_T,
|
||||
int BLOCK_SIZE,
|
||||
bool IS_FP8_E5M2_KV_CACHE,
|
||||
bool IS_FP8_KV_CACHE,
|
||||
int NUM_THREADS = 128>
|
||||
void paged_attention_v1_launcher(
|
||||
torch::Tensor& out,
|
||||
@ -613,7 +641,8 @@ void paged_attention_v1_launcher(
|
||||
torch::Tensor& block_tables,
|
||||
torch::Tensor& context_lens,
|
||||
int max_context_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes) {
|
||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||
float kv_scale) {
|
||||
int num_seqs = query.size(0);
|
||||
int num_heads = query.size(1);
|
||||
int head_size = query.size(2);
|
||||
@ -677,8 +706,8 @@ void paged_attention_v1_launcher(
|
||||
}
|
||||
}
|
||||
|
||||
#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE) \
|
||||
paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE>( \
|
||||
#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
|
||||
paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE>( \
|
||||
out, \
|
||||
query, \
|
||||
key_cache, \
|
||||
@ -688,20 +717,21 @@ void paged_attention_v1_launcher(
|
||||
block_tables, \
|
||||
context_lens, \
|
||||
max_context_len, \
|
||||
alibi_slopes);
|
||||
alibi_slopes, \
|
||||
kv_scale);
|
||||
|
||||
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
|
||||
// 1, 2, 4, 64, 128, 256.
|
||||
#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \
|
||||
#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_KV_CACHE) \
|
||||
switch (block_size) { \
|
||||
case 8: \
|
||||
CALL_V1_LAUNCHER(T, CACHE_T, 8, IS_FP8_E5M2_KV_CACHE); \
|
||||
CALL_V1_LAUNCHER(T, CACHE_T, 8, IS_FP8_KV_CACHE); \
|
||||
break; \
|
||||
case 16: \
|
||||
CALL_V1_LAUNCHER(T, CACHE_T, 16, IS_FP8_E5M2_KV_CACHE); \
|
||||
CALL_V1_LAUNCHER(T, CACHE_T, 16, IS_FP8_KV_CACHE); \
|
||||
break; \
|
||||
case 32: \
|
||||
CALL_V1_LAUNCHER(T, CACHE_T, 32, IS_FP8_E5M2_KV_CACHE); \
|
||||
CALL_V1_LAUNCHER(T, CACHE_T, 32, IS_FP8_KV_CACHE); \
|
||||
break; \
|
||||
default: \
|
||||
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
||||
@ -720,7 +750,8 @@ void paged_attention_v1(
|
||||
int block_size,
|
||||
int max_context_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype) {
|
||||
const std::string& kv_cache_dtype,
|
||||
float kv_scale) {
|
||||
if (kv_cache_dtype == "auto") {
|
||||
if (query.dtype() == at::ScalarType::Float) {
|
||||
CALL_V1_LAUNCHER_BLOCK_SIZE(float, float, false);
|
||||
@ -731,7 +762,7 @@ void paged_attention_v1(
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
||||
}
|
||||
} else if (kv_cache_dtype == "fp8_e5m2") {
|
||||
} else if (kv_cache_dtype == "fp8") {
|
||||
if (query.dtype() == at::ScalarType::Float) {
|
||||
CALL_V1_LAUNCHER_BLOCK_SIZE(float, uint8_t, true);
|
||||
} else if (query.dtype() == at::ScalarType::Half) {
|
||||
@ -748,7 +779,7 @@ void paged_attention_v1(
|
||||
|
||||
#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
|
||||
vllm::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
|
||||
IS_FP8_E5M2_KV_CACHE, PARTITION_SIZE> \
|
||||
IS_FP8_KV_CACHE, PARTITION_SIZE> \
|
||||
<<<grid, block, shared_mem_size, stream>>>( \
|
||||
exp_sums_ptr, \
|
||||
max_logits_ptr, \
|
||||
@ -764,7 +795,8 @@ void paged_attention_v1(
|
||||
alibi_slopes_ptr, \
|
||||
q_stride, \
|
||||
kv_block_stride, \
|
||||
kv_head_stride); \
|
||||
kv_head_stride, \
|
||||
kv_scale); \
|
||||
vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, PARTITION_SIZE> \
|
||||
<<<reduce_grid, block, reduce_shared_mem_size, stream>>>( \
|
||||
out_ptr, \
|
||||
@ -778,7 +810,7 @@ template<
|
||||
typename T,
|
||||
typename CACHE_T,
|
||||
int BLOCK_SIZE,
|
||||
bool IS_FP8_E5M2_KV_CACHE,
|
||||
bool IS_FP8_KV_CACHE,
|
||||
int NUM_THREADS = 128,
|
||||
int PARTITION_SIZE = 512>
|
||||
void paged_attention_v2_launcher(
|
||||
@ -794,7 +826,8 @@ void paged_attention_v2_launcher(
|
||||
torch::Tensor& block_tables,
|
||||
torch::Tensor& context_lens,
|
||||
int max_context_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes) {
|
||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||
float kv_scale) {
|
||||
int num_seqs = query.size(0);
|
||||
int num_heads = query.size(1);
|
||||
int head_size = query.size(2);
|
||||
@ -864,8 +897,8 @@ void paged_attention_v2_launcher(
|
||||
}
|
||||
}
|
||||
|
||||
#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE) \
|
||||
paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE>( \
|
||||
#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
|
||||
paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE>( \
|
||||
out, \
|
||||
exp_sums, \
|
||||
max_logits, \
|
||||
@ -878,20 +911,21 @@ void paged_attention_v2_launcher(
|
||||
block_tables, \
|
||||
context_lens, \
|
||||
max_context_len, \
|
||||
alibi_slopes);
|
||||
alibi_slopes, \
|
||||
kv_scale);
|
||||
|
||||
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
|
||||
// 1, 2, 4, 64, 128, 256.
|
||||
#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \
|
||||
#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_KV_CACHE) \
|
||||
switch (block_size) { \
|
||||
case 8: \
|
||||
CALL_V2_LAUNCHER(T, CACHE_T, 8, IS_FP8_E5M2_KV_CACHE); \
|
||||
CALL_V2_LAUNCHER(T, CACHE_T, 8, IS_FP8_KV_CACHE); \
|
||||
break; \
|
||||
case 16: \
|
||||
CALL_V2_LAUNCHER(T, CACHE_T, 16, IS_FP8_E5M2_KV_CACHE); \
|
||||
CALL_V2_LAUNCHER(T, CACHE_T, 16, IS_FP8_KV_CACHE); \
|
||||
break; \
|
||||
case 32: \
|
||||
CALL_V2_LAUNCHER(T, CACHE_T, 32, IS_FP8_E5M2_KV_CACHE); \
|
||||
CALL_V2_LAUNCHER(T, CACHE_T, 32, IS_FP8_KV_CACHE); \
|
||||
break; \
|
||||
default: \
|
||||
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
||||
@ -913,7 +947,8 @@ void paged_attention_v2(
|
||||
int block_size,
|
||||
int max_context_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype) {
|
||||
const std::string& kv_cache_dtype,
|
||||
float kv_scale) {
|
||||
if (kv_cache_dtype == "auto") {
|
||||
if (query.dtype() == at::ScalarType::Float) {
|
||||
CALL_V2_LAUNCHER_BLOCK_SIZE(float, float, false);
|
||||
@ -924,7 +959,7 @@ void paged_attention_v2(
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
||||
}
|
||||
} else if (kv_cache_dtype == "fp8_e5m2") {
|
||||
} else if (kv_cache_dtype == "fp8") {
|
||||
if (query.dtype() == at::ScalarType::Float) {
|
||||
CALL_V2_LAUNCHER_BLOCK_SIZE(float, uint8_t, true);
|
||||
} else if (query.dtype() == at::ScalarType::Half) {
|
||||
|
@ -8,7 +8,7 @@
|
||||
#endif
|
||||
|
||||
namespace vllm {
|
||||
#ifdef ENABLE_FP8_E5M2
|
||||
#if defined(ENABLE_FP8_E5M2) || defined(ENABLE_FP8_E4M3)
|
||||
// fp8 vector types for quantization of kv cache
|
||||
|
||||
template<>
|
@ -21,9 +21,10 @@ void reshape_and_cache(
|
||||
torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache,
|
||||
torch::Tensor& slot_mapping,
|
||||
const std::string& kv_cache_dtype);
|
||||
const std::string& kv_cache_dtype,
|
||||
const float kv_scale);
|
||||
|
||||
// Just for unittest
|
||||
void convert_fp8_e5m2(
|
||||
void convert_fp8(
|
||||
torch::Tensor& src_cache,
|
||||
torch::Tensor& dst_cache);
|
||||
|
@ -4,8 +4,10 @@
|
||||
|
||||
#include "cuda_compat.h"
|
||||
#include "dispatch_utils.h"
|
||||
#ifdef ENABLE_FP8_E5M2
|
||||
#if defined(ENABLE_FP8_E5M2)
|
||||
#include "quantization/fp8_e5m2_kvcache/quant_utils.cuh"
|
||||
#elif defined(ENABLE_FP8_E4M3)
|
||||
#include "quantization/fp8/amd_detail/quant_utils.cuh"
|
||||
#endif
|
||||
|
||||
#include <algorithm>
|
||||
@ -151,7 +153,7 @@ void copy_blocks(
|
||||
|
||||
namespace vllm {
|
||||
|
||||
template<typename scalar_t, typename cache_t, bool is_fp8_e5m2_kv_cache>
|
||||
template<typename scalar_t, typename cache_t, bool is_fp8_kv_cache>
|
||||
__global__ void reshape_and_cache_kernel(
|
||||
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
||||
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
|
||||
@ -163,7 +165,8 @@ __global__ void reshape_and_cache_kernel(
|
||||
const int num_heads,
|
||||
const int head_size,
|
||||
const int block_size,
|
||||
const int x) {
|
||||
const int x,
|
||||
const float kv_scale) {
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
const int64_t slot_idx = slot_mapping[token_idx];
|
||||
if (slot_idx < 0) {
|
||||
@ -195,10 +198,13 @@ __global__ void reshape_and_cache_kernel(
|
||||
+ block_offset;
|
||||
scalar_t tgt_key = key[src_key_idx];
|
||||
scalar_t tgt_value = value[src_value_idx];
|
||||
if constexpr (is_fp8_e5m2_kv_cache) {
|
||||
#ifdef ENABLE_FP8_E5M2
|
||||
if constexpr (is_fp8_kv_cache) {
|
||||
#if defined(ENABLE_FP8_E5M2)
|
||||
key_cache[tgt_key_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_key);
|
||||
value_cache[tgt_value_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_value);
|
||||
#elif defined(ENABLE_FP8_E4M3)
|
||||
key_cache[tgt_key_idx] = fp8_e4m3::scaled_vec_conversion<uint8_t, scalar_t>(tgt_key, kv_scale);
|
||||
value_cache[tgt_value_idx] = fp8_e4m3::scaled_vec_conversion<uint8_t, scalar_t>(tgt_value, kv_scale);
|
||||
#else
|
||||
assert(false);
|
||||
#endif
|
||||
@ -211,8 +217,8 @@ __global__ void reshape_and_cache_kernel(
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \
|
||||
vllm::reshape_and_cache_kernel<KV_T, CACHE_T, IS_FP8_E5M2_KV_CACHE><<<grid, block, 0, stream>>>( \
|
||||
#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_KV_CACHE) \
|
||||
vllm::reshape_and_cache_kernel<KV_T, CACHE_T, IS_FP8_KV_CACHE><<<grid, block, 0, stream>>>( \
|
||||
reinterpret_cast<KV_T*>(key.data_ptr()), \
|
||||
reinterpret_cast<KV_T*>(value.data_ptr()), \
|
||||
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
|
||||
@ -223,7 +229,8 @@ __global__ void reshape_and_cache_kernel(
|
||||
num_heads, \
|
||||
head_size, \
|
||||
block_size, \
|
||||
x);
|
||||
x, \
|
||||
kv_scale);
|
||||
|
||||
void reshape_and_cache(
|
||||
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
||||
@ -231,7 +238,8 @@ void reshape_and_cache(
|
||||
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||
torch::Tensor& slot_mapping, // [num_tokens]
|
||||
const std::string& kv_cache_dtype)
|
||||
const std::string& kv_cache_dtype,
|
||||
const float kv_scale)
|
||||
{
|
||||
int num_tokens = key.size(0);
|
||||
int num_heads = key.size(1);
|
||||
@ -254,7 +262,7 @@ void reshape_and_cache(
|
||||
} else if (key.dtype() == at::ScalarType::BFloat16) {
|
||||
CALL_RESHAPE_AND_CACHE(__nv_bfloat16, __nv_bfloat16, false);
|
||||
}
|
||||
} else if (kv_cache_dtype == "fp8_e5m2") {
|
||||
} else if (kv_cache_dtype == "fp8") {
|
||||
if (key.dtype() == at::ScalarType::Float) {
|
||||
CALL_RESHAPE_AND_CACHE(float, uint8_t, true);
|
||||
} else if (key.dtype() == at::ScalarType::Half) {
|
||||
@ -270,15 +278,17 @@ void reshape_and_cache(
|
||||
namespace vllm {
|
||||
|
||||
template<typename Tout, typename Tin>
|
||||
__global__ void convert_fp8_e5m2_kernel(
|
||||
__global__ void convert_fp8_kernel(
|
||||
const Tin* __restrict__ src_cache,
|
||||
Tout* __restrict__ dst_cache,
|
||||
const int64_t block_stride) {
|
||||
const int64_t block_idx = blockIdx.x;
|
||||
for (int i = threadIdx.x; i < block_stride; i += blockDim.x) {
|
||||
int64_t idx = block_idx * block_stride + i;
|
||||
#ifdef ENABLE_FP8_E5M2
|
||||
#if defined(ENABLE_FP8_E5M2)
|
||||
dst_cache[idx] = fp8_e5m2_unscaled::vec_conversion<Tout, Tin>(src_cache[idx]);
|
||||
#elif defined(ENABLE_FP8_E4M3)
|
||||
dst_cache[idx] = fp8_e4m3::vec_conversion<Tout, Tin>(src_cache[idx]);
|
||||
#else
|
||||
assert(false);
|
||||
#endif
|
||||
@ -287,16 +297,25 @@ __global__ void convert_fp8_e5m2_kernel(
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
#define CALL_CONVERT_FP8_E5M2(Tout, Tin) \
|
||||
vllm::convert_fp8_e5m2_kernel<Tout, Tin><<<grid, block, 0, stream>>>( \
|
||||
#define CALL_CONVERT_FP8(Tout, Tin) \
|
||||
vllm::convert_fp8_kernel<Tout, Tin><<<grid, block, 0, stream>>>( \
|
||||
reinterpret_cast<Tin*>(src_cache.data_ptr()), \
|
||||
reinterpret_cast<Tout*>(dst_cache.data_ptr()), \
|
||||
block_stride);
|
||||
|
||||
void convert_fp8_e5m2(
|
||||
void convert_fp8(
|
||||
torch::Tensor& src_cache,
|
||||
torch::Tensor& dst_cache)
|
||||
{
|
||||
torch::Device src_device = src_cache.device();
|
||||
torch::Device dst_device = dst_cache.device();
|
||||
TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU")
|
||||
TORCH_CHECK(dst_device.is_cuda(), "dst must be on a GPU")
|
||||
TORCH_CHECK(
|
||||
src_device.index() == dst_device.index(),
|
||||
"src and dst must be on the same GPU");
|
||||
at::cuda::OptionalCUDAGuard device_guard(src_device);
|
||||
|
||||
int64_t num_blocks = src_cache.size(0);
|
||||
int64_t block_stride = src_cache.stride(0);
|
||||
|
||||
@ -305,16 +324,16 @@ void convert_fp8_e5m2(
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
if (src_cache.dtype() == at::ScalarType::Float) {
|
||||
CALL_CONVERT_FP8_E5M2(uint8_t, float);
|
||||
CALL_CONVERT_FP8(uint8_t, float);
|
||||
} else if (src_cache.dtype() == at::ScalarType::Half) {
|
||||
CALL_CONVERT_FP8_E5M2(uint8_t, uint16_t);
|
||||
CALL_CONVERT_FP8(uint8_t, uint16_t);
|
||||
} else if (src_cache.dtype() == at::ScalarType::BFloat16) {
|
||||
CALL_CONVERT_FP8_E5M2(uint8_t, __nv_bfloat16);
|
||||
CALL_CONVERT_FP8(uint8_t, __nv_bfloat16);
|
||||
} else if (dst_cache.dtype() == at::ScalarType::Float) {
|
||||
CALL_CONVERT_FP8_E5M2(float, uint8_t);
|
||||
CALL_CONVERT_FP8(float, uint8_t);
|
||||
} else if (dst_cache.dtype() == at::ScalarType::Half) {
|
||||
CALL_CONVERT_FP8_E5M2(uint16_t, uint8_t);
|
||||
CALL_CONVERT_FP8(uint16_t, uint8_t);
|
||||
} else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
|
||||
CALL_CONVERT_FP8_E5M2(__nv_bfloat16, uint8_t);
|
||||
CALL_CONVERT_FP8(__nv_bfloat16, uint8_t);
|
||||
}
|
||||
}
|
||||
|
@ -419,7 +419,8 @@ void paged_attention_v1(torch::Tensor &out, torch::Tensor &query,
|
||||
torch::Tensor &context_lens, int block_size,
|
||||
int max_context_len,
|
||||
const c10::optional<torch::Tensor> &alibi_slopes,
|
||||
const std::string &kv_cache_dtype) {
|
||||
const std::string &kv_cache_dtype, float kv_scale) {
|
||||
TORCH_CHECK(kv_scale == 1.0f);
|
||||
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl",
|
||||
[&] {
|
||||
CPU_KERNEL_GUARD_IN(paged_attention_v1_impl)
|
||||
@ -734,7 +735,8 @@ void paged_attention_v2(torch::Tensor &out, torch::Tensor &exp_sums,
|
||||
torch::Tensor &context_lens, int block_size,
|
||||
int max_context_len,
|
||||
const c10::optional<torch::Tensor> &alibi_slopes,
|
||||
const std::string &kv_cache_dtype) {
|
||||
const std::string &kv_cache_dtype, float kv_scale) {
|
||||
TORCH_CHECK(kv_scale == 1.0f);
|
||||
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl",
|
||||
[&] {
|
||||
CPU_KERNEL_GUARD_IN(paged_attention_v2_impl)
|
||||
|
@ -111,7 +111,9 @@ void copy_blocks(std::vector<torch::Tensor> &key_caches,
|
||||
void reshape_and_cache(torch::Tensor &key, torch::Tensor &value,
|
||||
torch::Tensor &key_cache, torch::Tensor &value_cache,
|
||||
torch::Tensor &slot_mapping,
|
||||
const std::string &kv_cache_dtype) {
|
||||
const std::string &kv_cache_dtype, float kv_scale) {
|
||||
TORCH_CHECK(kv_scale == 1.0f);
|
||||
|
||||
int num_tokens = key.size(0);
|
||||
int num_heads = key.size(1);
|
||||
int head_size = key.size(2);
|
||||
|
@ -59,6 +59,8 @@ __global__ void rms_norm_kernel(
|
||||
template<typename torch_type>
|
||||
struct _typeConvert { static constexpr bool exists = false; };
|
||||
|
||||
#if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
|
||||
// CUDA < 12.0 runs into issues with packed type conversion
|
||||
template<>
|
||||
struct _typeConvert<c10::Half> {
|
||||
static constexpr bool exists = true;
|
||||
@ -85,8 +87,8 @@ struct _typeConvert<c10::BFloat16> {
|
||||
__device__ static inline hip_type convert(float x) { return __float2bfloat16(x); }
|
||||
__device__ static inline packed_hip_type convert(float2 x) { return __float22bfloat162_rn(x); }
|
||||
};
|
||||
#endif
|
||||
|
||||
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
|
||||
|
||||
/* Vector POD struct to generate vectorized and packed FP16/BF16 ops
|
||||
for appropriate specializations of fused_add_rms_norm_kernel.
|
||||
|
@ -14,7 +14,8 @@ void paged_attention_v1(
|
||||
int block_size,
|
||||
int max_context_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype);
|
||||
const std::string& kv_cache_dtype,
|
||||
float kv_scale);
|
||||
|
||||
void paged_attention_v2(
|
||||
torch::Tensor& out,
|
||||
@ -31,7 +32,8 @@ void paged_attention_v2(
|
||||
int block_size,
|
||||
int max_context_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype);
|
||||
const std::string& kv_cache_dtype,
|
||||
float kv_scale);
|
||||
|
||||
void rms_norm(
|
||||
torch::Tensor& out,
|
||||
|
@ -14,6 +14,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
||||
f(in_T, out_T, W_T, narrow, 128) \
|
||||
f(in_T, out_T, W_T, narrow, 256) \
|
||||
f(in_T, out_T, W_T, narrow, 512) \
|
||||
f(in_T, out_T, W_T, narrow, 640) \
|
||||
f(in_T, out_T, W_T, narrow, 768) \
|
||||
f(in_T, out_T, W_T, narrow, 1024) \
|
||||
f(in_T, out_T, W_T, narrow, 1152) \
|
||||
@ -46,6 +47,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
||||
f(in_T, out_T, W_T, narrow, 13696) \
|
||||
f(in_T, out_T, W_T, narrow, 13824) \
|
||||
f(in_T, out_T, W_T, narrow, 14336) \
|
||||
f(in_T, out_T, W_T, narrow, 15360) \
|
||||
f(in_T, out_T, W_T, narrow, 16384) \
|
||||
f(in_T, out_T, W_T, narrow, 20480) \
|
||||
f(in_T, out_T, W_T, narrow, 22016) \
|
||||
@ -59,7 +61,17 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
||||
f(in_T, out_T, W_T, narrow, 33024) \
|
||||
f(in_T, out_T, W_T, narrow, 36864) \
|
||||
f(in_T, out_T, W_T, narrow, 49152) \
|
||||
// Keep above in sync with vllm/lora/layers::SamplerWithLoRA
|
||||
f(in_T, out_T, W_T, narrow, 64000) \
|
||||
f(in_T, out_T, W_T, narrow, 64256) \
|
||||
f(in_T, out_T, W_T, narrow, 64512) \
|
||||
f(in_T, out_T, W_T, narrow, 102400) \
|
||||
f(in_T, out_T, W_T, narrow, 102656) \
|
||||
f(in_T, out_T, W_T, narrow, 102912) \
|
||||
f(in_T, out_T, W_T, narrow, 128000) \
|
||||
f(in_T, out_T, W_T, narrow, 128256) \
|
||||
f(in_T, out_T, W_T, narrow, 128512) \
|
||||
// Keep above in sync with vllm/lora/layers::LogitsProcessorWithLoRA
|
||||
// and vllm/tests/lora/test_punica.py
|
||||
|
||||
// Keep this in sync with vllm/config::LoRAConfig
|
||||
#define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \
|
||||
|
@ -20,8 +20,8 @@ inline void check_shape(const torch::Tensor &a, const torch::Tensor &b,
|
||||
}
|
||||
}
|
||||
|
||||
inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) {
|
||||
return (uint32_t(a) << 16) | uint32_t(b);
|
||||
inline constexpr uint64_t pack_u32(uint32_t a, uint32_t b) {
|
||||
return (uint64_t(a) << 32) | uint64_t(b);
|
||||
}
|
||||
|
||||
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
|
||||
@ -46,13 +46,13 @@ inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) {
|
||||
template <typename in_T, typename out_T, typename W_T>
|
||||
inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W,
|
||||
const int64_t *lora_indices,
|
||||
uint16_t in_features, uint16_t out_features,
|
||||
uint32_t in_features, uint32_t out_features,
|
||||
int64_t y_offset, int64_t full_y_size,
|
||||
int64_t batch_size, int64_t num_layers,
|
||||
int64_t layer_idx, float scale) {
|
||||
switch (pack_u16(in_features, out_features)) {
|
||||
switch (pack_u32(in_features, out_features)) {
|
||||
#define CASE_ONESIDE(_in_T, _out_T, _W_T, feat_in, feat_out) \
|
||||
case pack_u16(feat_in, feat_out): \
|
||||
case pack_u32(feat_in, feat_out): \
|
||||
bgmv_kernel<feat_in, feat_out>(Y, X, W, lora_indices, y_offset, \
|
||||
full_y_size, batch_size, num_layers, \
|
||||
layer_idx, scale); \
|
||||
@ -93,7 +93,7 @@ void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w,
|
||||
CHECK_EQ(y.size(0), x.size(0));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
||||
bool ok = false;
|
||||
if (h_in < 65536 && h_out < 65536) {
|
||||
if (h_in <= 128512 && h_out <= 128512) {
|
||||
// TODO: See if we can get rid of this massive nested switch
|
||||
switch (x.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
@ -325,7 +325,7 @@ void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w,
|
||||
CHECK_EQ(y.size(0), x.size(0));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
||||
bool ok = false;
|
||||
if (h_in < 65536 && h_out < 65536) {
|
||||
if (h_in <= 128512 && h_out <= 128512) {
|
||||
// TODO: See if we can get rid of this massive nested switch
|
||||
switch (x.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
|
@ -91,9 +91,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
&reshape_and_cache,
|
||||
"Reshape the key and value tensors and cache them");
|
||||
cache_ops.def(
|
||||
"convert_fp8_e5m2",
|
||||
&convert_fp8_e5m2,
|
||||
"Convert the key and value cache to fp8_e5m2 data type");
|
||||
"convert_fp8",
|
||||
&convert_fp8,
|
||||
"Convert the key and value cache to fp8 data type");
|
||||
|
||||
// Cuda utils
|
||||
pybind11::module cuda_utils = m.def_submodule("cuda_utils", "vLLM cuda utils");
|
||||
|
167
csrc/quantization/fp8/amd_detail/hip_float8.h
Normal file
167
csrc/quantization/fp8/amd_detail/hip_float8.h
Normal file
@ -0,0 +1,167 @@
|
||||
#pragma once
|
||||
|
||||
#ifdef __HIPCC__
|
||||
#include <hip/hip_runtime.h>
|
||||
#else
|
||||
#include <type_traits>
|
||||
#include <stdint.h>
|
||||
#include <math.h>
|
||||
#include <iostream>
|
||||
#endif
|
||||
|
||||
#include "hip_float8_impl.h"
|
||||
|
||||
struct alignas(1) hip_fp8
|
||||
{
|
||||
struct from_bits_t
|
||||
{
|
||||
};
|
||||
HIP_FP8_HOST_DEVICE static constexpr from_bits_t from_bits() { return from_bits_t(); }
|
||||
uint8_t data;
|
||||
|
||||
hip_fp8() = default;
|
||||
HIP_FP8_HOST_DEVICE constexpr hip_fp8(const hip_fp8&) = default;
|
||||
HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v) = delete;
|
||||
explicit HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v, from_bits_t)
|
||||
: data(v)
|
||||
{
|
||||
}
|
||||
|
||||
#ifdef __HIP__MI300__
|
||||
// NOTE: ON-DEVICE... always optimal bias
|
||||
explicit HIP_FP8_DEVICE hip_fp8(float v)
|
||||
: data(hip_fp8_impl::to_fp8_from_fp32(v))
|
||||
{
|
||||
}
|
||||
|
||||
explicit HIP_FP8_DEVICE hip_fp8(_Float16 v)
|
||||
: hip_fp8(static_cast<float>(v))
|
||||
{
|
||||
}
|
||||
|
||||
// Host only implementation using s/w simulation
|
||||
explicit HIP_FP8_HOST
|
||||
#else // __HIP__MI300__
|
||||
// both Host and DEVICE for non-MI300 using s/w simulation
|
||||
explicit HIP_FP8_HOST_DEVICE
|
||||
#endif // __HIP__MI300__
|
||||
hip_fp8(float v)
|
||||
{
|
||||
data = hip_fp8_impl::to_float8<4, 3, float, true /*negative_zero_nan*/, true /*clip*/>(v);
|
||||
}
|
||||
|
||||
explicit HIP_FP8_HOST_DEVICE hip_fp8(double v)
|
||||
: hip_fp8(static_cast<float>(v))
|
||||
{
|
||||
}
|
||||
|
||||
#ifdef __HIP__MI300__
|
||||
// upcast using device specific intrinsic
|
||||
explicit inline HIP_FP8_DEVICE operator float() const
|
||||
{
|
||||
float fval;
|
||||
uint32_t i32val = static_cast<uint32_t>(data);
|
||||
|
||||
// upcast
|
||||
asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
|
||||
|
||||
return fval;
|
||||
}
|
||||
|
||||
explicit inline HIP_FP8_HOST operator float() const
|
||||
#else // __HIP__MI300__
|
||||
explicit inline HIP_FP8_HOST_DEVICE operator float() const
|
||||
#endif // __HIP__MI300__
|
||||
{
|
||||
return hip_fp8_impl::from_float8<4, 3, float, true /*negative_zero_nan*/>(data);
|
||||
}
|
||||
};
|
||||
|
||||
namespace std
|
||||
{
|
||||
inline hip_fp8 sin(hip_fp8 a)
|
||||
{
|
||||
return hip_fp8(sinf(float(a)));
|
||||
}
|
||||
inline hip_fp8 cos(hip_fp8 a)
|
||||
{
|
||||
return hip_fp8(cosf(float(a)));
|
||||
}
|
||||
HIP_FP8_HOST_DEVICE constexpr hip_fp8 real(const hip_fp8& a)
|
||||
{
|
||||
return a;
|
||||
}
|
||||
} // namespace std
|
||||
|
||||
// Special operator overloading
|
||||
inline std::ostream& operator<<(std::ostream& os, const hip_fp8& f8)
|
||||
{
|
||||
return os << float(f8);
|
||||
}
|
||||
|
||||
// all + operator overloading with mixed types
|
||||
// mixed types, always converts to f32, does computation in f32, and returns float
|
||||
inline HIP_FP8_HOST_DEVICE float operator+(const float fa, hip_fp8 b)
|
||||
{
|
||||
return (fa + float(b));
|
||||
}
|
||||
|
||||
inline HIP_FP8_HOST_DEVICE float operator+(hip_fp8 a, const float fb)
|
||||
{
|
||||
return (float(a) + fb);
|
||||
}
|
||||
|
||||
inline HIP_FP8_HOST_DEVICE hip_fp8 operator+(hip_fp8 a, hip_fp8 b)
|
||||
{
|
||||
return hip_fp8(float(a) + float(b));
|
||||
}
|
||||
|
||||
inline HIP_FP8_HOST_DEVICE hip_fp8& operator+=(hip_fp8& a, hip_fp8 b)
|
||||
{
|
||||
return a = hip_fp8(float(a) + float(b));
|
||||
}
|
||||
|
||||
// overloading multiplication, always returns float,
|
||||
inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, hip_fp8 b)
|
||||
{
|
||||
return float(a) * float(b);
|
||||
}
|
||||
|
||||
inline HIP_FP8_HOST_DEVICE float operator*(float a, hip_fp8 b)
|
||||
{
|
||||
return (a * float(b));
|
||||
}
|
||||
|
||||
inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, float b)
|
||||
{
|
||||
return (float(a) * b);
|
||||
}
|
||||
|
||||
inline HIP_FP8_HOST_DEVICE float operator*(int32_t a, hip_fp8 b)
|
||||
{
|
||||
return ((float)a * float(b));
|
||||
}
|
||||
|
||||
inline HIP_FP8_HOST_DEVICE float operator*(double a, hip_fp8 b)
|
||||
{
|
||||
return ((float)a * float(b));
|
||||
}
|
||||
|
||||
// overloading for compare
|
||||
inline HIP_FP8_HOST_DEVICE bool operator==(hip_fp8 a, hip_fp8 b)
|
||||
{
|
||||
return (a.data == b.data);
|
||||
}
|
||||
inline HIP_FP8_HOST_DEVICE bool operator!=(hip_fp8 a, hip_fp8 b)
|
||||
{
|
||||
return (a.data != b.data);
|
||||
}
|
||||
|
||||
inline HIP_FP8_HOST_DEVICE bool operator>=(hip_fp8 a, hip_fp8 b)
|
||||
{
|
||||
return static_cast<float>(a) >= static_cast<float>(b);
|
||||
}
|
||||
inline HIP_FP8_HOST_DEVICE bool operator>(hip_fp8 a, hip_fp8 b)
|
||||
{
|
||||
return static_cast<float>(a) > static_cast<float>(b);
|
||||
}
|
316
csrc/quantization/fp8/amd_detail/hip_float8_impl.h
Normal file
316
csrc/quantization/fp8/amd_detail/hip_float8_impl.h
Normal file
@ -0,0 +1,316 @@
|
||||
#pragma once
|
||||
|
||||
#if defined(__HIPCC__) && (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
|
||||
#define __HIP__MI300__
|
||||
#endif
|
||||
|
||||
#ifdef __HIPCC__
|
||||
#define HIP_FP8_HOST_DEVICE __host__ __device__
|
||||
#define HIP_FP8_HOST __host__
|
||||
#define HIP_FP8_DEVICE __device__
|
||||
#else
|
||||
#define HIP_FP8_HOST_DEVICE
|
||||
#define HIP_FP8_HOST
|
||||
#define HIP_FP8_DEVICE
|
||||
#endif
|
||||
|
||||
namespace hip_fp8_impl
|
||||
{
|
||||
|
||||
#ifdef __HIP__MI300__
|
||||
HIP_FP8_DEVICE uint8_t to_fp8_from_fp32(float v)
|
||||
{
|
||||
uint8_t i8data;
|
||||
union {
|
||||
float fval;
|
||||
uint32_t i32val;
|
||||
uint8_t i8val[4]; // NOTE: not endian independent
|
||||
} val;
|
||||
|
||||
uint32_t ival = 0;
|
||||
val.fval = v;
|
||||
|
||||
if ((val.i32val & 0x7F800000) != 0x7F800000) { /// propagate NAN/INF, no clipping
|
||||
val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0);
|
||||
}
|
||||
|
||||
ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival,
|
||||
false); // false -> WORD0
|
||||
val.i32val = ival;
|
||||
i8data = val.i8val[0];
|
||||
|
||||
return i8data;
|
||||
}
|
||||
#endif // __HIP__MI300__
|
||||
|
||||
HIP_FP8_HOST inline int clz(uint32_t x)
|
||||
{
|
||||
return __builtin_clz(x);
|
||||
}
|
||||
#if defined(__HIPCC__) || defined(__CUDA_ARCH__)
|
||||
HIP_FP8_DEVICE inline int clz(uint32_t x)
|
||||
{
|
||||
return __clz(x);
|
||||
}
|
||||
#endif
|
||||
|
||||
template <int we, int wm, typename T, bool negative_zero_nan, bool clip>
|
||||
HIP_FP8_HOST_DEVICE uint8_t to_float8(T _x, bool stoch = false, uint32_t rng = 0)
|
||||
{
|
||||
#ifdef __HIPCC__
|
||||
constexpr bool is_half = std::is_same<T, _Float16>::value;
|
||||
#else
|
||||
constexpr bool is_half = false;
|
||||
#endif
|
||||
constexpr bool is_float = std::is_same<T, float>::value;
|
||||
static_assert(wm + we == 7, "wm+we==7");
|
||||
static_assert(is_half || is_float, "Only half and float can be cast to f8");
|
||||
|
||||
const int mfmt = (sizeof(T) == 4) ? 23 : 10;
|
||||
uint32_t x;
|
||||
if (sizeof(T) == 4) {
|
||||
x = reinterpret_cast<uint32_t&>(_x);
|
||||
} else {
|
||||
x = reinterpret_cast<uint16_t&>(_x);
|
||||
}
|
||||
|
||||
uint32_t head, mantissa;
|
||||
int exponent, bias;
|
||||
uint32_t sign;
|
||||
|
||||
if (sizeof(T) == 4) {
|
||||
head = x & 0xFF800000;
|
||||
mantissa = x & 0x7FFFFF;
|
||||
exponent = (head >> 23) & 0xFF;
|
||||
sign = head >> 31;
|
||||
bias = 127;
|
||||
} else {
|
||||
head = x & 0xFC00;
|
||||
mantissa = x & 0x3FF;
|
||||
exponent = (head >> 10) & 0x1F;
|
||||
sign = head >> 15;
|
||||
bias = 15;
|
||||
}
|
||||
|
||||
uint32_t signed_inf = (sign << 7) + (((1 << we) - 1) << wm);
|
||||
|
||||
// Deal with inf and NaNs
|
||||
if (negative_zero_nan) {
|
||||
if (sizeof(T) == 4) {
|
||||
if ((x & 0x7F800000) == 0x7F800000) {
|
||||
return 0x80;
|
||||
}
|
||||
} else {
|
||||
// if(__hisinf(x) || __hisnan(x))
|
||||
if ((x & 0x7C00) == 0x7C00) {
|
||||
return 0x80;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (sizeof(T) == 4) {
|
||||
if ((x & 0x7F800000) == 0x7F800000) {
|
||||
return signed_inf + (mantissa != 0 ? 1 : 0);
|
||||
}
|
||||
} else {
|
||||
if ((x & 0x7C00) == 0x7C00) {
|
||||
return signed_inf + (mantissa != 0 ? 1 : 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (x == 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
// First need to check if it is normal or denorm as there is a difference of
|
||||
// implicit 1 Then need to adjust the exponent to align with the F8 exponent,
|
||||
// in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng
|
||||
// to mantissa and truncate. And for RNE, no need to add rng. Then probably
|
||||
// need to check whether there is carry and adjust exponent and mantissa again
|
||||
|
||||
// For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent
|
||||
// bits
|
||||
const int f8_bias = (1 << (we - 1)) - 1 + (negative_zero_nan ? 1 : 0);
|
||||
const int f8_denormal_act_exponent = 1 - f8_bias; // actual exponent of f8 denormal
|
||||
// act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
|
||||
// f8_exponent is the converted f8 exponent with bias encoding
|
||||
// exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
|
||||
// the difference needs to be adjusted and mantissa shifted
|
||||
int act_exponent, f8_exponent, exponent_diff;
|
||||
|
||||
if (exponent == 0) { // fp32/fp16 is in denormal.
|
||||
/* fp32 denormal is below 2^-127 so it is usually not a concern here, we
|
||||
mostly concern fp16 here. In this case, f8 is usually in denormal. But there
|
||||
could be exceptions. fp16 denormal has exponent bias 15 while bf8 with NANOO has
|
||||
exponent bias 16. It means that there are some numbers in fp16 denormal but they
|
||||
are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
|
||||
where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8
|
||||
(NANOO) normal. In this case, the fp16 mantissa should be shift left by 1 */
|
||||
act_exponent = exponent - bias + 1;
|
||||
exponent_diff = f8_denormal_act_exponent - act_exponent; // actual exponent is exponent-bias+1 as it is denormal
|
||||
} else { // fp32/fp16 is normal with implicit 1
|
||||
act_exponent = exponent - bias;
|
||||
if (act_exponent <= f8_denormal_act_exponent) {
|
||||
/* This is the case where fp32/fp16 is normal but it is in f8 denormal
|
||||
range. For example fp8 nanoo mode, denormal exponent is -7, but if the
|
||||
fp32/fp16 actual exponent is -7, it is actually larger due to the implicit 1,
|
||||
Therefore it needs to be adjust to -6 and mantissa shift right by 1.
|
||||
So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
|
||||
exponent_diff = f8_denormal_act_exponent - act_exponent;
|
||||
} else { // both fp32/fp16 and f8 are in normal range
|
||||
exponent_diff = 0; // exponent_diff=0 does not mean there is no difference
|
||||
// for this case,
|
||||
// act_exponent could be larger. Just that it does not need shift mantissa
|
||||
}
|
||||
mantissa += (1 << mfmt); // Add the implicit 1 into mantissa
|
||||
}
|
||||
|
||||
bool midpoint = (mantissa & ((1 << (mfmt - wm + exponent_diff)) - 1)) ==
|
||||
static_cast<uint32_t>(1 << (mfmt - wm + exponent_diff - 1));
|
||||
/* This part is a bit tricky. The judgment of whether it is a tie needs to be
|
||||
done before we shift right as shift right could rip off some residual part
|
||||
and make something not midpoint look like midpoint. For example, the fp16
|
||||
number 0x1002 (0 00100 0000000010), it is larger than midpoint, but after
|
||||
shift right by 4 bits, it would look like midpoint.
|
||||
*/
|
||||
|
||||
if (exponent_diff > 0) {
|
||||
mantissa >>= exponent_diff;
|
||||
} else if (exponent_diff == -1) {
|
||||
mantissa <<= -exponent_diff;
|
||||
}
|
||||
bool implicit_one = mantissa & (1 << mfmt);
|
||||
// if there is no implicit 1, it means the f8 is denormal and need to adjust
|
||||
// to denorm exponent
|
||||
f8_exponent = (act_exponent + exponent_diff) /*actual f8 exponent*/ + f8_bias - (implicit_one ? 0 : 1);
|
||||
|
||||
// Now we have the exponent and mantissa adjusted
|
||||
uint32_t drop_mask = (1 << (mfmt - wm)) - 1;
|
||||
bool odd = mantissa & (1 << (mfmt - wm)); // if the least significant bit that
|
||||
// is not truncated is 1
|
||||
mantissa += (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) & drop_mask;
|
||||
|
||||
// Now we deal with overflow
|
||||
if (f8_exponent == 0) {
|
||||
if ((1 << mfmt) & mantissa) {
|
||||
f8_exponent = 1; // denormal overflow to become normal, promote exponent
|
||||
}
|
||||
} else {
|
||||
if ((1 << (mfmt + 1)) & mantissa) {
|
||||
mantissa >>= 1;
|
||||
f8_exponent++;
|
||||
}
|
||||
}
|
||||
|
||||
mantissa >>= (mfmt - wm);
|
||||
|
||||
// above range: quantize to maximum possible float of the same sign
|
||||
const int max_exp = (1 << we) - (negative_zero_nan ? 1 : 2);
|
||||
if (f8_exponent > max_exp) {
|
||||
if (clip) {
|
||||
mantissa = (1 << wm) - 1;
|
||||
f8_exponent = max_exp;
|
||||
} else {
|
||||
return signed_inf;
|
||||
}
|
||||
}
|
||||
|
||||
if (f8_exponent == 0 && mantissa == 0) {
|
||||
return negative_zero_nan ? 0 : (sign << 7);
|
||||
}
|
||||
mantissa &= (1 << wm) - 1;
|
||||
return (sign << 7) | (f8_exponent << wm) | mantissa;
|
||||
}
|
||||
|
||||
template <int we, int wm, typename T = float, bool negative_zero_nan = true>
|
||||
inline HIP_FP8_HOST_DEVICE T from_float8(uint8_t x)
|
||||
{
|
||||
#ifdef __HIPCC__
|
||||
constexpr bool is_half = std::is_same<T, _Float16>::value;
|
||||
#else
|
||||
constexpr bool is_half = false;
|
||||
#endif
|
||||
constexpr bool is_float = std::is_same<T, float>::value;
|
||||
static_assert(is_half || is_float, "only half and float are supported");
|
||||
|
||||
constexpr int weo = is_half ? 5 : 8;
|
||||
constexpr int wmo = is_half ? 10 : (is_float ? 23 : 7);
|
||||
|
||||
T fInf, fNegInf, fNaN, fNeg0;
|
||||
|
||||
#ifdef __HIPCC__
|
||||
if (is_half) {
|
||||
const uint16_t ihInf = 0x7C00;
|
||||
const uint16_t ihNegInf = 0xFC00;
|
||||
const uint16_t ihNaN = 0x7C01;
|
||||
const uint16_t ihNeg0 = 0x8000;
|
||||
fInf = reinterpret_cast<const _Float16&>(ihInf);
|
||||
fNegInf = reinterpret_cast<const _Float16&>(ihNegInf);
|
||||
fNaN = reinterpret_cast<const _Float16&>(ihNaN);
|
||||
fNeg0 = reinterpret_cast<const _Float16&>(ihNeg0);
|
||||
} else
|
||||
#endif
|
||||
if (is_float) {
|
||||
const uint32_t ifInf = 0x7F800000;
|
||||
const uint32_t ifNegInf = 0xFF800000;
|
||||
const uint32_t ifNaN = 0x7F800001;
|
||||
const uint32_t ifNeg0 = 0x80000000;
|
||||
fInf = reinterpret_cast<const float&>(ifInf);
|
||||
fNegInf = reinterpret_cast<const float&>(ifNegInf);
|
||||
fNaN = reinterpret_cast<const float&>(ifNaN);
|
||||
fNeg0 = reinterpret_cast<const float&>(ifNeg0);
|
||||
}
|
||||
|
||||
if (x == 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
uint32_t sign = x >> 7;
|
||||
uint32_t mantissa = x & ((1 << wm) - 1);
|
||||
int exponent = (x & 0x7F) >> wm;
|
||||
if (negative_zero_nan) {
|
||||
if (x == 0x80) {
|
||||
return fNaN;
|
||||
}
|
||||
} else {
|
||||
if (x == 0x80) {
|
||||
return fNeg0;
|
||||
}
|
||||
if (exponent == ((1 << we) - 1)) {
|
||||
return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN;
|
||||
}
|
||||
}
|
||||
typename std::conditional<sizeof(T) == 2, uint16_t, uint32_t>::type retval;
|
||||
if (we == 5 && is_half && !negative_zero_nan) {
|
||||
retval = x << 8;
|
||||
return reinterpret_cast<const T&>(retval);
|
||||
}
|
||||
|
||||
const int exp_low_cutoff = (1 << (weo - 1)) - (1 << (we - 1)) + 1 - (negative_zero_nan ? 1 : 0);
|
||||
|
||||
// subnormal input
|
||||
if (exponent == 0) {
|
||||
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
|
||||
int sh = 1 + clz(mantissa) - (32 - wm);
|
||||
mantissa <<= sh;
|
||||
exponent += 1 - sh;
|
||||
mantissa &= ((1 << wm) - 1);
|
||||
}
|
||||
exponent += exp_low_cutoff - 1;
|
||||
mantissa <<= wmo - wm;
|
||||
|
||||
// subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
|
||||
if (exponent <= 0) {
|
||||
mantissa |= 1 << wmo;
|
||||
mantissa >>= 1 - exponent;
|
||||
exponent = 0;
|
||||
}
|
||||
|
||||
if (sizeof(T) == 2) {
|
||||
retval = (sign << 15) | (exponent << 10) | mantissa;
|
||||
} else {
|
||||
retval = (sign << 31) | (exponent << 23) | mantissa;
|
||||
}
|
||||
return reinterpret_cast<const T&>(retval);
|
||||
}
|
||||
|
||||
} // namespace hip_fp8_impl
|
517
csrc/quantization/fp8/amd_detail/quant_utils.cuh
Normal file
517
csrc/quantization/fp8/amd_detail/quant_utils.cuh
Normal file
@ -0,0 +1,517 @@
|
||||
#pragma once
|
||||
#include "hip_float8.h"
|
||||
|
||||
#include <hip/hip_fp16.h>
|
||||
#include <hip/hip_bf16.h>
|
||||
#include <hip/hip_bfloat16.h>
|
||||
|
||||
#include "../../../attention/dtype_float32.cuh"
|
||||
#include "../../../attention/dtype_bfloat16.cuh"
|
||||
|
||||
namespace vllm
|
||||
{
|
||||
namespace fp8_e4m3 {
|
||||
template <typename Tout, typename Tin>
|
||||
__inline__ __device__ Tout vec_conversion(const Tin& x)
|
||||
{
|
||||
return x;
|
||||
}
|
||||
|
||||
template <typename Tout, typename Tin>
|
||||
__inline__ __device__ Tout scaled_vec_conversion(const Tin& x, const float scale)
|
||||
{
|
||||
return x;
|
||||
}
|
||||
|
||||
// fp8 -> half
|
||||
template <>
|
||||
__inline__ __device__ uint16_t vec_conversion<uint16_t, uint8_t>(const uint8_t& a)
|
||||
{
|
||||
hip_fp8 f8{a, hip_fp8::from_bits()};
|
||||
__half_raw res;
|
||||
res.data = static_cast<float>(f8);
|
||||
return res.x;
|
||||
}
|
||||
|
||||
// fp8x2 -> half2
|
||||
template <>
|
||||
__inline__ __device__ uint32_t vec_conversion<uint32_t, uint16_t>(const uint16_t& a)
|
||||
{
|
||||
#if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
|
||||
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
|
||||
union {
|
||||
__half2_raw h2r;
|
||||
uint32_t ui32;
|
||||
} tmp;
|
||||
tmp.h2r.x.data = f2[0];
|
||||
tmp.h2r.y.data = f2[1];
|
||||
return tmp.ui32;
|
||||
#else
|
||||
union {
|
||||
uint16_t u16[2];
|
||||
uint32_t u32;
|
||||
} tmp;
|
||||
|
||||
tmp.u16[0] = vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a));
|
||||
tmp.u16[1] = vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a >> 8U));
|
||||
return tmp.u32;
|
||||
#endif
|
||||
}
|
||||
|
||||
// fp8x4 -> half2x2
|
||||
template <>
|
||||
__inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(const uint32_t& a)
|
||||
{
|
||||
union {
|
||||
uint2 u32x2;
|
||||
uint32_t u32[2];
|
||||
} tmp;
|
||||
tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a);
|
||||
tmp.u32[1] = vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U));
|
||||
return tmp.u32x2;
|
||||
}
|
||||
|
||||
// fp8x8 -> half2x4
|
||||
template <>
|
||||
__inline__ __device__ uint4 vec_conversion<uint4, uint2>(const uint2& a)
|
||||
{
|
||||
union {
|
||||
uint4 u64x2;
|
||||
uint2 u64[2];
|
||||
} tmp;
|
||||
tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x);
|
||||
tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y);
|
||||
return tmp.u64x2;
|
||||
}
|
||||
|
||||
using __nv_bfloat16 = __hip_bfloat16;
|
||||
|
||||
// fp8 -> __nv_bfloat16
|
||||
template <>
|
||||
__inline__ __device__ __nv_bfloat16 vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a)
|
||||
{
|
||||
hip_fp8 f8{a, hip_fp8::from_bits()};
|
||||
float f{f8};
|
||||
return __float2bfloat16(f);
|
||||
}
|
||||
|
||||
using __nv_bfloat162 = __hip_bfloat162;
|
||||
|
||||
// fp8x2 -> __nv_bfloat162
|
||||
template <>
|
||||
__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a)
|
||||
{
|
||||
__nv_bfloat162 res;
|
||||
res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a);
|
||||
res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U));
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x4 -> bf16_4_t
|
||||
template <>
|
||||
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a)
|
||||
{
|
||||
bf16_4_t res;
|
||||
res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a);
|
||||
res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U));
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x8 -> bf16_8_t
|
||||
template <>
|
||||
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a)
|
||||
{
|
||||
bf16_4_t tmp1, tmp2;
|
||||
tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x);
|
||||
tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y);
|
||||
bf16_8_t res;
|
||||
res.x = tmp1.x;
|
||||
res.y = tmp1.y;
|
||||
res.z = tmp2.x;
|
||||
res.w = tmp2.y;
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8 -> float
|
||||
template <>
|
||||
__inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a)
|
||||
{
|
||||
hip_fp8 fp8{a, hip_fp8::from_bits()};
|
||||
return static_cast<float>(fp8);
|
||||
}
|
||||
|
||||
// fp8x2 -> float2
|
||||
template <>
|
||||
__inline__ __device__ float2 vec_conversion<float2, uint16_t>(const uint16_t& a)
|
||||
{
|
||||
#if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
|
||||
float2 res;
|
||||
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
|
||||
res.x = f2[0];
|
||||
res.y = f2[1];
|
||||
return res;
|
||||
#else
|
||||
float2 res;
|
||||
res.x = vec_conversion<float, uint8_t>(static_cast<uint8_t>(a));
|
||||
res.y = vec_conversion<float, uint8_t>(static_cast<uint8_t>(a >> 8U));
|
||||
return res;
|
||||
#endif
|
||||
}
|
||||
|
||||
// fp8x4 -> float4
|
||||
template <>
|
||||
__inline__ __device__ Float4_ vec_conversion<Float4_, uint32_t>(const uint32_t& a)
|
||||
{
|
||||
Float4_ res;
|
||||
res.x = vec_conversion<float2, uint16_t>((uint16_t)a);
|
||||
res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U));
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x8 -> float8
|
||||
template <>
|
||||
__inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a)
|
||||
{
|
||||
Float4_ tmp1, tmp2;
|
||||
tmp1 = vec_conversion<Float4_, uint32_t>(a.x);
|
||||
tmp2 = vec_conversion<Float4_, uint32_t>(a.y);
|
||||
Float8_ res;
|
||||
res.x = tmp1.x;
|
||||
res.y = tmp1.y;
|
||||
res.z = tmp2.x;
|
||||
res.w = tmp2.y;
|
||||
return res;
|
||||
}
|
||||
|
||||
// half -> fp8
|
||||
template <>
|
||||
__inline__ __device__ uint8_t vec_conversion<uint8_t, uint16_t>(const uint16_t& a)
|
||||
{
|
||||
__half_raw tmp;
|
||||
tmp.x = a;
|
||||
|
||||
hip_fp8 f8{static_cast<float>(tmp.data)};
|
||||
return f8.data;
|
||||
}
|
||||
|
||||
// bf16 -> fp8
|
||||
template <>
|
||||
__inline__ __device__ uint8_t vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a)
|
||||
{
|
||||
hip_fp8 res{__bfloat162float(a)};
|
||||
return res.data;
|
||||
}
|
||||
|
||||
// float -> fp8
|
||||
template <>
|
||||
__inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a)
|
||||
{
|
||||
hip_fp8 f8(a);
|
||||
return f8.data;
|
||||
}
|
||||
|
||||
// fp8x4 -> float4
|
||||
template <>
|
||||
__inline__ __device__ float4 vec_conversion<float4, uint32_t>(const uint32_t& a)
|
||||
{
|
||||
Float4_ tmp = vec_conversion<Float4_, uint32_t>(a);
|
||||
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
|
||||
return res;
|
||||
}
|
||||
|
||||
// float2 -> half2
|
||||
template <>
|
||||
__inline__ __device__ uint32_t vec_conversion<uint32_t, float2>(const float2& a)
|
||||
{
|
||||
union {
|
||||
half2 float16;
|
||||
uint32_t uint32;
|
||||
};
|
||||
|
||||
float16 = __float22half2_rn(a);
|
||||
return uint32;
|
||||
}
|
||||
|
||||
// Float4 -> half2x2
|
||||
template <>
|
||||
__inline__ __device__ uint2 vec_conversion<uint2, Float4_>(const Float4_& a)
|
||||
{
|
||||
uint2 b;
|
||||
float2 val;
|
||||
val.x = a.x.x;
|
||||
val.y = a.x.y;
|
||||
b.x = vec_conversion<uint32_t, float2>(val);
|
||||
|
||||
val.x = a.y.x;
|
||||
val.y = a.y.y;
|
||||
b.y = vec_conversion<uint32_t, float2>(val);
|
||||
return b;
|
||||
}
|
||||
|
||||
// Float4 -> float4
|
||||
template <>
|
||||
__inline__ __device__ float4 vec_conversion<float4, Float4_>(const Float4_& a)
|
||||
{
|
||||
float4 b;
|
||||
b.x = a.x.x;
|
||||
b.y = a.x.y;
|
||||
b.z = a.y.x;
|
||||
b.w = a.y.y;
|
||||
return b;
|
||||
}
|
||||
|
||||
// Float8 -> half2x4
|
||||
template <>
|
||||
__inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a)
|
||||
{
|
||||
uint4 b;
|
||||
b.x = vec_conversion<uint32_t, float2>(a.x);
|
||||
b.y = vec_conversion<uint32_t, float2>(a.y);
|
||||
b.z = vec_conversion<uint32_t, float2>(a.z);
|
||||
b.w = vec_conversion<uint32_t, float2>(a.w);
|
||||
return b;
|
||||
}
|
||||
|
||||
// float2 -> bfloat162
|
||||
template <>
|
||||
__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>(const float2& a)
|
||||
{
|
||||
__nv_bfloat162 b = __float22bfloat162_rn(a);
|
||||
return b;
|
||||
}
|
||||
|
||||
// Float4 -> bfloat162x2
|
||||
template <>
|
||||
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, Float4_>(const Float4_& a)
|
||||
{
|
||||
bf16_4_t b;
|
||||
b.x = __float22bfloat162_rn(a.x);
|
||||
b.y = __float22bfloat162_rn(a.y);
|
||||
return b;
|
||||
}
|
||||
|
||||
// Float8 -> bfloat162x4
|
||||
template <>
|
||||
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, Float8_>(const Float8_& a)
|
||||
{
|
||||
bf16_8_t b;
|
||||
b.x = __float22bfloat162_rn(a.x);
|
||||
b.y = __float22bfloat162_rn(a.y);
|
||||
b.z = __float22bfloat162_rn(a.z);
|
||||
b.w = __float22bfloat162_rn(a.w);
|
||||
return b;
|
||||
}
|
||||
|
||||
|
||||
/* Scaled and vectorized conversions, for data exchange between high and low precision domains
|
||||
|
||||
Convention of the scale in API, e.g: FP8_data = Quantization( High_Precision_data / scale )
|
||||
s.t.
|
||||
Quantize(HP / scale) => FP8
|
||||
Dequant(FP8) * scale => HP
|
||||
|
||||
*/
|
||||
|
||||
// fp8 -> half
|
||||
template <>
|
||||
__inline__ __device__ uint16_t scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, const float scale)
|
||||
{
|
||||
hip_fp8 f8{a, hip_fp8::from_bits()};
|
||||
__half_raw res;
|
||||
res.data = static_cast<float>(f8) * scale;
|
||||
return res.x;
|
||||
}
|
||||
|
||||
// fp8x2 -> half2
|
||||
template <>
|
||||
__inline__ __device__ uint32_t scaled_vec_conversion<uint32_t, uint16_t>(const uint16_t& a, const float scale)
|
||||
{
|
||||
#if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
|
||||
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
|
||||
union {
|
||||
__half2_raw h2r;
|
||||
uint32_t ui32;
|
||||
} tmp;
|
||||
tmp.h2r.x.data = f2[0] * scale;
|
||||
tmp.h2r.y.data = f2[1] * scale;
|
||||
return tmp.ui32;
|
||||
#else
|
||||
union {
|
||||
uint16_t u16[2];
|
||||
uint32_t u32;
|
||||
} tmp;
|
||||
|
||||
tmp.u16[0] = scaled_vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a), scale);
|
||||
tmp.u16[1] = scaled_vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a >> 8U), scale);
|
||||
return tmp.u32;
|
||||
#endif
|
||||
}
|
||||
|
||||
// fp8x4 -> half2x2
|
||||
template <>
|
||||
__inline__ __device__ uint2 scaled_vec_conversion<uint2, uint32_t>(const uint32_t& a, const float scale)
|
||||
{
|
||||
union {
|
||||
uint2 u32x2;
|
||||
uint32_t u32[2];
|
||||
} tmp;
|
||||
tmp.u32[0] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale);
|
||||
tmp.u32[1] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), scale);
|
||||
return tmp.u32x2;
|
||||
}
|
||||
|
||||
// fp8x8 -> half2x4
|
||||
template <>
|
||||
__inline__ __device__ uint4 scaled_vec_conversion<uint4, uint2>(const uint2& a, const float scale)
|
||||
{
|
||||
union {
|
||||
uint4 u64x2;
|
||||
uint2 u64[2];
|
||||
} tmp;
|
||||
tmp.u64[0] = scaled_vec_conversion<uint2, uint32_t>(a.x, scale);
|
||||
tmp.u64[1] = scaled_vec_conversion<uint2, uint32_t>(a.y, scale);
|
||||
return tmp.u64x2;
|
||||
}
|
||||
|
||||
using __nv_bfloat16 = __hip_bfloat16;
|
||||
|
||||
// fp8 -> __nv_bfloat16
|
||||
template <>
|
||||
__inline__ __device__ __nv_bfloat16 scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, const float scale)
|
||||
{
|
||||
hip_fp8 f8{a, hip_fp8::from_bits()};
|
||||
float f{f8};
|
||||
return __float2bfloat16(f * scale);
|
||||
}
|
||||
|
||||
using __nv_bfloat162 = __hip_bfloat162;
|
||||
|
||||
// fp8x2 -> __nv_bfloat162
|
||||
template <>
|
||||
__inline__ __device__ __nv_bfloat162 scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a, const float scale)
|
||||
{
|
||||
__nv_bfloat162 res;
|
||||
res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale);
|
||||
res.y = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), scale);
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x4 -> bf16_4_t
|
||||
template <>
|
||||
__inline__ __device__ bf16_4_t scaled_vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a, const float scale)
|
||||
{
|
||||
bf16_4_t res;
|
||||
res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale);
|
||||
res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U), scale);
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x8 -> bf16_8_t
|
||||
template <>
|
||||
__inline__ __device__ bf16_8_t scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, const float scale)
|
||||
{
|
||||
bf16_4_t tmp1, tmp2;
|
||||
tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale);
|
||||
tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale);
|
||||
bf16_8_t res;
|
||||
res.x = tmp1.x;
|
||||
res.y = tmp1.y;
|
||||
res.z = tmp2.x;
|
||||
res.w = tmp2.y;
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8 -> float
|
||||
template <>
|
||||
__inline__ __device__ float scaled_vec_conversion<float, uint8_t>(const uint8_t& a, const float scale)
|
||||
{
|
||||
hip_fp8 fp8{a, hip_fp8::from_bits()};
|
||||
return static_cast<float>(fp8) * scale;
|
||||
}
|
||||
|
||||
// fp8x2 -> float2
|
||||
template <>
|
||||
__inline__ __device__ float2 scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, const float scale)
|
||||
{
|
||||
#if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
|
||||
float2 res;
|
||||
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
|
||||
res.x = f2[0] * scale;
|
||||
res.y = f2[1] * scale;
|
||||
return res;
|
||||
#else
|
||||
float2 res;
|
||||
res.x = scaled_vec_conversion<float, uint8_t>(static_cast<uint8_t>(a), scale);
|
||||
res.y = scaled_vec_conversion<float, uint8_t>(static_cast<uint8_t>(a >> 8U), scale);
|
||||
return res;
|
||||
#endif
|
||||
}
|
||||
|
||||
// fp8x4 -> float4
|
||||
template <>
|
||||
__inline__ __device__ Float4_ scaled_vec_conversion<Float4_, uint32_t>(const uint32_t& a, const float scale)
|
||||
{
|
||||
Float4_ res;
|
||||
res.x = scaled_vec_conversion<float2, uint16_t>((uint16_t)a, scale);
|
||||
res.y = scaled_vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), scale);
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x8 -> float8
|
||||
template <>
|
||||
__inline__ __device__ Float8_ scaled_vec_conversion<Float8_, uint2>(const uint2& a, const float scale)
|
||||
{
|
||||
Float4_ tmp1, tmp2;
|
||||
tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale);
|
||||
tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale);
|
||||
Float8_ res;
|
||||
res.x = tmp1.x;
|
||||
res.y = tmp1.y;
|
||||
res.z = tmp2.x;
|
||||
res.w = tmp2.y;
|
||||
return res;
|
||||
}
|
||||
|
||||
|
||||
/* Quantize(HP / scale) => FP8 */
|
||||
|
||||
// TODO(Hai): vectorized to add
|
||||
|
||||
// half -> fp8
|
||||
template <>
|
||||
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, uint16_t>(const uint16_t& a, const float scale)
|
||||
{
|
||||
__half_raw tmp;
|
||||
tmp.x = a;
|
||||
|
||||
hip_fp8 f8{static_cast<float>(tmp.data)/scale};
|
||||
return f8.data;
|
||||
}
|
||||
|
||||
// bf16 -> fp8
|
||||
template <>
|
||||
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a, const float scale)
|
||||
{
|
||||
hip_fp8 res{__bfloat162float(a)/scale};
|
||||
return res.data;
|
||||
}
|
||||
|
||||
// float -> fp8
|
||||
template <>
|
||||
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, float>(const float& a, const float scale)
|
||||
{
|
||||
hip_fp8 f8(a/scale);
|
||||
return f8.data;
|
||||
}
|
||||
|
||||
// fp8x4 -> float4
|
||||
template <>
|
||||
__inline__ __device__ float4 scaled_vec_conversion<float4, uint32_t>(const uint32_t& a, const float scale)
|
||||
{
|
||||
Float4_ tmp = scaled_vec_conversion<Float4_, uint32_t>(a, scale);
|
||||
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
|
||||
return res;
|
||||
}
|
||||
|
||||
}
|
||||
} // namespace vllm
|
@ -2067,7 +2067,7 @@ void gptq_shuffle
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight));
|
||||
vllm::gptq::shuffle_exllama_weight(
|
||||
(uint32_t*) q_weight.data_ptr(),
|
||||
q_perm.device().is_meta() ? NULL : (int*) q_perm.data_ptr(),
|
||||
q_perm.device().is_meta() || q_perm.numel() == 0 ? NULL : (int*) q_perm.data_ptr(),
|
||||
q_weight.size(0) * 32 / bit,
|
||||
q_weight.size(1),
|
||||
bit
|
||||
|
@ -11,13 +11,11 @@
|
||||
# documentation root, use os.path.abspath to make it absolute, like shown here.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from typing import List
|
||||
|
||||
from sphinx.ext import autodoc
|
||||
|
||||
sys.path.insert(0, os.path.abspath(os.path.join('..', '..')))
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# -- Project information -----------------------------------------------------
|
||||
@ -48,7 +46,7 @@ templates_path = ['_templates']
|
||||
# List of patterns, relative to source directory, that match files and
|
||||
# directories to ignore when looking for source files.
|
||||
# This pattern also affects html_static_path and html_extra_path.
|
||||
exclude_patterns = []
|
||||
exclude_patterns: List[str] = []
|
||||
|
||||
# Exclude the prompt "$" when copying code
|
||||
copybutton_prompt_text = r"\$ "
|
||||
@ -85,6 +83,7 @@ autodoc_mock_imports = [
|
||||
"vllm._C",
|
||||
"numpy",
|
||||
"tqdm",
|
||||
"tensorizer",
|
||||
]
|
||||
|
||||
for mock_target in autodoc_mock_imports:
|
||||
|
@ -85,13 +85,3 @@ You can also build and install vLLM from source:
|
||||
|
||||
$ nvcc --version # verify that nvcc is in your PATH
|
||||
$ ${CUDA_HOME}/bin/nvcc --version # verify that nvcc is in your CUDA_HOME
|
||||
|
||||
.. note::
|
||||
If you are developing the C++ backend of vLLM, consider building vLLM with
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ python setup.py develop
|
||||
|
||||
since it will give you incremental builds. The downside is that this method
|
||||
is `deprecated by setuptools <https://github.com/pypa/setuptools/issues/917>`_.
|
||||
|
@ -91,7 +91,8 @@ Documentation
|
||||
:caption: Quantization
|
||||
|
||||
quantization/auto_awq
|
||||
quantization/fp8_e5m2_kv_cache
|
||||
quantization/fp8_e5m2_kvcache
|
||||
quantization/fp8_e4m3_kvcache
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
|
@ -21,6 +21,8 @@ This document provides a high-level guide on integrating a `HuggingFace Transfor
|
||||
Start by forking our `GitHub`_ repository and then :ref:`build it from source <build_from_source>`.
|
||||
This gives you the ability to modify the codebase and test your model.
|
||||
|
||||
.. tip::
|
||||
If you don't want to fork the repository and modify vLLM's codebase, please refer to the "Out-of-Tree Model Integration" section below.
|
||||
|
||||
1. Bring your model code
|
||||
------------------------
|
||||
@ -94,3 +96,28 @@ This method should load the weights from the HuggingFace's checkpoint file and a
|
||||
----------------------
|
||||
|
||||
Finally, include your :code:`*ForCausalLM` class in `vllm/model_executor/models/__init__.py <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/__init__.py>`_ and register it to the :code:`_MODEL_REGISTRY` in `vllm/model_executor/model_loader.py <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/model_loader.py>`_.
|
||||
|
||||
6. Out-of-Tree Model Integration
|
||||
--------------------------------------------
|
||||
|
||||
We also provide a way to integrate a model without modifying the vLLM codebase. Step 2, 3, 4 are still required, but you can skip step 1 and 5.
|
||||
|
||||
Just add the following lines in your code:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from vllm import ModelRegistry
|
||||
from your_code import YourModelForCausalLM
|
||||
ModelRegistry.register_model("YourModelForCausalLM", YourModelForCausalLM)
|
||||
|
||||
If you are running api server with `python -m vllm.entrypoints.openai.api_server args`, you can wrap the entrypoint with the following code:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from vllm import ModelRegistry
|
||||
from your_code import YourModelForCausalLM
|
||||
ModelRegistry.register_model("YourModelForCausalLM", YourModelForCausalLM)
|
||||
import runpy
|
||||
runpy.run_module('vllm.entrypoints.openai.api_server', run_name='__main__')
|
||||
|
||||
Save the above code in a file and run it with `python your_file.py args`.
|
||||
|
@ -36,7 +36,7 @@ Below, you can find an explanation of every engine argument for vLLM:
|
||||
|
||||
Directory to download and load the weights, default to the default cache dir of huggingface.
|
||||
|
||||
.. option:: --load-format {auto,pt,safetensors,npcache,dummy}
|
||||
.. option:: --load-format {auto,pt,safetensors,npcache,dummy,tensorizer}
|
||||
|
||||
The format of the model weights to load.
|
||||
|
||||
@ -45,6 +45,7 @@ Below, you can find an explanation of every engine argument for vLLM:
|
||||
* "safetensors" will load the weights in the safetensors format.
|
||||
* "npcache" will load the weights in pytorch format and store a numpy cache to speed up the loading.
|
||||
* "dummy" will initialize the weights with random values, mainly for profiling.
|
||||
* "tensorizer" will load serialized weights using `CoreWeave's Tensorizer model deserializer. <https://github.com/coreweave/tensorizer>`_ See `examples/tensorize_vllm_model.py <https://github.com/vllm-project/vllm/blob/main/examples/tensorize_vllm_model.py>`_ to serialize a vLLM model, and for more information.
|
||||
|
||||
.. option:: --dtype {auto,half,float16,bfloat16,float,float32}
|
||||
|
||||
@ -118,3 +119,19 @@ Below, you can find an explanation of every engine argument for vLLM:
|
||||
.. option:: --quantization (-q) {awq,squeezellm,None}
|
||||
|
||||
Method used to quantize the weights.
|
||||
|
||||
Async Engine Arguments
|
||||
----------------------
|
||||
Below are the additional arguments related to the asynchronous engine:
|
||||
|
||||
.. option:: --engine-use-ray
|
||||
|
||||
Use Ray to start the LLM engine in a separate process as the server process.
|
||||
|
||||
.. option:: --disable-log-requests
|
||||
|
||||
Disable logging requests.
|
||||
|
||||
.. option:: --max-log-len
|
||||
|
||||
Max number of prompt characters or prompt ID numbers being printed in log. Defaults to unlimited.
|
@ -83,13 +83,17 @@ Alongside each architecture, we include some popular models that use it.
|
||||
- LLaMA, LLaMA-2, Vicuna, Alpaca, Yi
|
||||
- :code:`meta-llama/Llama-2-13b-hf`, :code:`meta-llama/Llama-2-70b-hf`, :code:`openlm-research/open_llama_13b`, :code:`lmsys/vicuna-13b-v1.3`, :code:`01-ai/Yi-6B`, :code:`01-ai/Yi-34B`, etc.
|
||||
- ✅︎
|
||||
* - :code:`MiniCPMForCausalLM`
|
||||
- MiniCPM
|
||||
- :code:`openbmb/MiniCPM-2B-sft-bf16`, :code:`openbmb/MiniCPM-2B-dpo-bf16`, etc.
|
||||
-
|
||||
* - :code:`MistralForCausalLM`
|
||||
- Mistral, Mistral-Instruct
|
||||
- :code:`mistralai/Mistral-7B-v0.1`, :code:`mistralai/Mistral-7B-Instruct-v0.1`, etc.
|
||||
- ✅︎
|
||||
* - :code:`MixtralForCausalLM`
|
||||
- Mixtral-8x7B, Mixtral-8x7B-Instruct
|
||||
- :code:`mistralai/Mixtral-8x7B-v0.1`, :code:`mistralai/Mixtral-8x7B-Instruct-v0.1`, etc.
|
||||
- :code:`mistralai/Mixtral-8x7B-v0.1`, :code:`mistralai/Mixtral-8x7B-Instruct-v0.1`, :code:`mistral-community/Mixtral-8x22B-v0.1`, etc.
|
||||
- ✅︎
|
||||
* - :code:`MPTForCausalLM`
|
||||
- MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter
|
||||
@ -164,3 +168,29 @@ Alternatively, you can raise an issue on our `GitHub <https://github.com/vllm-pr
|
||||
llm = LLM(model=..., revision=..., trust_remote_code=True) # Name or path of your model
|
||||
output = llm.generate("Hello, my name is")
|
||||
print(output)
|
||||
|
||||
Model Support Policy
|
||||
---------------------
|
||||
|
||||
At vLLM, we are committed to facilitating the integration and support of third-party models within our ecosystem. Our approach is designed to balance the need for robustness and the practical limitations of supporting a wide range of models. Here’s how we manage third-party model support:
|
||||
|
||||
1. **Community-Driven Support**: We encourage community contributions for adding new models. When a user requests support for a new model, we welcome pull requests (PRs) from the community. These contributions are evaluated primarily on the sensibility of the output they generate, rather than strict consistency with existing implementations such as those in transformers. **Call for contribution:** PRs coming directly from model vendors are greatly appreciated!
|
||||
|
||||
2. **Best-Effort Consistency**: While we aim to maintain a level of consistency between the models implemented in vLLM and other frameworks like transformers, complete alignment is not always feasible. Factors like acceleration techniques and the use of low-precision computations can introduce discrepancies. Our commitment is to ensure that the implemented models are functional and produce sensible results.
|
||||
|
||||
3. **Issue Resolution and Model Updates**: Users are encouraged to report any bugs or issues they encounter with third-party models. Proposed fixes should be submitted via PRs, with a clear explanation of the problem and the rationale behind the proposed solution. If a fix for one model impacts another, we rely on the community to highlight and address these cross-model dependencies. Note: for bugfix PRs, it is good etiquette to inform the original author to seek their feedback.
|
||||
|
||||
4. **Monitoring and Updates**: Users interested in specific models should monitor the commit history for those models (e.g., by tracking changes in the main/vllm/model_executor/models directory). This proactive approach helps users stay informed about updates and changes that may affect the models they use.
|
||||
|
||||
5. **Selective Focus**: Our resources are primarily directed towards models with significant user interest and impact. Models that are less frequently used may receive less attention, and we rely on the community to play a more active role in their upkeep and improvement.
|
||||
|
||||
Through this approach, vLLM fosters a collaborative environment where both the core development team and the broader community contribute to the robustness and diversity of the third-party models supported in our ecosystem.
|
||||
|
||||
Note that, as an inference engine, vLLM does not introduce new models. Therefore, all models supported by vLLM are third-party models in this regard.
|
||||
|
||||
We have the following levels of testing for models:
|
||||
|
||||
1. **Strict Consistency**: We compare the output of the model with the output of the model in the HuggingFace Transformers library under greedy decoding. This is the most stringent test. Please refer to `test_models.py <https://github.com/vllm-project/vllm/blob/main/tests/models/test_models.py>`_ and `test_big_models.py <https://github.com/vllm-project/vllm/blob/main/tests/models/test_big_models.py>`_ for the models that have passed this test.
|
||||
2. **Output Sensibility**: We check if the output of the model is sensible and coherent, by measuring the perplexity of the output and checking for any obvious errors. This is a less stringent test.
|
||||
3. **Runtime Functionality**: We check if the model can be loaded and run without errors. This is the least stringent test. Please refer to `functionality tests <https://github.com/vllm-project/vllm/tree/main/tests>`_ and `examples <https://github.com/vllm-project/vllm/tree/main/examples>`_ for the models that have passed this test.
|
||||
4. **Community Feedback**: We rely on the community to provide feedback on the models. If a model is broken or not working as expected, we encourage users to raise issues to report it or open pull requests to fix it. The rest of the models fall under this category.
|
||||
|
49
docs/source/quantization/fp8_e4m3_kvcache.rst
Normal file
49
docs/source/quantization/fp8_e4m3_kvcache.rst
Normal file
@ -0,0 +1,49 @@
|
||||
.. _fp8_e4m3_kvcache:
|
||||
|
||||
FP8 E4M3 KV Cache
|
||||
==================
|
||||
|
||||
Quantizing the KV cache to FP8 reduces its memory footprint. This increases the number of tokens that can be stored in the cache,
|
||||
improving throughput. OCP (Open Compute Project www.opencompute.org) specifies two common 8-bit floating point data formats: E5M2
|
||||
(5 exponent bits and 2 mantissa bits) and E4M3FN (4 exponent bits and 3 mantissa bits), often shortened as E4M3. One benefit of
|
||||
the E4M3 format over E5M2 is that floating point numbers are represented in higher precision. However, the small dynamic range of
|
||||
FP8 E4M3 (±240.0 can be represented) typically necessitates the use of a higher-precision (typically FP32) scaling factor alongside
|
||||
each quantized tensor. For now, only per-tensor (scalar) scaling factors are supported. Development is ongoing to support scaling
|
||||
factors of a finer granularity (e.g. per-channel).
|
||||
|
||||
These scaling factors can be specified by passing an optional quantization param JSON to the LLM engine at load time. If
|
||||
this JSON is not specified, scaling factors default to 1.0. These scaling factors are typically obtained when running an
|
||||
unquantized model through a quantizer tool (e.g. AMD quantizer or NVIDIA AMMO).
|
||||
|
||||
To install AMMO (AlgorithMic Model Optimization):
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ pip install --no-cache-dir --extra-index-url https://pypi.nvidia.com nvidia-ammo
|
||||
|
||||
Studies have shown that FP8 E4M3 quantization typically only minimally degrades inference accuracy. The most recent silicon
|
||||
offerings e.g. AMD MI300, NVIDIA Hopper or later support native hardware conversion to and from fp32, fp16, bf16, etc.
|
||||
Thus, LLM inference is greatly accelerated with minimal accuracy loss.
|
||||
|
||||
|
||||
Here is an example of how to enable this feature:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# two float8_e4m3fn kv cache scaling factor files are provided under tests/fp8_kv, please refer to
|
||||
# https://github.com/vllm-project/vllm/blob/main/examples/fp8/README.md to generate kv_cache_scales.json of your own.
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
sampling_params = SamplingParams(temperature=1.3, top_p=0.8)
|
||||
llm = LLM(model="meta-llama/Llama-2-7b-chat-hf",
|
||||
kv_cache_dtype="fp8",
|
||||
quantization_param_path="./tests/fp8_kv/llama2-7b-fp8-kv/kv_cache_scales.json")
|
||||
prompt = "London is the capital of"
|
||||
out = llm.generate(prompt, sampling_params)[0].outputs[0].text
|
||||
print(out)
|
||||
|
||||
# output w/ scaling factors: England, the United Kingdom, and one of the world's leading financial,
|
||||
# output w/o scaling factors: England, located in the southeastern part of the country. It is known
|
||||
|
||||
Note, current prefix caching doesn't work with FP8 KV cache enabled, forward_prefix kernel should handle different KV and cache type.
|
||||
|
@ -1,4 +1,4 @@
|
||||
.. _fp8_e5m2_kv_cache:
|
||||
.. _fp8_kv_cache:
|
||||
|
||||
FP8 E5M2 KV Cache
|
||||
==================
|
||||
@ -21,7 +21,7 @@ Here is an example of how to enable this feature:
|
||||
# Create a sampling params object.
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
# Create an LLM.
|
||||
llm = LLM(model="facebook/opt-125m", kv_cache_dtype="fp8_e5m2")
|
||||
llm = LLM(model="facebook/opt-125m", kv_cache_dtype="fp8")
|
||||
# Generate texts from the prompts. The output is a list of RequestOutput objects
|
||||
# that contain the prompt, generated text, and other information.
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
@ -31,3 +31,6 @@ Here is an example of how to enable this feature:
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
|
||||
|
||||
Note, current prefix caching doesn't work with FP8 KV cache enabled, forward_prefix kernel should handle different KV and cache type.
|
||||
|
@ -4,7 +4,7 @@ vLLM provides an HTTP server that implements OpenAI's [Completions](https://plat
|
||||
|
||||
You can start the server using Python, or using [Docker](deploying_with_docker.rst):
|
||||
```bash
|
||||
python -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-2-7b-hf --dtype float32 --api-key token-abc123
|
||||
python -m vllm.entrypoints.openai.api_server --model mistralai/Mistral-7B-Instruct-v0.2 --dtype auto --api-key token-abc123
|
||||
```
|
||||
|
||||
To call the server, you can use the official OpenAI Python client library, or any other HTTP client.
|
||||
@ -16,9 +16,8 @@ client = OpenAI(
|
||||
)
|
||||
|
||||
completion = client.chat.completions.create(
|
||||
model="meta-llama/Llama-2-7b-hf",
|
||||
model="mistralai/Mistral-7B-Instruct-v0.2",
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello!"}
|
||||
]
|
||||
)
|
||||
@ -38,9 +37,8 @@ Or directly merge them into the JSON payload if you are using HTTP call directly
|
||||
|
||||
```python
|
||||
completion = client.chat.completions.create(
|
||||
model="meta-llama/Llama-2-7b-hf",
|
||||
model="mistralai/Mistral-7B-Instruct-v0.2",
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Classify this sentiment: vLLM is wonderful!"}
|
||||
],
|
||||
extra_body={
|
||||
@ -89,7 +87,7 @@ In order for the language model to support chat protocol, vLLM requires the mode
|
||||
a chat template in its tokenizer configuration. The chat template is a Jinja2 template that
|
||||
specifies how are roles, messages, and other chat-specific tokens are encoded in the input.
|
||||
|
||||
An example chat template for `meta-llama/Llama-2-7b-chat-hf` can be found [here](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/blob/09bd0f49e16738cdfaa6e615203e126038736eb0/tokenizer_config.json#L12)
|
||||
An example chat template for `mistralai/Mistral-7B-Instruct-v0.2` can be found [here](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2#instruction-format)
|
||||
|
||||
Some models do not provide a chat template even though they are instruction/chat fine-tuned. For those model,
|
||||
you can manually specify their chat template in the `--chat-template` parameter with the file path to the chat
|
||||
|
96
examples/fp8/README.md
Normal file
96
examples/fp8/README.md
Normal file
@ -0,0 +1,96 @@
|
||||
# FP8 KV Cache
|
||||
|
||||
This utility extracts the KV cache scaling factors from a quantized HF (Hugging Face) model. The extracted scaling factors are saved to a JSON file, which can later be used by vLLM (variable-length language model) during runtime. This tool is particularly useful when the KV cache data type is FP8 and is intended for use on ROCm (AMD GPU) platforms.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- Python 3.x
|
||||
- PyTorch
|
||||
- NumPy
|
||||
- Hugging Face Transformers
|
||||
- Hugging Face Hub
|
||||
- AMMO
|
||||
|
||||
Before incorporating the FP8 datatype for inference workloads, you must adhere to the following steps:
|
||||
1. Install all necessary prerequisites and dependencies.
|
||||
2. Convert HF model into a quantized HF model.
|
||||
3. Extract KV Cache Scaling Factors from quantized HF model.
|
||||
4. Load KV Cache Scaling Factors into VLLM.
|
||||
|
||||
### 2. Convert HF model into a quantized HF model.
|
||||
Note: The following steps are adapted from the [TensorRT-LLM repository](https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/quantization/README.md).
|
||||
|
||||
`quantize.py` (examples/fp8/quantizer/quantize.py) uses the quantization toolkit (AMMO) to calibrate the PyTorch models and export TensorRT-LLM checkpoints. Each TensorRT-LLM checkpoint contains a config file (in .json format) and one or several rank weight files (in .safetensors format).
|
||||
|
||||
The detailed quantization toolkit (AMMO) conversion guide for FP8 can be found at `examples/fp8/quantizer/README.md`.
|
||||
|
||||
### 3. Extract KV Cache Scaling Factors from quantized HF model.
|
||||
`extract_scales.py` (examples/fp8/extract_scales.py) can be utilized to extract the KV cache scaling factors from your quantized HF model, however at the moment, this tool exclusively supports Llama 2 models. It is also important to note the following:
|
||||
1. **File Structure**: The utility operates under the assumption that all parameters, including KV cache scaling factors, corresponding to a particular Tensor Parallelism (TP) rank are stored in a single file. These files must adhere to a specific naming convention where the TP rank is immediately identified after a specific keyword (e.g., "rank") in the filename.
|
||||
|
||||
2. **TP Decomposition**: The utility assumes consistency between the TP decomposition employed by the quantizer tool and that used by vLLM.
|
||||
|
||||
3. **AMMO Compatibility**: Currently, the generated KV cache scaling factors for AMMO remain uniform across all TP ranks.
|
||||
|
||||
```python
|
||||
# prerequisites:
|
||||
# - Quantized HF LLaMa 2 model
|
||||
python3 examples/fp8/extract_scales.py --help
|
||||
Usage: extract_scales.py [-h] --quantized_model QUANTIZED_MODEL [--load_format {auto,safetensors,npz,pt}] [--output_dir OUTPUT_DIR] [--output_name OUTPUT_NAME] [--tp_size TP_SIZE]
|
||||
|
||||
KV Scale Extraction Example
|
||||
|
||||
optional arguments:
|
||||
--quantized_model: Specify either the local path to, or name of, a quantized HF model. It is expected that the quantization format is FP8_E4M3, for use on ROCm (AMD GPU).
|
||||
Optional arguments:
|
||||
--cache_dir: Specify a cache directory to use in the event of a HF model download. (Default: None)
|
||||
--load_format: Specify the format of the model's tensor files containing the KV cache scaling factors. (Choices: auto, safetensors, npz, pt; Default: auto)
|
||||
--revision: Specify the model's revision number. (Default: None)
|
||||
--output_dir: Specify the output directory. By default the KV cache scaling factors will be saved in the model directory. (Default: None)
|
||||
--output_name: Specify the output filename. (Default: kv_cache_scales.json)
|
||||
--tp_size: Specify the tensor-parallel (TP) size that the quantized model should correspond to. If specified, during KV cache scaling factor extraction the observed TP size will be checked against this and an error will be raised if there is a mismatch. (Default: None)
|
||||
```
|
||||
```python
|
||||
Example:
|
||||
python3 examples/fp8/extract_scales.py --quantized_model <QUANTIZED_MODEL_DIR> --tp_size <TENSOR_PARALLEL_SIZE> --output_dir <PATH_TO_OUTPUT_DIR>
|
||||
```
|
||||
### 4. Load KV Cache Scaling Factors into VLLM.
|
||||
This script evaluates the inference throughput of language models using various backends such as vLLM. It measures the time taken to process a given number of prompts and generate sequences for each prompt. The recently generated KV cache scaling factors are now integrated into the benchmarking process and allow for KV cache scaling factors to be utilized for FP8.
|
||||
```python
|
||||
# prerequisites:
|
||||
# - LLaMa 2 kv_cache_scales.json file
|
||||
|
||||
python3 benchmarks/benchmark_throughput.py --help
|
||||
usage: benchmark_throughput.py [-h] [--backend {vllm,hf,mii}] [--dataset DATASET] [--input-len INPUT_LEN] [--output-len OUTPUT_LEN] [--model MODEL]
|
||||
[--tokenizer TOKENIZER] [--quantization {awq,gptq,squeezellm,None}] [--tensor-parallel-size TENSOR_PARALLEL_SIZE] [--n N]
|
||||
[--use-beam-search] [--num-prompts NUM_PROMPTS] [--seed SEED] [--hf-max-batch-size HF_MAX_BATCH_SIZE] [--trust-remote-code]
|
||||
[--max-model-len MAX_MODEL_LEN] [--dtype {auto,half,float16,bfloat16,float,float32}] [--enforce-eager] [--kv-cache-dtype {auto,fp8}]
|
||||
[--quantization-param-path KV_CACHE_quantization_param_path]
|
||||
|
||||
Benchmark Throughput Example
|
||||
optional arguments:
|
||||
-h, --help show this help message and exit
|
||||
--backend {vllm,hf,mii}
|
||||
--dataset DATASET Path to the dataset.
|
||||
--input-len INPUT_LEN Input prompt length for each request
|
||||
--output-len OUTPUT_LEN Output length for each request. Overrides the output length from the dataset.
|
||||
--model MODEL
|
||||
--tokenizer TOKENIZER
|
||||
--quantization {awq,gptq,squeezellm,None}, -q {awq,gptq,squeezellm,None}
|
||||
--tensor-parallel-size TENSOR_PARALLEL_SIZE, -tp TENSOR_PARALLEL_SIZE
|
||||
--n N Number of generated sequences per prompt.
|
||||
--use-beam-search
|
||||
--num-prompts NUM_PROMPTS Number of prompts to process.
|
||||
--seed SEED
|
||||
--hf-max-batch-size HF_MAX_BATCH_SIZE Maximum batch size for HF backend.
|
||||
--trust-remote-code trust remote code from huggingface
|
||||
--max-model-len MAX_MODEL_LEN Maximum length of a sequence (including prompt and output). If None, will be derived from the model.
|
||||
--dtype {auto,half,float16,bfloat16,float,float32} data type for model weights and activations. The "auto" option will use FP16 precision for FP32 and FP16 models, and BF16 precision for BF16 models.
|
||||
--enforce-eager enforce eager execution
|
||||
--kv-cache-dtype {auto,fp8} Data type for kv cache storage. If "auto", will use model data type. FP8_E5M2 (without scaling) is only supported on cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported ```for common inference criteria.
|
||||
--quantization-param-path QUANT_PARAM_JSON Path to the JSON file containing the KV cache scaling factors. This should generally be supplied, when KV cache dtype is FP8. Otherwise, KV cache scaling factors default to 1.0, which may cause accuracy issues. FP8_E5M2 (without scaling) is only supported on cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for common inference criteria.
|
||||
```
|
||||
```
|
||||
Example:
|
||||
python3 benchmarks/benchmark_throughput.py --input-len <INPUT_LEN> --output-len <OUTPUT_LEN> -tp <TENSOR_PARALLEL_SIZE> --kv-cache-dtype fp8 --quantization-param-path <path/to/kv_cache_scales.json> --model <path-to-llama2>
|
||||
```python
|
367
examples/fp8/extract_scales.py
Normal file
367
examples/fp8/extract_scales.py
Normal file
@ -0,0 +1,367 @@
|
||||
import argparse
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from safetensors.torch import safe_open
|
||||
|
||||
from vllm.model_executor.layers.quantization.schema import QuantParamSchema
|
||||
|
||||
|
||||
# Adapted from vllm/model_executor/weight_utils.py
|
||||
# The main differences are that we add the NPZ format and simplify
|
||||
# its functionality drastically for our purposes (e.g. we assume that
|
||||
# the quantized model exists locally and there is no need to download it)
|
||||
def _prepare_hf_weights(
|
||||
quantized_model_dir: str,
|
||||
load_format: str = "auto",
|
||||
fall_back_to_pt: bool = True,
|
||||
) -> Tuple[str, List[str], bool]:
|
||||
if not os.path.isdir(quantized_model_dir):
|
||||
raise FileNotFoundError(
|
||||
f"The quantized model directory `{quantized_model_dir}` "
|
||||
"does not exist.")
|
||||
use_safetensors = False
|
||||
# Some quantized models use .pt files for storing the weights.
|
||||
if load_format == "auto":
|
||||
allow_patterns = ["*.safetensors", "*.bin"]
|
||||
elif load_format == "safetensors":
|
||||
use_safetensors = True
|
||||
allow_patterns = ["*.safetensors"]
|
||||
elif load_format == "pt":
|
||||
allow_patterns = ["*.pt"]
|
||||
elif load_format == "npz":
|
||||
allow_patterns = ["*.npz"]
|
||||
else:
|
||||
raise ValueError(f"Unknown load_format: {load_format}")
|
||||
if fall_back_to_pt:
|
||||
allow_patterns += ["*.pt"]
|
||||
|
||||
hf_weights_files: List[str] = []
|
||||
for pattern in allow_patterns:
|
||||
hf_weights_files += glob.glob(
|
||||
os.path.join(quantized_model_dir, pattern))
|
||||
if len(hf_weights_files) > 0:
|
||||
if pattern == "*.safetensors":
|
||||
use_safetensors = True
|
||||
break
|
||||
|
||||
if not use_safetensors:
|
||||
# Exclude files that are not needed for inference.
|
||||
# https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233
|
||||
blacklist = [
|
||||
"training_args.bin",
|
||||
"optimizer.bin",
|
||||
"optimizer.pt",
|
||||
"scheduler.pt",
|
||||
"scaler.pt",
|
||||
]
|
||||
hf_weights_files = [
|
||||
f for f in hf_weights_files
|
||||
if not any(f.endswith(x) for x in blacklist)
|
||||
]
|
||||
|
||||
if len(hf_weights_files) == 0:
|
||||
raise RuntimeError(
|
||||
f"Cannot find any model weights with `{quantized_model_dir}`")
|
||||
|
||||
return hf_weights_files, use_safetensors
|
||||
|
||||
|
||||
# Adapted from vllm/model_executor/weight_utils.py
|
||||
def _hf_tensorfile_iterator(filename: str, load_format: str,
|
||||
use_safetensors: bool):
|
||||
if load_format == "npz":
|
||||
assert not use_safetensors
|
||||
with np.load(filename) as data:
|
||||
for name in data.files:
|
||||
param = torch.from_numpy(data[name])
|
||||
yield name, param
|
||||
elif use_safetensors:
|
||||
with safe_open(filename, framework="pt") as f:
|
||||
for name in f.keys(): # NOQA: SIM118
|
||||
param = f.get_tensor(name)
|
||||
yield name, param
|
||||
else:
|
||||
state = torch.load(filename, map_location="cpu")
|
||||
for name, param in state.items():
|
||||
yield name, param
|
||||
del state
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def _kv_scales_extractor(
|
||||
hf_tensor_files: Iterable[str],
|
||||
use_safetensors: bool,
|
||||
rank_keyword: str = "rank",
|
||||
expected_tp_size: Optional[int] = None) -> Dict[int, Dict[int, float]]:
|
||||
"""
|
||||
Given a list of files containing tensor data, attempt to extract KV cache
|
||||
scales from these files. Intended as a helper function taking in the output
|
||||
from _prepare_hf_weights.
|
||||
Args:
|
||||
rank_keyword Matches the number immediately after this keyword in the
|
||||
tensor filename to determine the TP rank corresponding
|
||||
to said tensor file
|
||||
expected_tp_size If specified, the TP size of the tensor files is checked
|
||||
against this and an error is raised if they don't match.
|
||||
Returns a dictionary mapping TP ranks to their relevant KV cache scales.
|
||||
The per-rank scales are themselves represented as a dictionary of layer
|
||||
indices to the respective per-layer scale.
|
||||
"""
|
||||
for char in rank_keyword:
|
||||
assert not char.isdecimal(
|
||||
), f"Rank keyword {rank_keyword} contains a numeric character!"
|
||||
rank_scales_map = {}
|
||||
for tensor_file in hf_tensor_files:
|
||||
try:
|
||||
rank_idx = tensor_file.find(rank_keyword)
|
||||
if rank_idx != -1:
|
||||
start_idx = rank_idx + len(rank_keyword)
|
||||
stop_idx = start_idx
|
||||
while stop_idx < len(
|
||||
tensor_file) and tensor_file[stop_idx].isdecimal():
|
||||
stop_idx += 1
|
||||
if stop_idx == start_idx:
|
||||
raise RuntimeError("Did not find rank # in filename.")
|
||||
rank = int(tensor_file[start_idx:stop_idx])
|
||||
elif len(hf_tensor_files) == 1:
|
||||
# Since there is only one tensor file, we can assume
|
||||
# that it's intended for TP rank 0
|
||||
rank = 0
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Filename does not contain '{rank_keyword}'.")
|
||||
except RuntimeError:
|
||||
print("Unable to determine TP rank "
|
||||
f"corresponding to file '{tensor_file}'")
|
||||
raise
|
||||
|
||||
if rank not in rank_scales_map:
|
||||
layer_scales_map = {}
|
||||
rank_scales_map[rank] = layer_scales_map
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Tensor file '{tensor_file}' shares TP rank {rank} "
|
||||
"with another tensor file.")
|
||||
|
||||
module_delimiter = ":" if args.load_format == "npz" else "."
|
||||
for name, param in _hf_tensorfile_iterator(tensor_file,
|
||||
args.load_format,
|
||||
use_safetensors):
|
||||
if "kv_cache_scaling_factor" in name:
|
||||
nums = [
|
||||
int(s) for s in name.split(module_delimiter)
|
||||
if s.isdecimal()
|
||||
]
|
||||
assert len(
|
||||
nums) == 1, f"Could not determine layer idx for {name}"
|
||||
layer_idx = nums[0]
|
||||
assert layer_idx not in layer_scales_map, f"Duplicate scaling"\
|
||||
f" factor corresponding to layer {layer_idx}"
|
||||
try:
|
||||
layer_scales_map[layer_idx] = param.item()
|
||||
except RuntimeError:
|
||||
print(
|
||||
"This utility supports only per-tensor scalar scales "
|
||||
f"for now. The tensor\n {name} = {param} \nis an "
|
||||
"invalid scale factor.")
|
||||
raise
|
||||
|
||||
if all(
|
||||
len(layer_scales_map) == 0
|
||||
for layer_scales_map in rank_scales_map.values()):
|
||||
# Note: this is true even if the rank_scales_map is empty
|
||||
print("WARNING: No KV cache scale factors found. No output saved.")
|
||||
return None
|
||||
empirical_tp_world_size = max(rank_scales_map.keys()) + 1
|
||||
if expected_tp_size is not None:
|
||||
assert expected_tp_size == empirical_tp_world_size, \
|
||||
f"User expected TP world size = {expected_tp_size} " \
|
||||
"from model but tool is expecting TP world size = " \
|
||||
f"{empirical_tp_world_size} from model instead."
|
||||
for i in range(empirical_tp_world_size):
|
||||
assert i in rank_scales_map, "Expected TP world size = "\
|
||||
f"{empirical_tp_world_size} but did not find KV " \
|
||||
f"cache scaling factors for TP rank {i}"
|
||||
print(f"Found TP world size = {empirical_tp_world_size} "
|
||||
"when extracting KV cache scales!")
|
||||
return rank_scales_map
|
||||
|
||||
|
||||
def _metadata_extractor(quantized_model_dir: str,
|
||||
metadata_extract_fns: \
|
||||
Dict[str, Callable[[Dict[str, Any]], Any]]) \
|
||||
-> Dict[str, Any]:
|
||||
"""
|
||||
Given a directory containing quantized model files, this function
|
||||
aims to extract metadata from the JSON files within this directory.
|
||||
Each JSON file is expected to represent a dictionary in JSON
|
||||
format (referred to as a "JSON-dictionary"). Metadata extraction is
|
||||
defined by a dictionary called metadata_extract_fns, where each
|
||||
metadata field name is mapped to an extraction function.
|
||||
|
||||
These extraction functions are designed to take a JSON-dictionary
|
||||
as their only argument and return the corresponding metadata.
|
||||
While extraction functions are permitted to raise exceptions, they
|
||||
should only raise a KeyError or ValueError if the metadata field
|
||||
cannot be extracted from the current JSON-dictionary, yet there's
|
||||
a possibility of finding it in another JSON-dictionary.
|
||||
|
||||
The function returns a dictionary that maps metadata fields to
|
||||
their extracted data. The keys of this dictionary correspond exactly
|
||||
to those in metadata_extract_fns. If any fields fail to be extracted,
|
||||
their corresponding values are set to None, and a warning is printed.
|
||||
"""
|
||||
if not os.path.isdir(quantized_model_dir):
|
||||
raise FileNotFoundError(
|
||||
f"The quantized model directory `{quantized_model_dir}` "
|
||||
"does not exist.")
|
||||
metadata_files = glob.glob(os.path.join(quantized_model_dir, "*.json"))
|
||||
|
||||
result = {}
|
||||
for file in metadata_files:
|
||||
with open(file) as f:
|
||||
try:
|
||||
metadata = json.load(f)
|
||||
except json.JSONDecodeError:
|
||||
print(f"Could not parse `{file}` as a valid metadata file,"
|
||||
" skipping it.")
|
||||
continue
|
||||
if not isinstance(metadata, dict):
|
||||
print(f"The file `{file}` does not correspond to a "
|
||||
"JSON-serialized dictionary, skipping it.")
|
||||
continue
|
||||
for metadata_name, extract_fn in metadata_extract_fns.items():
|
||||
try:
|
||||
metadata_info = extract_fn(metadata)
|
||||
if metadata_name not in result:
|
||||
result[metadata_name] = metadata_info
|
||||
elif metadata_info != result[metadata_name]:
|
||||
raise RuntimeError(
|
||||
"Metadata mismatch! Originally found "
|
||||
f"{metadata_name} = {result[metadata_name]} but "
|
||||
f"now found {metadata_name} = {metadata_info} in "
|
||||
f"`{file}`")
|
||||
except KeyError:
|
||||
# It is possible that a given file does not contain some
|
||||
# of our selected metadata as it could be located in some
|
||||
# other metadata file.
|
||||
# 'EFINAE': extract_fn failure is not an error.
|
||||
pass
|
||||
except ValueError:
|
||||
# See above.
|
||||
pass
|
||||
|
||||
# Warn if we cannot find any of the requested metadata
|
||||
for metadata_name in metadata_extract_fns:
|
||||
if metadata_name not in result:
|
||||
print("WARNING: Unable to find requested metadata field "
|
||||
f"`{metadata_name}`, setting it to None.")
|
||||
result[metadata_name] = None
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def main(args):
|
||||
metadata_extract_fns = {
|
||||
"model_type": lambda json_dict: json_dict["layers"][0]["decoder_type"],
|
||||
"tp_size": lambda json_dict: int(json_dict["tensor_parallel"]),
|
||||
"model_dtype": lambda json_dict: json_dict["dtype"]
|
||||
}
|
||||
recovered_metadata = _metadata_extractor(args.quantized_model,
|
||||
metadata_extract_fns)
|
||||
if args.tp_size is not None:
|
||||
metadata_tp_size = recovered_metadata["tp_size"]
|
||||
if metadata_tp_size is not None:
|
||||
assert args.tp_size == metadata_tp_size, \
|
||||
f"User expected TP world size = {args.tp_size} " \
|
||||
f"but found TP world size = {metadata_tp_size} from metadata!"
|
||||
expected_tp_size = args.tp_size or recovered_metadata["tp_size"]
|
||||
rank_keyword = "rank"
|
||||
hf_tensor_files, use_safetensors = _prepare_hf_weights(
|
||||
args.quantized_model, args.load_format)
|
||||
rank_scales_map = _kv_scales_extractor(hf_tensor_files, use_safetensors,
|
||||
rank_keyword, expected_tp_size)
|
||||
# Postprocess: formatting to the current schema. Consider pulling it
|
||||
# out into a dedicated function should it ever become more complicated.
|
||||
rank_scales_map = {
|
||||
rank: {k: scale[k]
|
||||
for k in sorted(scale.keys())}
|
||||
for rank, scale in rank_scales_map.items()
|
||||
}
|
||||
# TODO: Expand this with activation and weights scaling factors when
|
||||
# they are used in the future
|
||||
schema = QuantParamSchema(
|
||||
model_type=recovered_metadata["model_type"],
|
||||
kv_cache={
|
||||
"dtype": ("float8_e4m3fn" if len(rank_scales_map) > 0 else
|
||||
recovered_metadata["model_dtype"]),
|
||||
"scaling_factor":
|
||||
rank_scales_map
|
||||
},
|
||||
)
|
||||
|
||||
if args.output_dir is None:
|
||||
output_file = os.path.join(args.quantized_model, args.output_name)
|
||||
else:
|
||||
if not os.path.isdir(args.output_dir):
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
output_file = os.path.join(args.output_dir, args.output_name)
|
||||
|
||||
with open(output_file, 'w') as f:
|
||||
f.write(schema.model_dump_json(indent=4))
|
||||
print(f"Completed! KV cache scaling factors saved to {output_file}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="This simple utility extracts the "
|
||||
"KV cache scaling factors from a quantized HF model "
|
||||
"and saves them to a JSON file compatible with later "
|
||||
"use by vLLM (pass this file to the appropriate "
|
||||
"runtime typically using the argument "
|
||||
"--quantization-param-path <filename>). This is only used "
|
||||
"if the KV cache dtype is FP8 and on ROCm (AMD GPU).")
|
||||
parser.add_argument(
|
||||
"--quantized_model",
|
||||
help="Specify the directory containing a single quantized HF model. "
|
||||
"It is expected that the quantization format is FP8_E4M3, for use "
|
||||
"on ROCm (AMD GPU).",
|
||||
required=True)
|
||||
parser.add_argument(
|
||||
"--load_format",
|
||||
help="Optionally specify the format of the model's tensor files "
|
||||
"containing the KV cache scaling factors.",
|
||||
choices=["auto", "safetensors", "npz", "pt"],
|
||||
default="auto")
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
help="Optionally specify the output directory. By default the "
|
||||
"KV cache scaling factors will be saved in the model directory, "
|
||||
"however you can override this behavior here.",
|
||||
default=None)
|
||||
parser.add_argument(
|
||||
"--output_name",
|
||||
help="Optionally specify the output filename.",
|
||||
# TODO: Change this once additional scaling factors are enabled
|
||||
default="kv_cache_scales.json")
|
||||
parser.add_argument(
|
||||
"--tp_size",
|
||||
help="Optionally specify the tensor-parallel (TP) size that the "
|
||||
"quantized model should correspond to. If specified, during KV "
|
||||
"cache scaling factor extraction the observed TP size will be "
|
||||
"checked against this and an error will be raised if there is "
|
||||
"a mismatch. If not specified, the quantized model's expected "
|
||||
"TP size is instead inferred from the largest TP rank observed. "
|
||||
"The expected TP size is cross-checked against the TP ranks "
|
||||
"observed in the quantized model and an error is raised if any "
|
||||
"discrepancies are found.",
|
||||
default=None,
|
||||
type=int)
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
32
examples/fp8/quantizer/README.md
Normal file
32
examples/fp8/quantizer/README.md
Normal file
@ -0,0 +1,32 @@
|
||||
### Quantizer Utilities
|
||||
`quantize.py`: NVIDIA Quantization utilities using AMMO, ported from TensorRT-LLM:
|
||||
`https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/quantization/quantize.py`
|
||||
|
||||
### Prerequisite
|
||||
|
||||
#### AMMO (AlgorithMic Model Optimization) Installation: nvidia-ammo 0.7.1 or later
|
||||
`pip install --no-cache-dir --extra-index-url https://pypi.nvidia.com nvidia-ammo`
|
||||
|
||||
#### AMMO Download (code and docs)
|
||||
`https://developer.nvidia.com/downloads/assets/cuda/files/nvidia-ammo/nvidia_ammo-0.5.0.tar.gz`
|
||||
`https://developer.nvidia.com/downloads/assets/cuda/files/nvidia-ammo/nvidia_ammo-0.7.1.tar.gz`
|
||||
|
||||
### Usage
|
||||
|
||||
#### Run on H100 system for speed if FP8; number of GPUs depends on the model size
|
||||
|
||||
#### Example: quantize Llama2-7b model from HF to FP8 with FP8 KV Cache:
|
||||
`python quantize.py --model_dir ./ll2-7b --dtype float16 --qformat fp8 --kv_cache_dtype fp8 --output_dir ./ll2_7b_fp8 --calib_size 512 --tp_size 1`
|
||||
|
||||
Outputs: model structure, quantized model & parameters (with scaling factors) are in JSON and Safetensors (npz is generated only for the reference)
|
||||
```
|
||||
# ll ./ll2_7b_fp8/
|
||||
total 19998244
|
||||
drwxr-xr-x 2 root root 4096 Feb 7 01:08 ./
|
||||
drwxrwxr-x 8 1060 1061 4096 Feb 7 01:08 ../
|
||||
-rw-r--r-- 1 root root 176411 Feb 7 01:08 llama_tp1.json
|
||||
-rw-r--r-- 1 root root 13477087480 Feb 7 01:09 llama_tp1_rank0.npz
|
||||
-rw-r--r-- 1 root root 7000893272 Feb 7 01:08 rank0.safetensors
|
||||
#
|
||||
```
|
||||
|
367
examples/fp8/quantizer/quantize.py
Normal file
367
examples/fp8/quantizer/quantize.py
Normal file
@ -0,0 +1,367 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # noqa: E501
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Adapted from examples/quantization/hf_ptq.py
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import json
|
||||
import random
|
||||
import time
|
||||
|
||||
import ammo.torch.quantization as atq
|
||||
import numpy as np
|
||||
import torch
|
||||
from ammo.torch.export import export_model_config
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
RAND_SEED = 1234
|
||||
MAX_SEQ_LEN = 2048
|
||||
|
||||
EMPTY_CFG = {
|
||||
"quant_cfg": {
|
||||
"*weight_quantizer": {
|
||||
"enable": False,
|
||||
},
|
||||
"*input_quantizer": {
|
||||
"enable": False
|
||||
},
|
||||
"*lm_head*": {
|
||||
"enable": False
|
||||
},
|
||||
"*output_layer*": {
|
||||
"enable": False
|
||||
},
|
||||
"default": {
|
||||
"enable": False
|
||||
},
|
||||
},
|
||||
"algorithm": "max",
|
||||
}
|
||||
|
||||
KV_CACHE_CFG = {
|
||||
"*.query_key_value.output_quantizer": {
|
||||
"num_bits": 8,
|
||||
"axis": None,
|
||||
"enable": True
|
||||
},
|
||||
"*.Wqkv.output_quantizer": {
|
||||
"num_bits": 8,
|
||||
"axis": None,
|
||||
"enable": True
|
||||
},
|
||||
"*.W_pack.output_quantizer": {
|
||||
"num_bits": 8,
|
||||
"axis": None,
|
||||
"enable": True
|
||||
},
|
||||
"*.c_attn.output_quantizer": {
|
||||
"num_bits": 8,
|
||||
"axis": None,
|
||||
"enable": True
|
||||
},
|
||||
"*.k_proj.output_quantizer": {
|
||||
"num_bits": 8,
|
||||
"axis": None,
|
||||
"enable": True
|
||||
},
|
||||
"*.v_proj.output_quantizer": {
|
||||
"num_bits": 8,
|
||||
"axis": None,
|
||||
"enable": True
|
||||
},
|
||||
}
|
||||
|
||||
QUANT_CFG_CHOICES = {
|
||||
"int8_sq": atq.INT8_SMOOTHQUANT_CFG,
|
||||
"fp8": atq.FP8_DEFAULT_CFG,
|
||||
"int4_awq": atq.INT4_AWQ_CFG,
|
||||
"w4a8_awq": atq.W4A8_AWQ_BETA_CFG,
|
||||
"int8_wo": EMPTY_CFG,
|
||||
"int4_wo": EMPTY_CFG,
|
||||
"full_prec": EMPTY_CFG,
|
||||
}
|
||||
|
||||
MODEL_NAME_PATTERN_MAP = {
|
||||
"GPT2": "gpt2",
|
||||
"Xverse": "llama",
|
||||
"Llama": "llama",
|
||||
"Mistral": "llama",
|
||||
"GPTJ": "gptj",
|
||||
"FalconForCausalLM": "falcon",
|
||||
"RWForCausalLM": "falcon",
|
||||
"baichuan": "baichuan",
|
||||
"MPT": "mpt",
|
||||
"Bloom": "bloom",
|
||||
"ChatGLM": "chatglm",
|
||||
"QWen": "qwen",
|
||||
}
|
||||
|
||||
|
||||
def get_tokenizer(ckpt_path, max_seq_len=MAX_SEQ_LEN, model_type=None):
|
||||
print(f"Initializing tokenizer from {ckpt_path}")
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
ckpt_path,
|
||||
model_max_length=max_seq_len,
|
||||
padding_side="left",
|
||||
trust_remote_code=True,
|
||||
)
|
||||
if model_type and model_type == "qwen":
|
||||
# qwen use token id 151643 as pad and eos tokens
|
||||
tokenizer.pad_token = tokenizer.convert_ids_to_tokens(151643)
|
||||
tokenizer.eos_token = tokenizer.convert_ids_to_tokens(151643)
|
||||
|
||||
# can't set attribute 'pad_token' for "<unk>"
|
||||
if tokenizer.pad_token != "<unk>":
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
assert (tokenizer.pad_token
|
||||
is not None), f"Pad token for {model_type} cannot be set!"
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
def get_model(ckpt_path, dtype="fp16", device="cuda"):
|
||||
print(f"Initializing model from {ckpt_path}")
|
||||
if dtype == "bf16" or dtype == "bfloat16":
|
||||
dtype = torch.bfloat16
|
||||
elif dtype == "fp16" or dtype == "float16":
|
||||
dtype = torch.float16
|
||||
elif dtype == "fp32" or dtype == "float32":
|
||||
dtype = torch.float32
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown dtype {dtype}")
|
||||
|
||||
# model_kwargs = {"torch_dtype": dtype}
|
||||
model_kwargs = {"torch_dtype": "auto"}
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(ckpt_path,
|
||||
device_map="auto",
|
||||
**model_kwargs,
|
||||
trust_remote_code=True)
|
||||
model.eval()
|
||||
|
||||
model_dtype = next(model.parameters()).dtype
|
||||
if dtype != model_dtype:
|
||||
print("[TensorRT-LLM][WARNING] The manually set model data type is "
|
||||
f"{dtype}, but the data type of the HuggingFace model is "
|
||||
f"{model_dtype}.")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def get_model_type(model):
|
||||
for k, v in MODEL_NAME_PATTERN_MAP.items():
|
||||
if k.lower() in type(model).__name__.lower():
|
||||
return v
|
||||
return None
|
||||
|
||||
|
||||
def get_calib_dataloader(data="cnn_dailymail",
|
||||
tokenizer=None,
|
||||
batch_size=1,
|
||||
calib_size=512,
|
||||
block_size=512,
|
||||
device=None):
|
||||
print("Loading calibration dataset")
|
||||
if data == "pileval":
|
||||
dataset = load_dataset(
|
||||
"json",
|
||||
data_files="https://the-eye.eu/public/AI/pile/val.jsonl.zst",
|
||||
split="train")
|
||||
dataset = dataset["text"][:calib_size]
|
||||
elif data == "cnn_dailymail":
|
||||
dataset = load_dataset("cnn_dailymail", name="3.0.0", split="train")
|
||||
dataset = dataset["article"][:calib_size]
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
batch_encoded = tokenizer.batch_encode_plus(dataset,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=block_size)
|
||||
if device:
|
||||
batch_encoded = batch_encoded.to(device)
|
||||
batch_encoded = batch_encoded["input_ids"]
|
||||
|
||||
calib_dataloader = DataLoader(batch_encoded,
|
||||
batch_size=batch_size,
|
||||
shuffle=False)
|
||||
|
||||
return calib_dataloader
|
||||
|
||||
|
||||
def quantize_model(model, quant_cfg, calib_dataloader=None):
|
||||
|
||||
def calibrate_loop():
|
||||
if calib_dataloader is None:
|
||||
return
|
||||
"""Adjusts weights and scaling factors based on selected algorithms."""
|
||||
for idx, data in enumerate(calib_dataloader):
|
||||
print(f"Calibrating batch {idx}")
|
||||
model(data)
|
||||
|
||||
print("Starting quantization...")
|
||||
start_time = time.time()
|
||||
atq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
|
||||
end_time = time.time()
|
||||
print("Quantization done. Total time used: {:.2f} s.".format(end_time -
|
||||
start_time))
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def main(args):
|
||||
if not torch.cuda.is_available():
|
||||
raise EnvironmentError("GPU is required for inference.")
|
||||
|
||||
random.seed(RAND_SEED)
|
||||
np.random.seed(RAND_SEED)
|
||||
|
||||
model = get_model(args.model_dir, args.dtype, args.device)
|
||||
model_type = get_model_type(model)
|
||||
tokenizer = get_tokenizer(args.model_dir, model_type=model_type)
|
||||
|
||||
if args.qformat in ["full_prec", "int8_wo", "int4_wo"
|
||||
] and args.kv_cache_dtype is None:
|
||||
print(f"No quantization applied, export {args.dtype} model")
|
||||
else:
|
||||
if "awq" in args.qformat:
|
||||
if args.calib_size > 32:
|
||||
print("AWQ calibration could take longer with calib_size = "
|
||||
f"{args.calib_size}, Using calib_size=32 instead")
|
||||
args.calib_size = 32
|
||||
print("\nAWQ calibration could take longer than other calibration "
|
||||
"methods. Please increase the batch size to speed up the "
|
||||
"calibration process. Batch size can be set by adding the "
|
||||
"argument --batch_size <batch_size> to the command line.\n")
|
||||
|
||||
calib_dataloader = get_calib_dataloader(
|
||||
tokenizer=tokenizer,
|
||||
batch_size=args.batch_size,
|
||||
calib_size=args.calib_size,
|
||||
device=args.device,
|
||||
)
|
||||
|
||||
if args.qformat in QUANT_CFG_CHOICES:
|
||||
quant_cfg = QUANT_CFG_CHOICES[args.qformat]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported quantization format: {args.qformat}")
|
||||
|
||||
if "awq" in args.qformat:
|
||||
quant_cfg = copy.deepcopy(QUANT_CFG_CHOICES[args.qformat])
|
||||
weight_quantizer = quant_cfg["quant_cfg"][
|
||||
"*weight_quantizer"] # type: ignore
|
||||
if isinstance(weight_quantizer, list):
|
||||
weight_quantizer = weight_quantizer[0]
|
||||
weight_quantizer["block_sizes"][-1] = args.awq_block_size
|
||||
|
||||
if args.kv_cache_dtype is not None:
|
||||
if args.kv_cache_dtype == "fp8":
|
||||
for value in KV_CACHE_CFG.values():
|
||||
value.update({"num_bits": (4, 3)}) # type: ignore
|
||||
quant_cfg["quant_cfg"].update(KV_CACHE_CFG) # type: ignore
|
||||
|
||||
print(quant_cfg)
|
||||
|
||||
model = quantize_model(model, quant_cfg, calib_dataloader)
|
||||
|
||||
with torch.inference_mode():
|
||||
if model_type is None:
|
||||
print(f"Unknown model type {type(model).__name__}. Continue "
|
||||
"exporting...")
|
||||
model_type = f"unknown:{type(model).__name__}"
|
||||
|
||||
export_path = args.output_dir
|
||||
start_time = time.time()
|
||||
|
||||
if args.qformat == "int4_awq" and model_type == "qwen":
|
||||
torch.save(model.state_dict(), export_path)
|
||||
else:
|
||||
export_npz = (model_type not in [
|
||||
'gptj', 'falcon', 'chatglm', 'mpt', 'llama', 'baichuan'
|
||||
])
|
||||
|
||||
# export safetensors
|
||||
export_model_config(
|
||||
model,
|
||||
model_type,
|
||||
getattr(torch, args.dtype),
|
||||
export_dir=export_path,
|
||||
inference_tensor_parallel=args.tp_size,
|
||||
inference_pipeline_parallel=args.pp_size,
|
||||
# export_tensorrt_llm_config=(not export_npz),
|
||||
export_tensorrt_llm_config=False,
|
||||
export_npz=export_npz)
|
||||
|
||||
# Workaround for wo quantization
|
||||
if args.qformat in ["int8_wo", "int4_wo", "full_prec"]:
|
||||
with open(f"{export_path}/config.json", 'r') as f:
|
||||
tensorrt_llm_config = json.load(f)
|
||||
if args.qformat == "int8_wo":
|
||||
tensorrt_llm_config["quantization"]["quant_algo"] = 'W8A16'
|
||||
elif args.qformat == "int4_wo":
|
||||
tensorrt_llm_config["quantization"]["quant_algo"] = 'W4A16'
|
||||
else:
|
||||
tensorrt_llm_config["quantization"]["quant_algo"] = None
|
||||
with open(f"{export_path}/config.json", "w") as f:
|
||||
json.dump(tensorrt_llm_config, f, indent=4)
|
||||
|
||||
end_time = time.time()
|
||||
print("Quantized model exported to {} \nTotal time used {:.2f} s.".
|
||||
format(export_path, end_time - start_time))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument("--model_dir",
|
||||
help="Specify where the HuggingFace model is",
|
||||
required=True)
|
||||
parser.add_argument("--device", default="cuda")
|
||||
parser.add_argument("--dtype", help="Model data type.", default="float16")
|
||||
parser.add_argument(
|
||||
"--qformat",
|
||||
help="Quantization format.",
|
||||
default="full_prec",
|
||||
choices=[
|
||||
"fp8", "int8_sq", "int4_awq", "w4a8_awq", "int8_wo", "int4_wo",
|
||||
"full_prec"
|
||||
],
|
||||
)
|
||||
parser.add_argument("--batch_size",
|
||||
help="Batch size for calibration.",
|
||||
type=int,
|
||||
default=1)
|
||||
parser.add_argument("--calib_size",
|
||||
help="Number of samples for calibration.",
|
||||
type=int,
|
||||
default=512)
|
||||
parser.add_argument("--output_dir", default="exported_model")
|
||||
parser.add_argument("--tp_size", type=int, default=1)
|
||||
parser.add_argument("--pp_size", type=int, default=1)
|
||||
parser.add_argument("--awq_block_size", type=int, default=128)
|
||||
parser.add_argument("--kv_cache_dtype",
|
||||
help="KV Cache dtype.",
|
||||
default=None,
|
||||
choices=["int8", "fp8", None])
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
282
examples/tensorize_vllm_model.py
Normal file
282
examples/tensorize_vllm_model.py
Normal file
@ -0,0 +1,282 @@
|
||||
import argparse
|
||||
import dataclasses
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from functools import partial
|
||||
from typing import Type
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from tensorizer import (DecryptionParams, EncryptionParams, TensorDeserializer,
|
||||
TensorSerializer, stream_io)
|
||||
from tensorizer.utils import convert_bytes, get_mem_usage, no_init_or_tensor
|
||||
from transformers import AutoConfig, PretrainedConfig
|
||||
|
||||
from vllm.distributed import initialize_model_parallel
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
from vllm.model_executor.tensorizer_loader import TensorizerArgs
|
||||
|
||||
# yapf conflicts with isort for this docstring
|
||||
# yapf: disable
|
||||
"""
|
||||
tensorize_vllm_model.py is a script that can be used to serialize and
|
||||
deserialize vLLM models. These models can be loaded using tensorizer
|
||||
to the GPU extremely quickly over an HTTP/HTTPS endpoint, an S3 endpoint,
|
||||
or locally. Tensor encryption and decryption is also supported, although
|
||||
libsodium must be installed to use it. Install vllm with tensorizer support
|
||||
using `pip install vllm[tensorizer]`.
|
||||
|
||||
To serialize a model, install vLLM from source, then run something
|
||||
like this from the root level of this repository:
|
||||
|
||||
python -m examples.tensorize_vllm_model \
|
||||
--model EleutherAI/gpt-j-6B \
|
||||
--dtype float16 \
|
||||
serialize \
|
||||
--serialized-directory s3://my-bucket/ \
|
||||
--suffix vllm
|
||||
|
||||
Which downloads the model from HuggingFace, loads it into vLLM, serializes it,
|
||||
and saves it to your S3 bucket. A local directory can also be used. This
|
||||
assumes your S3 credentials are specified as environment variables
|
||||
in the form of `S3_ACCESS_KEY_ID`, `S3_SECRET_ACCESS_KEY`, and `S3_ENDPOINT`.
|
||||
To provide S3 credentials directly, you can provide `--s3-access-key-id` and
|
||||
`--s3-secret-access-key`, as well as `--s3-endpoint` as CLI args to this
|
||||
script.
|
||||
|
||||
You can also encrypt the model weights with a randomly-generated key by
|
||||
providing a `--keyfile` argument.
|
||||
|
||||
To deserialize a model, you can run something like this from the root
|
||||
level of this repository:
|
||||
|
||||
python -m examples.tensorize_vllm_model \
|
||||
--model EleutherAI/gpt-j-6B \
|
||||
--dtype float16 \
|
||||
deserialize \
|
||||
--path-to-tensors s3://my-bucket/vllm/EleutherAI/gpt-j-6B/vllm/model.tensors
|
||||
|
||||
Which downloads the model tensors from your S3 bucket and deserializes them.
|
||||
|
||||
You can also provide a `--keyfile` argument to decrypt the model weights if
|
||||
they were serialized with encryption.
|
||||
|
||||
For more information on the available arguments for serializing, run
|
||||
`python -m examples.tensorize_vllm_model serialize --help`.
|
||||
|
||||
Or for deserializing:
|
||||
|
||||
`python -m examples.tensorize_vllm_model deserialize --help`.
|
||||
|
||||
Once a model is serialized, it can be used to load the model when running the
|
||||
OpenAI inference client at `vllm/entrypoints/openai/api_server.py` by providing
|
||||
the `--tensorizer-uri` CLI argument that is functionally the same as the
|
||||
`--path-to-tensors` argument in this script, along with `--vllm-tensorized`, to
|
||||
signify that the model to be deserialized is a vLLM model, rather than a
|
||||
HuggingFace `PreTrainedModel`, which can also be deserialized using tensorizer
|
||||
in the same inference server, albeit without the speed optimizations. To
|
||||
deserialize an encrypted file, the `--encryption-keyfile` argument can be used
|
||||
to provide the path to the keyfile used to encrypt the model weights. For
|
||||
information on all the arguments that can be used to configure tensorizer's
|
||||
deserialization, check out the tensorizer options argument group in the
|
||||
`vllm/entrypoints/openai/api_server.py` script with `--help`.
|
||||
|
||||
Tensorizer can also be invoked with the `LLM` class directly to load models:
|
||||
|
||||
llm = LLM(model="facebook/opt-125m",
|
||||
load_format="tensorizer",
|
||||
tensorizer_uri=path_to_opt_tensors,
|
||||
num_readers=3,
|
||||
vllm_tensorized=True)
|
||||
"""
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="An example script that can be used to serialize and "
|
||||
"deserialize vLLM models. These models "
|
||||
"can be loaded using tensorizer directly to the GPU "
|
||||
"extremely quickly. Tensor encryption and decryption is "
|
||||
"also supported, although libsodium must be installed to "
|
||||
"use it.")
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
subparsers = parser.add_subparsers(dest='command')
|
||||
|
||||
serialize_parser = subparsers.add_parser(
|
||||
'serialize', help="Serialize a model to `--serialized-directory`")
|
||||
|
||||
serialize_parser.add_argument(
|
||||
"--suffix",
|
||||
type=str,
|
||||
required=False,
|
||||
help=(
|
||||
"The suffix to append to the serialized model directory, which is "
|
||||
"used to construct the location of the serialized model tensors, "
|
||||
"e.g. if `--serialized-directory` is `s3://my-bucket/` and "
|
||||
"`--suffix` is `v1`, the serialized model tensors will be "
|
||||
"saved to "
|
||||
"`s3://my-bucket/vllm/EleutherAI/gpt-j-6B/v1/model.tensors`. "
|
||||
"If none is provided, a random UUID will be used."))
|
||||
serialize_parser.add_argument(
|
||||
"--serialized-directory",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The directory to serialize the model to. "
|
||||
"This can be a local directory or S3 URI. The path to where the "
|
||||
"tensors are saved is a combination of the supplied `dir` and model "
|
||||
"reference ID. For instance, if `dir` is the serialized directory, "
|
||||
"and the model HuggingFace ID is `EleutherAI/gpt-j-6B`, tensors will "
|
||||
"be saved to `dir/vllm/EleutherAI/gpt-j-6B/suffix/model.tensors`, "
|
||||
"where `suffix` is given by `--suffix` or a random UUID if not "
|
||||
"provided.")
|
||||
|
||||
serialize_parser.add_argument(
|
||||
"--keyfile",
|
||||
type=str,
|
||||
required=False,
|
||||
help=("Encrypt the model weights with a randomly-generated binary key,"
|
||||
" and save the key at this path"))
|
||||
|
||||
deserialize_parser = subparsers.add_parser(
|
||||
'deserialize',
|
||||
help=("Deserialize a model from `--path-to-tensors`"
|
||||
" to verify it can be loaded and used."))
|
||||
|
||||
deserialize_parser.add_argument(
|
||||
"--path-to-tensors",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The local path or S3 URI to the model tensors to deserialize. ")
|
||||
|
||||
deserialize_parser.add_argument(
|
||||
"--keyfile",
|
||||
type=str,
|
||||
required=False,
|
||||
help=("Path to a binary key to use to decrypt the model weights,"
|
||||
" if the model was serialized with encryption"))
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def make_model_contiguous(model):
|
||||
# Ensure tensors are saved in memory contiguously
|
||||
for param in model.parameters():
|
||||
param.data = param.data.contiguous()
|
||||
|
||||
|
||||
def _get_vllm_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
|
||||
architectures = getattr(config, "architectures", [])
|
||||
for arch in architectures:
|
||||
model_cls = ModelRegistry.load_model_cls(arch)
|
||||
if model_cls is not None:
|
||||
return model_cls
|
||||
raise ValueError(
|
||||
f"Model architectures {architectures} are not supported for now. "
|
||||
f"Supported architectures: {ModelRegistry.get_supported_archs()}")
|
||||
|
||||
|
||||
def serialize():
|
||||
|
||||
eng_args_dict = {f.name: getattr(args, f.name) for f in
|
||||
dataclasses.fields(EngineArgs)}
|
||||
engine_args = EngineArgs.from_cli_args(argparse.Namespace(**eng_args_dict))
|
||||
engine = LLMEngine.from_engine_args(engine_args)
|
||||
|
||||
model = (engine.model_executor.driver_worker.
|
||||
model_runner.model)
|
||||
|
||||
encryption_params = EncryptionParams.random() if keyfile else None
|
||||
if keyfile:
|
||||
with _write_stream(keyfile) as stream:
|
||||
stream.write(encryption_params.key)
|
||||
|
||||
with _write_stream(model_path) as stream:
|
||||
serializer = TensorSerializer(stream, encryption=encryption_params)
|
||||
serializer.write_module(model)
|
||||
serializer.close()
|
||||
|
||||
print("Serialization complete. Model tensors saved to", model_path)
|
||||
if keyfile:
|
||||
print("Key saved to", keyfile)
|
||||
|
||||
|
||||
def deserialize():
|
||||
config = AutoConfig.from_pretrained(model_ref)
|
||||
|
||||
with no_init_or_tensor():
|
||||
model_class = _get_vllm_model_architecture(config)
|
||||
model = model_class(config)
|
||||
|
||||
before_mem = get_mem_usage()
|
||||
start = time.time()
|
||||
|
||||
if keyfile:
|
||||
with _read_stream(keyfile) as stream:
|
||||
key = stream.read()
|
||||
decryption_params = DecryptionParams.from_key(key)
|
||||
tensorizer_args.deserializer_params['encryption'] = \
|
||||
decryption_params
|
||||
|
||||
with (_read_stream(model_path)) as stream, TensorDeserializer(
|
||||
stream, **tensorizer_args.deserializer_params) as deserializer:
|
||||
deserializer.load_into_module(model)
|
||||
end = time.time()
|
||||
|
||||
# Brag about how fast we are.
|
||||
total_bytes_str = convert_bytes(deserializer.total_tensor_bytes)
|
||||
duration = end - start
|
||||
per_second = convert_bytes(deserializer.total_tensor_bytes / duration)
|
||||
after_mem = get_mem_usage()
|
||||
print(
|
||||
f"Deserialized {total_bytes_str} in {end - start:0.2f}s, {per_second}/s"
|
||||
)
|
||||
print(f"Memory usage before: {before_mem}")
|
||||
print(f"Memory usage after: {after_mem}")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
args = parse_args()
|
||||
|
||||
s3_access_key_id = (args.s3_access_key_id or os.environ.get("S3_ACCESS_KEY_ID")
|
||||
or None)
|
||||
s3_secret_access_key = (args.s3_secret_access_key
|
||||
or os.environ.get("S3_SECRET_ACCESS_KEY") or None)
|
||||
|
||||
s3_endpoint = (args.s3_endpoint or os.environ.get("S3_ENDPOINT_URL") or None)
|
||||
|
||||
_read_stream, _write_stream = (partial(
|
||||
stream_io.open_stream,
|
||||
mode=mode,
|
||||
s3_access_key_id=s3_access_key_id,
|
||||
s3_secret_access_key=s3_secret_access_key,
|
||||
s3_endpoint=s3_endpoint,
|
||||
) for mode in ("rb", "wb+"))
|
||||
|
||||
model_ref = args.model
|
||||
|
||||
model_name = model_ref.split("/")[1]
|
||||
|
||||
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
||||
os.environ["MASTER_PORT"] = "8080"
|
||||
|
||||
torch.distributed.init_process_group(world_size=1, rank=0)
|
||||
initialize_model_parallel()
|
||||
|
||||
keyfile = args.keyfile if args.keyfile else None
|
||||
|
||||
if args.command == "serialize":
|
||||
input_dir = args.serialized_directory.rstrip('/')
|
||||
suffix = args.suffix if args.suffix else uuid.uuid4().hex
|
||||
base_path = f"{input_dir}/vllm/{model_ref}/{suffix}"
|
||||
model_path = f"{base_path}/model.tensors"
|
||||
serialize()
|
||||
elif args.command == "deserialize":
|
||||
tensorizer_args = TensorizerArgs.from_cli_args(args)
|
||||
model_path = args.path_to_tensors
|
||||
deserialize()
|
||||
else:
|
||||
raise ValueError("Either serialize or deserialize must be specified.")
|
22
format.sh
22
format.sh
@ -93,9 +93,23 @@ fi
|
||||
echo 'vLLM yapf: Done'
|
||||
|
||||
# Run mypy
|
||||
# TODO(zhuohan): Enable mypy
|
||||
# echo 'vLLM mypy:'
|
||||
# mypy
|
||||
echo 'vLLM mypy:'
|
||||
mypy vllm/attention/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
mypy vllm/distributed/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
mypy vllm/entrypoints/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
mypy vllm/executor/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
mypy vllm/usage/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
mypy vllm/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
mypy vllm/transformers_utils/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
|
||||
# TODO(sang): Follow up
|
||||
# mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
# mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
# mypy vllm/spec_decoding/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
# mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
# mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
|
||||
|
||||
CODESPELL_EXCLUDES=(
|
||||
'--skip' '*docs/source/_build/**'
|
||||
@ -228,5 +242,3 @@ if ! git diff --quiet &>/dev/null; then
|
||||
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
|
@ -5,7 +5,7 @@ requires = [
|
||||
"ninja",
|
||||
"packaging",
|
||||
"setuptools >= 49.4.0",
|
||||
"torch == 2.1.2",
|
||||
"torch == 2.2.1",
|
||||
"wheel",
|
||||
]
|
||||
build-backend = "setuptools.build_meta"
|
||||
@ -13,6 +13,10 @@ build-backend = "setuptools.build_meta"
|
||||
[tool.ruff]
|
||||
# Allow lines to be as long as 80.
|
||||
line-length = 80
|
||||
exclude = [
|
||||
# External file, leaving license intact
|
||||
"examples/fp8/quantizer/quantize.py"
|
||||
]
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = [
|
||||
@ -42,10 +46,13 @@ ignore = [
|
||||
python_version = "3.8"
|
||||
|
||||
ignore_missing_imports = true
|
||||
check_untyped_defs = true
|
||||
|
||||
files = "vllm"
|
||||
# TODO(woosuk): Include the code from Megatron and HuggingFace.
|
||||
exclude = "vllm/model_executor/parallel_utils/|vllm/model_executor/models/"
|
||||
exclude = [
|
||||
"vllm/model_executor/parallel_utils/|vllm/model_executor/models/",
|
||||
]
|
||||
|
||||
|
||||
[tool.codespell]
|
||||
|
@ -3,5 +3,5 @@ cmake>=3.21
|
||||
ninja
|
||||
packaging
|
||||
setuptools>=49.4.0
|
||||
torch==2.1.2
|
||||
torch==2.2.1
|
||||
wheel
|
||||
|
@ -1,19 +1,17 @@
|
||||
cmake>=3.21
|
||||
cmake >= 3.21
|
||||
ninja # For faster builds.
|
||||
psutil
|
||||
ray >= 2.9
|
||||
sentencepiece # Required for LLaMA tokenizer.
|
||||
numpy
|
||||
torch == 2.1.2
|
||||
requests
|
||||
py-cpuinfo
|
||||
transformers >= 4.39.1 # Required for StarCoder2 & Llava.
|
||||
xformers == 0.0.23.post1 # Required for CUDA 12.1.
|
||||
fastapi
|
||||
uvicorn[standard]
|
||||
pydantic >= 2.0 # Required for OpenAI server.
|
||||
prometheus_client >= 0.18.0
|
||||
pynvml == 11.5.0
|
||||
triton >= 2.1.0
|
||||
outlines == 0.0.34
|
||||
tiktoken == 0.6.0 # Required for DBRX tokenizer
|
||||
lm-format-enforcer == 0.9.3
|
||||
outlines == 0.0.34 # Requires torch >= 2.1.0
|
||||
typing_extensions
|
||||
filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4
|
@ -1,15 +1,6 @@
|
||||
cmake>=3.21
|
||||
ninja # For faster builds.
|
||||
psutil
|
||||
ray >= 2.9
|
||||
sentencepiece # Required for LLaMA tokenizer.
|
||||
numpy
|
||||
transformers >= 4.38.0 # Required for Gemma.
|
||||
fastapi
|
||||
uvicorn[standard]
|
||||
pydantic >= 2.0 # Required for OpenAI server.
|
||||
prometheus_client >= 0.18.0
|
||||
torch == 2.1.2+cpu
|
||||
triton >= 2.1.0
|
||||
filelock == 3.13.3
|
||||
py-cpuinfo
|
||||
# Common dependencies
|
||||
-r requirements-common.txt
|
||||
|
||||
# Dependencies for x86_64 CPUs
|
||||
torch == 2.2.1+cpu
|
||||
triton >= 2.2.0 # FIXME(woosuk): This is a hack to avoid import error.
|
9
requirements-cuda.txt
Normal file
9
requirements-cuda.txt
Normal file
@ -0,0 +1,9 @@
|
||||
# Common dependencies
|
||||
-r requirements-common.txt
|
||||
|
||||
# Dependencies for NVIDIA GPUs
|
||||
ray >= 2.9
|
||||
pynvml == 11.5.0
|
||||
vllm-nccl-cu12>=2.18,<2.19 # for downloading nccl library
|
||||
torch == 2.2.1
|
||||
xformers == 0.0.25 # Requires PyTorch 2.2.1
|
@ -7,13 +7,14 @@ codespell==2.2.6
|
||||
isort==5.13.2
|
||||
|
||||
# type checking
|
||||
mypy==0.991
|
||||
mypy==1.9.0
|
||||
types-PyYAML
|
||||
types-requests
|
||||
types-setuptools
|
||||
|
||||
# testing
|
||||
pytest
|
||||
tensorizer==2.9.0a0
|
||||
pytest-forked
|
||||
pytest-asyncio
|
||||
pytest-rerunfailures
|
||||
|
@ -1,12 +1,7 @@
|
||||
sentencepiece # Required for LLaMA tokenizer.
|
||||
numpy
|
||||
# Common dependencies
|
||||
-r requirements-common.txt
|
||||
|
||||
# Dependencies for Neuron devices
|
||||
transformers-neuronx >= 0.9.0
|
||||
torch-neuronx >= 2.1.0
|
||||
neuronx-cc
|
||||
fastapi
|
||||
uvicorn[standard]
|
||||
pydantic >= 2.0 # Required for OpenAI server.
|
||||
prometheus_client >= 0.18.0
|
||||
requests
|
||||
psutil
|
||||
py-cpuinfo
|
@ -1,18 +1,5 @@
|
||||
cmake>=3.21
|
||||
ninja # For faster builds.
|
||||
typing-extensions>=4.8.0
|
||||
starlette
|
||||
requests
|
||||
py-cpuinfo
|
||||
psutil
|
||||
# Common dependencies
|
||||
-r requirements-common.txt
|
||||
|
||||
# Dependencies for AMD GPUs
|
||||
ray == 2.9.3
|
||||
sentencepiece # Required for LLaMA tokenizer.
|
||||
numpy
|
||||
tokenizers>=0.15.0
|
||||
transformers >= 4.39.1 # Required for StarCoder2 & Llava.
|
||||
fastapi
|
||||
uvicorn[standard]
|
||||
pydantic >= 2.0 # Required for OpenAI server.
|
||||
prometheus_client >= 0.18.0
|
||||
outlines == 0.0.34
|
||||
tiktoken == 0.6.0 # Required for DBRX tokenizer
|
||||
|
6
requirements-tpu.txt
Normal file
6
requirements-tpu.txt
Normal file
@ -0,0 +1,6 @@
|
||||
# Common dependencies
|
||||
-r requirements-common.txt
|
||||
|
||||
torch
|
||||
jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
|
||||
flax >= 0.8
|
66
setup.py
66
setup.py
@ -5,7 +5,7 @@ import re
|
||||
import subprocess
|
||||
import sys
|
||||
from shutil import which
|
||||
from typing import List
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
from packaging.version import Version, parse
|
||||
@ -52,7 +52,7 @@ class CMakeExtension(Extension):
|
||||
|
||||
class cmake_build_ext(build_ext):
|
||||
# A dict of extension directories that have been configured.
|
||||
did_config = {}
|
||||
did_config: Dict[str, bool] = {}
|
||||
|
||||
#
|
||||
# Determine number of compilation jobs and optionally nvcc compile threads.
|
||||
@ -188,9 +188,9 @@ class cmake_build_ext(build_ext):
|
||||
|
||||
|
||||
def _is_cuda() -> bool:
|
||||
return VLLM_TARGET_DEVICE == "cuda" \
|
||||
and torch.version.cuda is not None \
|
||||
and not _is_neuron()
|
||||
has_cuda = torch.version.cuda is not None
|
||||
return (VLLM_TARGET_DEVICE == "cuda" and has_cuda
|
||||
and not (_is_neuron() or _is_tpu()))
|
||||
|
||||
|
||||
def _is_hip() -> bool:
|
||||
@ -207,10 +207,18 @@ def _is_neuron() -> bool:
|
||||
return torch_neuronx_installed
|
||||
|
||||
|
||||
def _is_tpu() -> bool:
|
||||
return True # FIXME
|
||||
|
||||
|
||||
def _is_cpu() -> bool:
|
||||
return VLLM_TARGET_DEVICE == "cpu"
|
||||
|
||||
|
||||
def _build_custom_ops() -> bool:
|
||||
return _is_cuda() or _is_hip() or _is_cpu()
|
||||
|
||||
|
||||
def _install_punica() -> bool:
|
||||
return bool(int(os.getenv("VLLM_INSTALL_PUNICA_KERNELS", "0")))
|
||||
|
||||
@ -261,6 +269,7 @@ def get_nvcc_cuda_version() -> Version:
|
||||
|
||||
Adapted from https://github.com/NVIDIA/apex/blob/8b7a1ff183741dd8f9b87e7bafd04cfde99cea28/setup.py
|
||||
"""
|
||||
assert CUDA_HOME is not None, "CUDA_HOME is not set"
|
||||
nvcc_output = subprocess.check_output([CUDA_HOME + "/bin/nvcc", "-V"],
|
||||
universal_newlines=True)
|
||||
output = nvcc_output.split()
|
||||
@ -306,6 +315,8 @@ def get_vllm_version() -> str:
|
||||
if neuron_version != MAIN_CUDA_VERSION:
|
||||
neuron_version_str = neuron_version.replace(".", "")[:3]
|
||||
version += f"+neuron{neuron_version_str}"
|
||||
elif _is_tpu():
|
||||
version += "+tpu"
|
||||
elif _is_cpu():
|
||||
version += "+cpu"
|
||||
else:
|
||||
@ -325,22 +336,40 @@ def read_readme() -> str:
|
||||
|
||||
def get_requirements() -> List[str]:
|
||||
"""Get Python package dependencies from requirements.txt."""
|
||||
|
||||
def _read_requirements(filename: str) -> List[str]:
|
||||
with open(get_path(filename)) as f:
|
||||
requirements = f.read().strip().split("\n")
|
||||
resolved_requirements = []
|
||||
for line in requirements:
|
||||
if line.startswith("-r "):
|
||||
resolved_requirements += _read_requirements(line.split()[1])
|
||||
else:
|
||||
resolved_requirements.append(line)
|
||||
return resolved_requirements
|
||||
|
||||
if _is_cuda():
|
||||
with open(get_path("requirements.txt")) as f:
|
||||
requirements = f.read().strip().split("\n")
|
||||
requirements = _read_requirements("requirements-cuda.txt")
|
||||
cuda_major = torch.version.cuda.split(".")[0]
|
||||
modified_requirements = []
|
||||
for req in requirements:
|
||||
if "vllm-nccl-cu12" in req:
|
||||
modified_requirements.append(
|
||||
req.replace("vllm-nccl-cu12", f"vllm-nccl-cu{cuda_major}"))
|
||||
else:
|
||||
modified_requirements.append(req)
|
||||
requirements = modified_requirements
|
||||
elif _is_hip():
|
||||
with open(get_path("requirements-rocm.txt")) as f:
|
||||
requirements = f.read().strip().split("\n")
|
||||
requirements = _read_requirements("requirements-rocm.txt")
|
||||
elif _is_neuron():
|
||||
with open(get_path("requirements-neuron.txt")) as f:
|
||||
requirements = f.read().strip().split("\n")
|
||||
requirements = _read_requirements("requirements-neuron.txt")
|
||||
elif _is_tpu():
|
||||
requirements = _read_requirements("requirements-tpu.txt")
|
||||
elif _is_cpu():
|
||||
with open(get_path("requirements-cpu.txt")) as f:
|
||||
requirements = f.read().strip().split("\n")
|
||||
requirements = _read_requirements("requirements-cpu.txt")
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported platform, please use CUDA, ROCM or Neuron.")
|
||||
|
||||
"Unsupported platform, please use CUDA, ROCm, Neuron, or CPU.")
|
||||
return requirements
|
||||
|
||||
|
||||
@ -352,7 +381,7 @@ if _is_cuda():
|
||||
if _install_punica():
|
||||
ext_modules.append(CMakeExtension(name="vllm._punica_C"))
|
||||
|
||||
if not _is_neuron():
|
||||
if _build_custom_ops():
|
||||
ext_modules.append(CMakeExtension(name="vllm._C"))
|
||||
|
||||
package_data = {
|
||||
@ -388,6 +417,9 @@ setup(
|
||||
python_requires=">=3.8",
|
||||
install_requires=get_requirements(),
|
||||
ext_modules=ext_modules,
|
||||
cmdclass={"build_ext": cmake_build_ext} if not _is_neuron() else {},
|
||||
extras_require={
|
||||
"tensorizer": ["tensorizer==2.9.0a1"],
|
||||
},
|
||||
cmdclass={"build_ext": cmake_build_ext} if _build_custom_ops() else {},
|
||||
package_data=package_data,
|
||||
)
|
||||
|
@ -25,21 +25,30 @@ def _query_server_long(prompt: str) -> dict:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api_server(tokenizer_pool_size: int):
|
||||
def api_server(tokenizer_pool_size: int, engine_use_ray: bool,
|
||||
worker_use_ray: bool):
|
||||
script_path = Path(__file__).parent.joinpath(
|
||||
"api_server_async_engine.py").absolute()
|
||||
uvicorn_process = subprocess.Popen([
|
||||
commands = [
|
||||
sys.executable, "-u",
|
||||
str(script_path), "--model", "facebook/opt-125m", "--host",
|
||||
"127.0.0.1", "--tokenizer-pool-size",
|
||||
str(tokenizer_pool_size)
|
||||
])
|
||||
]
|
||||
if engine_use_ray:
|
||||
commands.append("--engine-use-ray")
|
||||
if worker_use_ray:
|
||||
commands.append("--worker-use-ray")
|
||||
uvicorn_process = subprocess.Popen(commands)
|
||||
yield
|
||||
uvicorn_process.terminate()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tokenizer_pool_size", [0, 2])
|
||||
def test_api_server(api_server, tokenizer_pool_size: int):
|
||||
@pytest.mark.parametrize("worker_use_ray", [False, True])
|
||||
@pytest.mark.parametrize("engine_use_ray", [False, True])
|
||||
def test_api_server(api_server, tokenizer_pool_size: int, worker_use_ray: bool,
|
||||
engine_use_ray: bool):
|
||||
"""
|
||||
Run the API server and test it.
|
||||
|
||||
|
66
tests/basic_correctness/test_chunked_prefill.py
Normal file
66
tests/basic_correctness/test_chunked_prefill.py
Normal file
@ -0,0 +1,66 @@
|
||||
"""Compare the outputs of HF and vLLM when using greedy sampling.
|
||||
|
||||
It tests chunked prefill. Chunked prefill can be enabled by
|
||||
enable_chunked_prefill=True. If prefill size exceeds max_num_batched_tokens,
|
||||
prefill requests are chunked.
|
||||
|
||||
Run `pytest tests/models/test_chunked_prefill.py`.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
MODELS = [
|
||||
"facebook/opt-125m",
|
||||
"meta-llama/Llama-2-7b-hf",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [32])
|
||||
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16])
|
||||
@pytest.mark.parametrize("enforce_eager", [False, True])
|
||||
# NOTE: Increasing this in this suite will fail CI because we currently cannot
|
||||
# reset distributed env properly. Use a value > 1 just when you test.
|
||||
@pytest.mark.parametrize("tensor_parallel_size", [1])
|
||||
def test_models(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
chunked_prefill_token_size: int,
|
||||
enforce_eager: bool,
|
||||
tensor_parallel_size: int,
|
||||
) -> None:
|
||||
max_num_seqs = min(chunked_prefill_token_size, 256)
|
||||
enable_chunked_prefill = False
|
||||
max_num_batched_tokens = None
|
||||
if chunked_prefill_token_size != -1:
|
||||
enable_chunked_prefill = True
|
||||
max_num_batched_tokens = chunked_prefill_token_size
|
||||
|
||||
hf_model = hf_runner(model, dtype=dtype)
|
||||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||
del hf_model
|
||||
|
||||
vllm_model = vllm_runner(
|
||||
model,
|
||||
dtype=dtype,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
enforce_eager=enforce_eager,
|
||||
max_num_seqs=max_num_seqs,
|
||||
)
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
del vllm_model
|
||||
print(vllm_outputs[0])
|
||||
|
||||
for i in range(len(example_prompts)):
|
||||
hf_output_ids, hf_output_str = hf_outputs[i]
|
||||
vllm_output_ids, vllm_output_str = vllm_outputs[i]
|
||||
assert hf_output_str == vllm_output_str, (
|
||||
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
|
||||
assert hf_output_ids == vllm_output_ids, (
|
||||
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
|
@ -11,8 +11,7 @@ from transformers import (AutoModelForCausalLM, AutoProcessor,
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import TokenizerPoolConfig, VisionLanguageConfig
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
destroy_model_parallel)
|
||||
from vllm.distributed import destroy_model_parallel
|
||||
from vllm.sequence import MultiModalData
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
@ -56,11 +55,15 @@ def cleanup():
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def should_do_global_cleanup_after_test() -> bool:
|
||||
def should_do_global_cleanup_after_test(request) -> bool:
|
||||
"""Allow subdirectories to skip global cleanup by overriding this fixture.
|
||||
This can provide a ~10x speedup for non-GPU unit tests since they don't need
|
||||
to initialize torch.
|
||||
"""
|
||||
|
||||
if request.node.get_closest_marker("skip_global_cleanup"):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@ -398,7 +401,7 @@ class VllmRunner:
|
||||
cleanup()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@pytest.fixture(scope="session")
|
||||
def vllm_runner():
|
||||
return VllmRunner
|
||||
|
||||
|
@ -16,7 +16,7 @@ from vllm import SamplingParams
|
||||
|
||||
# Allow only 5 sequences of ~1024 tokens in worst case.
|
||||
"block_size": 16,
|
||||
"forced_num_gpu_blocks": 5 * (64 + 1),
|
||||
"num_gpu_blocks_override": 5 * (64 + 1),
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{
|
||||
@ -162,14 +162,14 @@ def test_v1_v2_greedy_equality_with_cow(baseline_llm_generator,
|
||||
|
||||
# Allow only 2 sequences of ~128 tokens in worst case.
|
||||
# Note 8 = 128/block_size
|
||||
"forced_num_gpu_blocks": 2 * (8 + 1),
|
||||
"num_gpu_blocks_override": 2 * (8 + 1),
|
||||
},
|
||||
{
|
||||
"block_size": 8,
|
||||
|
||||
# Allow only 2 sequences of ~128 tokens in worst case.
|
||||
# Note 16 = 128/block_size
|
||||
"forced_num_gpu_blocks": 2 * (16 + 1),
|
||||
"num_gpu_blocks_override": 2 * (16 + 1),
|
||||
}
|
||||
])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{
|
||||
|
563
tests/core/test_chunked_prefill_scheduler.py
Normal file
563
tests/core/test_chunked_prefill_scheduler.py
Normal file
@ -0,0 +1,563 @@
|
||||
from typing import List
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest # noqa
|
||||
|
||||
from vllm.config import CacheConfig, SchedulerConfig
|
||||
from vllm.core.scheduler import Scheduler
|
||||
from vllm.sequence import Logprob, SequenceGroup
|
||||
|
||||
from .utils import create_dummy_prompt
|
||||
|
||||
|
||||
def get_sequence_groups(scheduler_output):
|
||||
return [s.seq_group for s in scheduler_output.scheduled_seq_groups]
|
||||
|
||||
|
||||
def append_new_token(seq_group, token_id: int):
|
||||
for seq in seq_group.get_seqs():
|
||||
seq.append_token_id(token_id, {token_id: Logprob(token_id)})
|
||||
|
||||
|
||||
def schedule_and_update_computed_tokens(scheduler):
|
||||
metas, out = scheduler.schedule()
|
||||
for s, meta in zip(out.scheduled_seq_groups, metas):
|
||||
s.seq_group.update_num_computed_tokens(meta.token_chunk_size)
|
||||
return metas, out
|
||||
|
||||
|
||||
def test_simple():
|
||||
"""Verify basic scheduling works."""
|
||||
block_size = 4
|
||||
num_seq_group = 4
|
||||
max_model_len = 16
|
||||
max_num_batched_tokens = 64
|
||||
scheduler_config = SchedulerConfig(max_num_batched_tokens,
|
||||
num_seq_group,
|
||||
max_model_len,
|
||||
enable_chunked_prefill=True)
|
||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||
cache_config.num_cpu_blocks = 8
|
||||
cache_config.num_gpu_blocks = 8
|
||||
scheduler = Scheduler(scheduler_config, cache_config, None)
|
||||
running: List[SequenceGroup] = []
|
||||
|
||||
# Add seq groups to scheduler.
|
||||
for i in range(num_seq_group):
|
||||
_, seq_group = create_dummy_prompt(str(i), prompt_length=block_size)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
running.append(seq_group)
|
||||
|
||||
# Schedule seq groups prompts.
|
||||
num_tokens = block_size * num_seq_group
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert set(get_sequence_groups(out)) == set(running)
|
||||
assert out.num_batched_tokens == num_tokens
|
||||
assert (not out.blocks_to_copy and not out.blocks_to_swap_in
|
||||
and not out.blocks_to_swap_out)
|
||||
assert len(seq_group_meta) == num_seq_group
|
||||
for s in running:
|
||||
append_new_token(s, 1)
|
||||
|
||||
# Schedule seq groups generation.
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert set(get_sequence_groups(out)) == set(running)
|
||||
assert out.num_batched_tokens == num_seq_group
|
||||
assert (not out.blocks_to_copy and not out.blocks_to_swap_in
|
||||
and not out.blocks_to_swap_out)
|
||||
assert len(seq_group_meta) == num_seq_group
|
||||
|
||||
|
||||
def test_chunk():
|
||||
"""Verify prefills are chunked properly."""
|
||||
block_size = 4
|
||||
max_seqs = 60
|
||||
max_model_len = 80
|
||||
max_num_batched_tokens = 64
|
||||
scheduler_config = SchedulerConfig(max_num_batched_tokens,
|
||||
max_seqs,
|
||||
max_model_len,
|
||||
enable_chunked_prefill=True)
|
||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||
cache_config.num_cpu_blocks = 8
|
||||
cache_config.num_gpu_blocks = 8
|
||||
scheduler = Scheduler(scheduler_config, cache_config, None)
|
||||
running: List[SequenceGroup] = []
|
||||
|
||||
# Add seq groups to scheduler.
|
||||
for i in range(2):
|
||||
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
running.append(seq_group)
|
||||
|
||||
# Verify the second request is chunked.
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert set(get_sequence_groups(out)) == set(running)
|
||||
assert seq_group_meta[0].token_chunk_size == 60
|
||||
# Verify it is chunked.
|
||||
assert seq_group_meta[1].token_chunk_size == 4
|
||||
assert out.num_prefill_groups == 2
|
||||
assert out.num_batched_tokens == 64
|
||||
# Only the first seq group has a new token appended.
|
||||
append_new_token(running[0], 1)
|
||||
|
||||
# One chunked prefill, and one decoding.
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert set(get_sequence_groups(out)) == set(running)
|
||||
# The first one is prefill. Scheduler guarantees ordering.
|
||||
assert seq_group_meta[0].token_chunk_size == 56
|
||||
# The second one is a chunked prefill.
|
||||
assert seq_group_meta[1].token_chunk_size == 1
|
||||
assert out.num_prefill_groups == 1
|
||||
assert out.num_batched_tokens == 57
|
||||
|
||||
|
||||
def test_complex():
|
||||
block_size = 4
|
||||
max_seqs = 60
|
||||
max_model_len = 80
|
||||
max_num_batched_tokens = 64
|
||||
scheduler_config = SchedulerConfig(max_num_batched_tokens,
|
||||
max_seqs,
|
||||
max_model_len,
|
||||
enable_chunked_prefill=True)
|
||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||
cache_config.num_cpu_blocks = 8
|
||||
cache_config.num_gpu_blocks = 8
|
||||
scheduler = Scheduler(scheduler_config, cache_config, None)
|
||||
running: List[SequenceGroup] = []
|
||||
|
||||
# Add seq groups to scheduler.
|
||||
for i in range(2):
|
||||
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
running.append(seq_group)
|
||||
assert seq_group.is_prefill()
|
||||
|
||||
# Verify the second request is chunked.
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
|
||||
assert set(get_sequence_groups(out)) == set(running)
|
||||
assert seq_group_meta[0].token_chunk_size == 60
|
||||
# Verify it is chunked.
|
||||
assert seq_group_meta[1].token_chunk_size == 4
|
||||
assert not running[0].is_prefill()
|
||||
assert running[1].is_prefill()
|
||||
assert out.num_prefill_groups == 2
|
||||
assert out.num_batched_tokens == 64
|
||||
# Only the first seq group has a new token appended.
|
||||
append_new_token(running[0], 1)
|
||||
|
||||
# Add 2 more requsets.
|
||||
for i in range(2, 4):
|
||||
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
running.append(seq_group)
|
||||
|
||||
# Decoding & chunked prefill & first chunk of 3rd request is scheduled.
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert len(get_sequence_groups(out)) == 3
|
||||
# The first one is the first chunked prefill.
|
||||
assert seq_group_meta[0].token_chunk_size == 7
|
||||
# The second one is the second new chunked prefill.
|
||||
assert seq_group_meta[1].token_chunk_size == 56
|
||||
# The last one is decode.
|
||||
assert seq_group_meta[2].token_chunk_size == 1
|
||||
# Two of them are in chunked prefill.
|
||||
assert out.num_prefill_groups == 2
|
||||
assert out.num_batched_tokens == 64
|
||||
# The first 2 requests are now in decodine phase.
|
||||
append_new_token(running[0], 1)
|
||||
assert not running[0].is_prefill()
|
||||
append_new_token(running[1], 1)
|
||||
assert not running[1].is_prefill()
|
||||
# The third request is still in prefill stage.
|
||||
assert running[2].is_prefill()
|
||||
|
||||
|
||||
def test_maximal_decoding():
|
||||
"""Verify decoding requests are prioritized."""
|
||||
block_size = 4
|
||||
max_seqs = 2
|
||||
max_model_len = 2
|
||||
max_num_batched_tokens = 2
|
||||
scheduler_config = SchedulerConfig(max_num_batched_tokens,
|
||||
max_seqs,
|
||||
max_model_len,
|
||||
enable_chunked_prefill=True)
|
||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||
cache_config.num_cpu_blocks = 8
|
||||
cache_config.num_gpu_blocks = 8
|
||||
scheduler = Scheduler(scheduler_config, cache_config, None)
|
||||
running: List[SequenceGroup] = []
|
||||
|
||||
# Add seq groups to scheduler.
|
||||
for i in range(2):
|
||||
_, seq_group = create_dummy_prompt(str(i), prompt_length=2)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
running.append(seq_group)
|
||||
assert seq_group.is_prefill()
|
||||
|
||||
# The first prefill is scheduled.
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert len(get_sequence_groups(out)) == 1
|
||||
assert seq_group_meta[0].token_chunk_size == 2
|
||||
assert not running[0].is_prefill()
|
||||
assert running[1].is_prefill()
|
||||
assert out.num_prefill_groups == 1
|
||||
assert out.num_batched_tokens == 2
|
||||
# Only the first seq group has a new token appended.
|
||||
append_new_token(running[0], 1)
|
||||
|
||||
# Create one more seq_group.
|
||||
_, seq_group = create_dummy_prompt("3", prompt_length=2)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
running.append(seq_group)
|
||||
assert seq_group.is_prefill()
|
||||
# The first decoding + second chunk is scheduled.
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert len(get_sequence_groups(out)) == 2
|
||||
assert seq_group_meta[0].token_chunk_size == 1
|
||||
assert seq_group_meta[1].token_chunk_size == 1
|
||||
assert not running[0].is_prefill()
|
||||
assert running[1].is_prefill()
|
||||
assert running[2].is_prefill()
|
||||
assert out.num_prefill_groups == 1
|
||||
assert out.num_batched_tokens == 2
|
||||
append_new_token(running[0], 1)
|
||||
|
||||
# Decoding + running prefill is prioritized.
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert len(get_sequence_groups(out)) == 2
|
||||
assert seq_group_meta[0].token_chunk_size == 1
|
||||
assert seq_group_meta[1].token_chunk_size == 1
|
||||
assert not running[0].is_prefill()
|
||||
assert not running[1].is_prefill()
|
||||
assert out.num_prefill_groups == 1
|
||||
assert out.num_batched_tokens == 2
|
||||
append_new_token(running[0], 1)
|
||||
append_new_token(running[1], 1)
|
||||
|
||||
# Only decoding is prioritized.
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert len(get_sequence_groups(out)) == 2
|
||||
assert seq_group_meta[0].token_chunk_size == 1
|
||||
assert seq_group_meta[1].token_chunk_size == 1
|
||||
assert not running[0].is_prefill()
|
||||
assert not running[1].is_prefill()
|
||||
assert out.num_prefill_groups == 0
|
||||
assert out.num_batched_tokens == 2
|
||||
append_new_token(running[0], 1)
|
||||
append_new_token(running[1], 1)
|
||||
|
||||
# After aborting the decoding request, the fcfs new prefill is prioritized.
|
||||
scheduler.abort_seq_group(running[0].request_id)
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert len(get_sequence_groups(out)) == 2
|
||||
assert seq_group_meta[0].token_chunk_size == 1
|
||||
assert seq_group_meta[1].token_chunk_size == 1
|
||||
assert not running[1].is_prefill()
|
||||
assert running[2].is_prefill()
|
||||
assert out.num_prefill_groups == 1
|
||||
assert out.num_batched_tokens == 2
|
||||
|
||||
|
||||
def test_prompt_limit():
|
||||
"""Verify max_num_batched_tokens < max_model_len is possible."""
|
||||
block_size = 4
|
||||
max_seqs = 32
|
||||
max_model_len = 64
|
||||
max_num_batched_tokens = 32
|
||||
scheduler_config = SchedulerConfig(max_num_batched_tokens,
|
||||
max_seqs,
|
||||
max_model_len,
|
||||
enable_chunked_prefill=True)
|
||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||
cache_config.num_cpu_blocks = 8
|
||||
cache_config.num_gpu_blocks = 8
|
||||
scheduler = Scheduler(scheduler_config, cache_config, None)
|
||||
running: List[SequenceGroup] = []
|
||||
|
||||
_, seq_group = create_dummy_prompt("1", prompt_length=48)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
running.append(seq_group)
|
||||
assert seq_group.is_prefill()
|
||||
|
||||
# The prompt length > max_num_batched_tokens should be still scheduled.
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert len(get_sequence_groups(out)) == 1
|
||||
assert seq_group_meta[0].token_chunk_size == 32
|
||||
assert running[0].is_prefill()
|
||||
assert out.num_prefill_groups == 1
|
||||
assert out.num_batched_tokens == 32
|
||||
|
||||
|
||||
def test_prompt_limit_exceed():
|
||||
block_size = 4
|
||||
max_seqs = 64
|
||||
max_model_len = 32
|
||||
max_num_batched_tokens = 64
|
||||
scheduler_config = SchedulerConfig(max_num_batched_tokens,
|
||||
max_seqs,
|
||||
max_model_len,
|
||||
enable_chunked_prefill=True)
|
||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||
cache_config.num_cpu_blocks = 8
|
||||
cache_config.num_gpu_blocks = 8
|
||||
scheduler = Scheduler(scheduler_config, cache_config, None)
|
||||
running: List[SequenceGroup] = []
|
||||
|
||||
_, seq_group = create_dummy_prompt("2", prompt_length=48)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
running.append(seq_group)
|
||||
assert seq_group.is_prefill()
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert len(out.ignored_seq_groups) == 1
|
||||
assert out.ignored_seq_groups[0] == seq_group
|
||||
|
||||
|
||||
def test_swap():
|
||||
"""Verify swapping works with chunked prefill requests"""
|
||||
block_size = 4
|
||||
max_seqs = 30
|
||||
max_model_len = 200
|
||||
max_num_batched_tokens = 30
|
||||
scheduler_config = SchedulerConfig(max_num_batched_tokens,
|
||||
max_seqs,
|
||||
max_model_len,
|
||||
enable_chunked_prefill=True)
|
||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||
cache_config.num_cpu_blocks = 8
|
||||
cache_config.num_gpu_blocks = 8
|
||||
scheduler = Scheduler(scheduler_config, cache_config, None)
|
||||
|
||||
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
_, out = schedule_and_update_computed_tokens(scheduler)
|
||||
# The request is chunked.
|
||||
# prefill scheduled now.
|
||||
assert len(out.scheduled_seq_groups) == 1
|
||||
assert out.num_prefill_groups == 1
|
||||
assert seq_group.is_prefill()
|
||||
assert out.num_batched_tokens == max_num_batched_tokens
|
||||
|
||||
# The last request should be swapped out.
|
||||
scheduler.block_manager.can_append_slots = MagicMock()
|
||||
|
||||
def cannot_append_second_group(seq_group, num_lookahead_slots):
|
||||
return seq_group.request_id != "1"
|
||||
|
||||
scheduler.block_manager.can_append_slots.side_effect = (
|
||||
cannot_append_second_group)
|
||||
|
||||
# The running prefill is now swapped.
|
||||
_, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert len(out.scheduled_seq_groups) == 0
|
||||
assert out.num_batched_tokens == 0
|
||||
assert out.blocks_to_swap_out != {}
|
||||
assert out.blocks_to_swap_in == {}
|
||||
|
||||
# Add 1 more task. Swap should be prioritized over new prefill.
|
||||
_, seq_group = create_dummy_prompt("2", prompt_length=60)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
_, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert len(out.scheduled_seq_groups) == 1
|
||||
# 3 decodes. It is swapped in.
|
||||
assert out.num_batched_tokens == 30
|
||||
assert out.blocks_to_swap_in != {}
|
||||
assert out.blocks_to_swap_out == {}
|
||||
|
||||
|
||||
def test_running_prefill_prioritized_over_swap():
|
||||
block_size = 4
|
||||
max_seqs = 30
|
||||
max_model_len = 200
|
||||
max_num_batched_tokens = 30
|
||||
scheduler_config = SchedulerConfig(max_num_batched_tokens,
|
||||
max_seqs,
|
||||
max_model_len,
|
||||
enable_chunked_prefill=True)
|
||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||
cache_config.num_cpu_blocks = 8
|
||||
cache_config.num_gpu_blocks = 8
|
||||
scheduler = Scheduler(scheduler_config, cache_config, None)
|
||||
|
||||
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
_, out = schedule_and_update_computed_tokens(scheduler)
|
||||
# The request is chunked.
|
||||
# prefill scheduled now.
|
||||
assert len(out.scheduled_seq_groups) == 1
|
||||
assert out.num_prefill_groups == 1
|
||||
assert seq_group.is_prefill()
|
||||
assert out.num_batched_tokens == max_num_batched_tokens
|
||||
|
||||
# The request should be swapped out.
|
||||
scheduler.block_manager.can_append_slots = MagicMock()
|
||||
|
||||
def cannot_append_second_group(seq_group, num_lookahead_slots):
|
||||
return seq_group.request_id != "1"
|
||||
|
||||
scheduler.block_manager.can_append_slots.side_effect = (
|
||||
cannot_append_second_group)
|
||||
|
||||
# The running prefill is now swapped.
|
||||
_, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert len(out.scheduled_seq_groups) == 0
|
||||
assert out.num_batched_tokens == 0
|
||||
assert out.blocks_to_swap_out != {}
|
||||
assert out.blocks_to_swap_in == {}
|
||||
|
||||
# Add 1 more task. Swap is not possible, so prefill is running.
|
||||
scheduler.block_manager.can_swap_in = MagicMock()
|
||||
scheduler.block_manager.can_swap_in.return_value = False
|
||||
|
||||
_, seq_group2 = create_dummy_prompt("2", prompt_length=60)
|
||||
scheduler.add_seq_group(seq_group2)
|
||||
_, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert len(out.scheduled_seq_groups) == 1
|
||||
# 3 decodes. It is swapped in.
|
||||
assert out.num_batched_tokens == 30
|
||||
assert out.blocks_to_swap_in == {}
|
||||
assert out.blocks_to_swap_out == {}
|
||||
assert out.scheduled_seq_groups[0].seq_group == seq_group2
|
||||
|
||||
# Now although swap is possible, running prefill is prioritized.
|
||||
scheduler.block_manager.can_swap_in.return_value = True
|
||||
_, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert len(out.scheduled_seq_groups) == 1
|
||||
# 3 decodes. It is swapped in.
|
||||
assert out.num_batched_tokens == 30
|
||||
assert out.blocks_to_swap_in == {}
|
||||
assert out.blocks_to_swap_out == {}
|
||||
assert not seq_group2.is_prefill()
|
||||
assert out.scheduled_seq_groups[0].seq_group == seq_group2
|
||||
append_new_token(seq_group2, 1)
|
||||
|
||||
# Decoding is prioritized.
|
||||
_, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert len(out.scheduled_seq_groups) == 1
|
||||
# 3 decodes. It is swapped in.
|
||||
assert out.num_batched_tokens == 1
|
||||
assert out.blocks_to_swap_in == {}
|
||||
assert out.blocks_to_swap_out == {}
|
||||
assert not seq_group2.is_prefill()
|
||||
assert out.scheduled_seq_groups[0].seq_group == seq_group2
|
||||
append_new_token(seq_group2, 1)
|
||||
|
||||
# Since we abort the sequence group, we can finally swap.
|
||||
scheduler.abort_seq_group(seq_group2.request_id)
|
||||
_, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert len(out.scheduled_seq_groups) == 1
|
||||
assert out.num_batched_tokens == 30
|
||||
assert out.blocks_to_swap_in != {}
|
||||
assert out.blocks_to_swap_out == {}
|
||||
|
||||
|
||||
def test_chunked_prefill_preempt():
|
||||
"""Verify preempt works with chunked prefill requests"""
|
||||
block_size = 4
|
||||
max_seqs = 30
|
||||
max_model_len = 200
|
||||
max_num_batched_tokens = 30
|
||||
scheduler_config = SchedulerConfig(max_num_batched_tokens,
|
||||
max_seqs,
|
||||
max_model_len,
|
||||
enable_chunked_prefill=True)
|
||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||
cache_config.num_cpu_blocks = 8
|
||||
cache_config.num_gpu_blocks = 8
|
||||
scheduler = Scheduler(scheduler_config, cache_config, None)
|
||||
|
||||
_, seq_group = create_dummy_prompt("1", prompt_length=60)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
_, out = schedule_and_update_computed_tokens(scheduler)
|
||||
# The request is chunked.
|
||||
# prefill scheduled now.
|
||||
assert len(out.scheduled_seq_groups) == 1
|
||||
assert out.num_prefill_groups == 1
|
||||
assert seq_group.is_prefill()
|
||||
assert out.num_batched_tokens == max_num_batched_tokens
|
||||
|
||||
# The request should be preempted.
|
||||
scheduler.block_manager.can_append_slots = MagicMock()
|
||||
|
||||
def cannot_append_second_group(seq_group, num_lookahead_slots):
|
||||
return seq_group.request_id != "1"
|
||||
|
||||
scheduler.block_manager.can_append_slots.side_effect = (
|
||||
cannot_append_second_group)
|
||||
|
||||
# The running prefill is now preempted.
|
||||
_, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert len(out.scheduled_seq_groups) == 0
|
||||
assert out.num_batched_tokens == 0
|
||||
assert out.blocks_to_swap_out == {}
|
||||
assert out.blocks_to_swap_in == {}
|
||||
|
||||
# Make sure we can reschedule preempted request.
|
||||
_, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert len(out.scheduled_seq_groups) == 1
|
||||
assert out.num_prefill_groups == 1
|
||||
assert seq_group.is_prefill()
|
||||
assert out.num_batched_tokens == max_num_batched_tokens
|
||||
assert seq_group.get_num_uncomputed_tokens() == 30
|
||||
|
||||
# We should be able to run prefill twice as it is chunked.
|
||||
def cannot_append_second_group(seq_group, num_lookahead_slots):
|
||||
return True
|
||||
|
||||
scheduler.block_manager.can_append_slots.side_effect = (
|
||||
cannot_append_second_group)
|
||||
_, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert len(out.scheduled_seq_groups) == 1
|
||||
assert out.num_prefill_groups == 1
|
||||
assert not seq_group.is_prefill()
|
||||
assert out.num_batched_tokens == max_num_batched_tokens
|
||||
|
||||
|
||||
def test_chunked_prefill_max_seqs():
|
||||
block_size = 4
|
||||
max_seqs = 2
|
||||
max_model_len = 80
|
||||
max_num_batched_tokens = 64
|
||||
scheduler_config = SchedulerConfig(max_num_batched_tokens,
|
||||
max_seqs,
|
||||
max_model_len,
|
||||
enable_chunked_prefill=True)
|
||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||
cache_config.num_cpu_blocks = 8
|
||||
cache_config.num_gpu_blocks = 8
|
||||
scheduler = Scheduler(scheduler_config, cache_config, None)
|
||||
running = []
|
||||
|
||||
_, seq_group = create_dummy_prompt("1", prompt_length=65)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
running.append(seq_group)
|
||||
# The first prefill is chunked.
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert seq_group_meta[0].token_chunk_size == max_num_batched_tokens
|
||||
assert len(get_sequence_groups(out)) == 1
|
||||
|
||||
# Add new requests.
|
||||
for i in range(4):
|
||||
_, seq_group = create_dummy_prompt(str(i), prompt_length=65)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
running.append(seq_group)
|
||||
|
||||
# Make sure only 2 requests are scheduled.
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert out.num_batched_tokens == max_num_batched_tokens
|
||||
assert len(get_sequence_groups(out)) == 2
|
||||
assert not running[0].is_prefill()
|
||||
assert running[1].is_prefill()
|
||||
append_new_token(running[0], 1)
|
||||
|
||||
# Although we have enough token budget, we can only schedule max_seqs.
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert seq_group_meta[0].token_chunk_size == 2
|
||||
assert seq_group_meta[1].token_chunk_size == 1
|
||||
assert out.num_batched_tokens == 3
|
||||
assert len(get_sequence_groups(out)) == max_seqs
|
||||
assert not running[0].is_prefill()
|
||||
assert not running[1].is_prefill()
|
@ -1,11 +1,16 @@
|
||||
import time
|
||||
from collections import deque
|
||||
from typing import List
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest # noqa
|
||||
|
||||
from vllm.config import CacheConfig, SchedulerConfig
|
||||
from vllm.core.scheduler import Scheduler
|
||||
from vllm.sequence import Logprob, SequenceGroup
|
||||
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
|
||||
from vllm.core.interfaces import AllocStatus
|
||||
from vllm.core.policy import PolicyFactory
|
||||
from vllm.core.scheduler import Scheduler, SchedulingBudget
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sequence import Logprob, SequenceGroup, SequenceStatus
|
||||
|
||||
from .utils import create_dummy_prompt
|
||||
|
||||
@ -14,6 +19,26 @@ def get_sequence_groups(scheduler_output):
|
||||
return [s.seq_group for s in scheduler_output.scheduled_seq_groups]
|
||||
|
||||
|
||||
def append_new_token(out, token_id: int):
|
||||
seq_groups = get_sequence_groups(out)
|
||||
for seq_group in seq_groups:
|
||||
for seq in seq_group.get_seqs():
|
||||
seq.append_token_id(token_id, {token_id: Logprob(token_id)})
|
||||
|
||||
|
||||
def schedule_and_update_computed_tokens(scheduler):
|
||||
metas, out = scheduler.schedule()
|
||||
for s, meta in zip(out.scheduled_seq_groups, metas):
|
||||
s.seq_group.update_num_computed_tokens(meta.token_chunk_size)
|
||||
return metas, out
|
||||
|
||||
|
||||
def append_new_token_seq_group(token_chunk_size, seq_group, token_id: int):
|
||||
seq_group.update_num_computed_tokens(token_chunk_size)
|
||||
for seq in seq_group.get_seqs():
|
||||
seq.append_token_id(token_id, {token_id: Logprob(token_id)})
|
||||
|
||||
|
||||
def test_scheduler_add_seq_group():
|
||||
block_size = 4
|
||||
scheduler_config = SchedulerConfig(100, 64, 1)
|
||||
@ -71,20 +96,52 @@ def test_scheduler_schedule_simple():
|
||||
|
||||
# Schedule seq groups prompts.
|
||||
num_tokens = block_size * num_seq_group
|
||||
seq_group_meta, out = scheduler.schedule()
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert set(get_sequence_groups(out)) == set(running)
|
||||
assert out.num_batched_tokens == num_tokens
|
||||
assert (not out.blocks_to_copy and not out.blocks_to_swap_in
|
||||
and not out.blocks_to_swap_out)
|
||||
assert len(seq_group_meta) == num_seq_group
|
||||
append_new_token(out, 1)
|
||||
|
||||
# Schedule seq groups generation.
|
||||
seq_group_meta, out = scheduler.schedule()
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert set(get_sequence_groups(out)) == set(running)
|
||||
assert out.num_batched_tokens == num_seq_group
|
||||
assert (not out.blocks_to_copy and not out.blocks_to_swap_in
|
||||
and not out.blocks_to_swap_out)
|
||||
assert len(seq_group_meta) == num_seq_group
|
||||
append_new_token(out, 1)
|
||||
|
||||
|
||||
def test_scheduler_prefill_prioritized():
|
||||
"""Verify running batched tokens are not applied to prefill requests."""
|
||||
block_size = 4
|
||||
max_model_len = 30
|
||||
max_batched_num_tokens = 30
|
||||
scheduler_config = SchedulerConfig(max_batched_num_tokens, 2,
|
||||
max_model_len)
|
||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||
cache_config.num_cpu_blocks = 2
|
||||
cache_config.num_gpu_blocks = 2
|
||||
scheduler = Scheduler(scheduler_config, cache_config, None)
|
||||
|
||||
# Add seq groups to scheduler.
|
||||
_, seq_group_a = create_dummy_prompt("1", 1)
|
||||
scheduler.add_seq_group(seq_group_a)
|
||||
|
||||
# Schedule seq groups prompts.
|
||||
_, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert get_sequence_groups(out) == [seq_group_a]
|
||||
|
||||
# Add a new prefill request B.
|
||||
_, seq_group_b = create_dummy_prompt("2", 30)
|
||||
scheduler.add_seq_group(seq_group_b)
|
||||
|
||||
# Verify prefill requests are prioritized. Since max_batched_num_tokens
|
||||
# is 1, new prefill request has to be scheduled first.
|
||||
_, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert get_sequence_groups(out) == [seq_group_b]
|
||||
|
||||
|
||||
def test_scheduler_schedule_preempt_abort():
|
||||
@ -103,7 +160,7 @@ def test_scheduler_schedule_preempt_abort():
|
||||
scheduler.add_seq_group(seq_group_b)
|
||||
|
||||
# Schedule seq groups prompts.
|
||||
seq_group_meta, out = scheduler.schedule()
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert get_sequence_groups(out) == [seq_group_a, seq_group_b]
|
||||
assert out.num_batched_tokens == block_size * 2 # seq_a and seq_b
|
||||
assert (not out.blocks_to_copy and not out.blocks_to_swap_in
|
||||
@ -113,12 +170,10 @@ def test_scheduler_schedule_preempt_abort():
|
||||
|
||||
# Append "generated" tokens, allowing the sequence to mark prompt tokens as
|
||||
# processed.
|
||||
token_id = 0
|
||||
seq_a.append_token_id(token_id, {token_id: Logprob(0.0)})
|
||||
seq_b.append_token_id(token_id, {token_id: Logprob(0.0)})
|
||||
append_new_token(out, 1)
|
||||
|
||||
# Schedule seq groups generation and preempt seq group b.
|
||||
seq_group_meta, out = scheduler.schedule()
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert get_sequence_groups(out) == [seq_group_a]
|
||||
assert out.num_batched_tokens == 1
|
||||
assert (not out.blocks_to_copy and not out.blocks_to_swap_in
|
||||
@ -128,7 +183,7 @@ def test_scheduler_schedule_preempt_abort():
|
||||
|
||||
# Abort seq group a. Re-schedule seq group b prompt with recomputation.
|
||||
scheduler.abort_seq_group("1")
|
||||
seq_group_meta, out = scheduler.schedule()
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert get_sequence_groups(out) == [seq_group_b]
|
||||
assert out.num_batched_tokens == 5 # 4 prompt + 1 generation.
|
||||
assert (not out.blocks_to_copy and not out.blocks_to_swap_in
|
||||
@ -158,12 +213,14 @@ def test_scheduler_max_seqs():
|
||||
scheduler.add_seq_group(all_seq_groups[0])
|
||||
|
||||
# Schedule seq groups prompts.
|
||||
_, out = scheduler.schedule()
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert set(get_sequence_groups(out)) == set([all_seq_groups[0]])
|
||||
append_new_token(out, 1)
|
||||
|
||||
# Schedule seq groups generation.
|
||||
_, out = scheduler.schedule()
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert set(get_sequence_groups(out)) == set([all_seq_groups[0]])
|
||||
append_new_token(out, 1)
|
||||
|
||||
# Append 2 more seq group
|
||||
scheduler.add_seq_group(all_seq_groups[1])
|
||||
@ -172,12 +229,11 @@ def test_scheduler_max_seqs():
|
||||
# Schedule seq groups prompts.
|
||||
# Only 1 seq group should be scheduled since max_seq_group is 2
|
||||
# and one is prompting.
|
||||
_, out = scheduler.schedule()
|
||||
_, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert set(get_sequence_groups(out)) == set([all_seq_groups[1]])
|
||||
|
||||
|
||||
def test_scheduler_delay_factor():
|
||||
|
||||
block_size = 4
|
||||
scheduler_config = SchedulerConfig(100, 64, 16, delay_factor=0.5)
|
||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||
@ -186,24 +242,630 @@ def test_scheduler_delay_factor():
|
||||
scheduler = Scheduler(scheduler_config, cache_config, None)
|
||||
|
||||
# schedule first prompt
|
||||
_, seq_group = create_dummy_prompt("0", prompt_length=block_size)
|
||||
seq_group_meta, seq_group = create_dummy_prompt("0",
|
||||
prompt_length=block_size)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
seq_group_meta, out = scheduler.schedule()
|
||||
assert out.prompt_run
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert out.num_prefill_groups > 0
|
||||
assert seq_group_meta[0].request_id == '0'
|
||||
append_new_token(out, 1)
|
||||
|
||||
# wait for a second before scheduling next prompt
|
||||
time.sleep(1)
|
||||
_, seq_group = create_dummy_prompt("1", prompt_length=block_size)
|
||||
seq_group_meta, seq_group = create_dummy_prompt("1",
|
||||
prompt_length=block_size)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
|
||||
# second prompt should *not* be scheduled
|
||||
seq_group_meta, out = scheduler.schedule()
|
||||
assert not out.prompt_run
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert out.num_prefill_groups == 0
|
||||
assert seq_group_meta[0].request_id == '0'
|
||||
append_new_token(out, 1)
|
||||
|
||||
# wait for more than 0.5 second and try again
|
||||
time.sleep(0.6)
|
||||
seq_group_meta, out = scheduler.schedule()
|
||||
assert out.prompt_run
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert out.num_prefill_groups > 0
|
||||
assert seq_group_meta[0].request_id == '1'
|
||||
append_new_token(out, 1)
|
||||
|
||||
|
||||
def test_swapped_out_prioritized():
|
||||
scheduler = initialize_scheduler(max_num_seqs=6)
|
||||
# best_of=2 * 3 == 6 sequences.
|
||||
for i in range(3):
|
||||
_, seq_group = create_dummy_prompt(str(i), prompt_length=60, best_of=2)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
# prefill scheduled now.
|
||||
assert len(out.scheduled_seq_groups) == 3
|
||||
append_new_token(out, 1)
|
||||
|
||||
# The last request should be swapped out.
|
||||
scheduler.block_manager.can_append_slots = MagicMock()
|
||||
|
||||
def cannot_append_second_group(seq_group, num_lookahead_slots):
|
||||
return seq_group.request_id != "2"
|
||||
|
||||
scheduler.block_manager.can_append_slots.side_effect = (
|
||||
cannot_append_second_group)
|
||||
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert len(out.scheduled_seq_groups) == 2
|
||||
assert out.num_batched_tokens == 2
|
||||
assert out.blocks_to_swap_out != {}
|
||||
assert out.blocks_to_swap_in == {}
|
||||
append_new_token(out, 1)
|
||||
|
||||
# Add 1 more task. Swap should be prioritized over prefill.
|
||||
_, seq_group = create_dummy_prompt(str(i), prompt_length=60, best_of=2)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
append_new_token(out, 1)
|
||||
assert len(out.scheduled_seq_groups) == 3
|
||||
# 3 decodes. It is swapped in.
|
||||
assert out.num_batched_tokens == 3
|
||||
assert out.blocks_to_swap_in != {}
|
||||
assert out.blocks_to_swap_out == {}
|
||||
|
||||
|
||||
def initialize_scheduler(*,
|
||||
max_num_seqs=1000,
|
||||
max_token_budget=1000,
|
||||
max_model_len=1000,
|
||||
lora_config=None):
|
||||
block_size = 4
|
||||
scheduler_config = SchedulerConfig(max_token_budget, max_num_seqs,
|
||||
max_model_len)
|
||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||
cache_config.num_cpu_blocks = 8
|
||||
cache_config.num_gpu_blocks = 8
|
||||
scheduler = Scheduler(scheduler_config, cache_config, lora_config)
|
||||
return scheduler
|
||||
|
||||
|
||||
def create_token_budget(token_budget: int = 10000,
|
||||
max_num_seqs: int = 10000) -> SchedulingBudget:
|
||||
return SchedulingBudget(
|
||||
token_budget=token_budget,
|
||||
max_num_seqs=max_num_seqs,
|
||||
)
|
||||
|
||||
|
||||
def add_token_budget(budget: SchedulingBudget,
|
||||
num_batched_tokens: int = 0,
|
||||
num_curr_seqs: int = 0):
|
||||
mock_seq_group = create_dummy_prompt('10', prompt_length=60)[1]
|
||||
budget.add_num_batched_tokens(mock_seq_group.request_id,
|
||||
num_batched_tokens)
|
||||
budget.add_num_seqs(mock_seq_group.request_id, num_curr_seqs)
|
||||
|
||||
|
||||
def test_prefill_schedule_max_prompt_len():
|
||||
"""
|
||||
Test prompt longer than max_prompt_len is aborted.
|
||||
"""
|
||||
scheduler = initialize_scheduler(max_model_len=30)
|
||||
_, seq_group = create_dummy_prompt(0, prompt_length=60)
|
||||
waiting = deque([seq_group])
|
||||
budget = create_token_budget()
|
||||
remaining_waiting, output = scheduler._schedule_prefills(
|
||||
waiting, budget, None)
|
||||
assert len(output.ignored_seq_groups) == 1
|
||||
assert len(output.seq_groups) == 0
|
||||
assert budget.num_batched_tokens == 0
|
||||
assert budget.num_curr_seqs == 0
|
||||
assert len(remaining_waiting) == 0
|
||||
|
||||
|
||||
def test_prefill_schedule_token_budget():
|
||||
"""
|
||||
Test token budget respected.
|
||||
"""
|
||||
scheduler = initialize_scheduler()
|
||||
waiting = deque()
|
||||
budget = create_token_budget(token_budget=0)
|
||||
for i in range(2):
|
||||
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
|
||||
waiting.append(seq_group)
|
||||
|
||||
# 0 token budget == nothing is scheduled.
|
||||
remaining_waiting, output = scheduler._schedule_prefills(
|
||||
waiting, budget, None)
|
||||
assert len(output.ignored_seq_groups) == 0
|
||||
assert len(output.seq_groups) == 0
|
||||
assert budget.num_batched_tokens == 0
|
||||
assert budget.num_curr_seqs == 0
|
||||
assert len(remaining_waiting) == 2
|
||||
|
||||
# 60 token budget == 1 request scheduled.
|
||||
budget = create_token_budget(token_budget=60)
|
||||
remaining_waiting, output = scheduler._schedule_prefills(
|
||||
waiting, budget, None)
|
||||
assert len(output.ignored_seq_groups) == 0
|
||||
assert len(output.seq_groups) == 1
|
||||
assert budget.num_batched_tokens == 60
|
||||
assert budget.num_curr_seqs == 1
|
||||
assert len(remaining_waiting) == 1
|
||||
|
||||
# Test when current_batched_tokens respected.
|
||||
scheduler = initialize_scheduler()
|
||||
waiting = deque()
|
||||
budget = create_token_budget(token_budget=60)
|
||||
add_token_budget(budget, 30, 0)
|
||||
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
|
||||
# Cannot schedule a prompt that doesn't fit the budget.
|
||||
waiting.append(seq_group)
|
||||
remaining_waiting, output = scheduler._schedule_prefills(
|
||||
waiting, budget, None)
|
||||
assert len(output.ignored_seq_groups) == 0
|
||||
assert len(output.seq_groups) == 0
|
||||
assert budget.num_batched_tokens == 30
|
||||
assert budget.num_curr_seqs == 0
|
||||
assert len(remaining_waiting) == 1
|
||||
budget = create_token_budget(token_budget=90)
|
||||
add_token_budget(budget, 30, 0)
|
||||
remaining_waiting, output = scheduler._schedule_prefills(
|
||||
waiting, budget, None)
|
||||
assert len(output.seq_groups) == 1
|
||||
assert budget.num_batched_tokens == 90
|
||||
assert budget.num_curr_seqs == 1
|
||||
assert len(remaining_waiting) == 0
|
||||
|
||||
|
||||
def test_prefill_schedule_max_seqs():
|
||||
"""
|
||||
Test max seq respected.
|
||||
"""
|
||||
scheduler = initialize_scheduler()
|
||||
waiting = deque()
|
||||
budget = create_token_budget(max_num_seqs=2)
|
||||
for i in range(3):
|
||||
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
|
||||
waiting.append(seq_group)
|
||||
remaining_waiting, output = scheduler._schedule_prefills(
|
||||
waiting, budget, None)
|
||||
assert len(output.ignored_seq_groups) == 0
|
||||
assert len(output.seq_groups) == 2
|
||||
assert budget.num_batched_tokens == 120
|
||||
assert budget.num_curr_seqs == 2
|
||||
assert len(remaining_waiting) == 1
|
||||
|
||||
# Verify curr_num_seqs respected.
|
||||
waiting = deque()
|
||||
budget = create_token_budget(max_num_seqs=2)
|
||||
add_token_budget(budget, 0, 2)
|
||||
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
|
||||
waiting.append(seq_group)
|
||||
remaining_waiting, output = scheduler._schedule_prefills(
|
||||
waiting, budget, None)
|
||||
assert len(output.ignored_seq_groups) == 0
|
||||
assert len(output.seq_groups) == 0
|
||||
assert budget.num_batched_tokens == 0
|
||||
assert budget.num_curr_seqs == 2
|
||||
assert len(remaining_waiting) == 1
|
||||
|
||||
|
||||
def test_prefill_schedule_max_lora():
|
||||
"""
|
||||
Test max lora is respected and prioritized.
|
||||
"""
|
||||
lora_config = LoRAConfig(max_lora_rank=8, max_loras=1)
|
||||
scheduler = initialize_scheduler(lora_config=lora_config)
|
||||
waiting = deque()
|
||||
budget = create_token_budget(token_budget=120)
|
||||
curr_loras = set()
|
||||
for i in range(2):
|
||||
_, seq_group = create_dummy_prompt(str(i),
|
||||
prompt_length=60,
|
||||
lora_request=LoRARequest(
|
||||
lora_name=str(i),
|
||||
lora_int_id=i + 1,
|
||||
lora_local_path="abc"))
|
||||
waiting.append(seq_group)
|
||||
# Add two more requests to verify lora is prioritized.
|
||||
# 0: Lora, 1: Lora, 2: regular, 3: regular
|
||||
# In the first iteration, index 0, 2 is scheduled.
|
||||
# If a request is not scheduled because it hits max lora, it is
|
||||
# prioritized. Verify that.
|
||||
for i in range(2, 4):
|
||||
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
|
||||
waiting.append(seq_group)
|
||||
# Schedule 2 requests (0 and 2)
|
||||
remaining_waiting, output = scheduler._schedule_prefills(
|
||||
waiting, budget, curr_loras)
|
||||
assert len(output.ignored_seq_groups) == 0
|
||||
assert len(output.seq_groups) == 2
|
||||
assert budget.num_batched_tokens == 120
|
||||
assert budget.num_curr_seqs == 2
|
||||
assert len(remaining_waiting) == 2
|
||||
assert len(curr_loras) == 1
|
||||
# The second lora request is scheduled next as FCFS policy.
|
||||
# Reset curr_loras so that it can be scheduled.
|
||||
curr_loras = set()
|
||||
budget = create_token_budget(token_budget=60)
|
||||
remaining_waiting, output = scheduler._schedule_prefills(
|
||||
remaining_waiting, budget, curr_loras)
|
||||
assert len(output.seq_groups) == 1
|
||||
assert output.seq_groups[0].seq_group.request_id == "1"
|
||||
assert len(remaining_waiting) == 1
|
||||
assert len(curr_loras) == 1
|
||||
assert budget.num_batched_tokens == 60
|
||||
|
||||
|
||||
def test_prefill_schedule_no_block_manager_capacity():
|
||||
"""
|
||||
Test sequence cannot be scheduled due to block manager has no capacity.
|
||||
"""
|
||||
scheduler = initialize_scheduler()
|
||||
waiting = deque()
|
||||
budget = create_token_budget()
|
||||
for i in range(3):
|
||||
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
|
||||
waiting.append(seq_group)
|
||||
scheduler.block_manager.can_allocate = MagicMock()
|
||||
scheduler.block_manager.can_allocate.return_value = AllocStatus.LATER
|
||||
remainig_waiting, output = scheduler._schedule_prefills(
|
||||
waiting, budget, None)
|
||||
assert len(output.ignored_seq_groups) == 0
|
||||
assert len(output.seq_groups) == 0
|
||||
assert budget.num_batched_tokens == 0
|
||||
assert budget.num_curr_seqs == 0
|
||||
assert len(remainig_waiting) == 3
|
||||
|
||||
scheduler = initialize_scheduler()
|
||||
waiting = deque()
|
||||
budget = create_token_budget()
|
||||
for i in range(3):
|
||||
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
|
||||
waiting.append(seq_group)
|
||||
scheduler.block_manager.can_allocate = MagicMock()
|
||||
scheduler.block_manager.can_allocate.return_value = AllocStatus.NEVER
|
||||
remaining_waiting, output = scheduler._schedule_prefills(
|
||||
waiting, budget, None)
|
||||
assert len(output.ignored_seq_groups) == 3
|
||||
assert len(output.seq_groups) == 0
|
||||
assert budget.num_batched_tokens == 0
|
||||
assert budget.num_curr_seqs == 0
|
||||
assert len(remaining_waiting) == 0
|
||||
|
||||
|
||||
def test_decode_schedule_preempted():
|
||||
"""
|
||||
Test decodes cannot be scheduled and preempted.
|
||||
"""
|
||||
scheduler = initialize_scheduler()
|
||||
running = deque()
|
||||
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||
curr_loras = None
|
||||
for i in range(3):
|
||||
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
|
||||
scheduler._allocate_and_set_running(seq_group, 60)
|
||||
append_new_token_seq_group(60, seq_group, 1)
|
||||
running.append(seq_group)
|
||||
scheduler.block_manager.can_append_slots = MagicMock()
|
||||
|
||||
def cannot_append_second_group(seq_group, num_lookahead_slots):
|
||||
return seq_group.request_id != "1"
|
||||
|
||||
scheduler.block_manager.can_append_slots.side_effect = (
|
||||
cannot_append_second_group)
|
||||
|
||||
# 1 cannot be scheduled, and the lowest priority (request 2)
|
||||
# should be preempted. 1 will also be preempted.
|
||||
budget = create_token_budget()
|
||||
remainig_running, output = scheduler._schedule_running(
|
||||
running, budget, curr_loras, policy)
|
||||
assert len(remainig_running) == 0
|
||||
assert len(output.decode_seq_groups) == 1
|
||||
assert len(output.prefill_seq_groups) == 0
|
||||
assert output.decode_seq_groups[0].seq_group.request_id == "0"
|
||||
assert len(output.preempted) == 2
|
||||
# Verify budgets are updated.
|
||||
assert budget.num_batched_tokens == 1
|
||||
assert budget.num_curr_seqs == 1
|
||||
# Both should be preempted, not swapped.
|
||||
assert output.blocks_to_swap_out == {}
|
||||
# Nothing is copied.
|
||||
assert output.blocks_to_copy == {}
|
||||
|
||||
|
||||
def test_decode_swap_beam_search():
|
||||
"""
|
||||
Test best_of > 1 swap out blocks
|
||||
"""
|
||||
scheduler = initialize_scheduler()
|
||||
running = deque()
|
||||
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||
curr_loras = None
|
||||
budget = create_token_budget()
|
||||
for i in range(3):
|
||||
_, seq_group = create_dummy_prompt(str(i), prompt_length=60, best_of=2)
|
||||
scheduler._allocate_and_set_running(seq_group, 60)
|
||||
running.append(seq_group)
|
||||
append_new_token_seq_group(60, seq_group, 1)
|
||||
budget.add_num_seqs(seq_group.request_id,
|
||||
seq_group.get_max_num_running_seqs())
|
||||
budget.add_num_batched_tokens(
|
||||
seq_group.request_id, seq_group.num_seqs(SequenceStatus.RUNNING))
|
||||
|
||||
# The last request should be swapped out.
|
||||
scheduler.block_manager.can_append_slots = MagicMock()
|
||||
|
||||
def cannot_append_second_group(seq_group, num_lookahead_slots):
|
||||
return seq_group.request_id != "2"
|
||||
|
||||
scheduler.block_manager.can_append_slots.side_effect = (
|
||||
cannot_append_second_group)
|
||||
scheduler.block_manager.swap_out = MagicMock()
|
||||
expected_swap_mapping = {"5": "7"}
|
||||
scheduler.block_manager.swap_out.return_value = expected_swap_mapping
|
||||
|
||||
remainig_running, output = scheduler._schedule_running(
|
||||
running, budget, curr_loras, policy)
|
||||
assert len(remainig_running) == 0
|
||||
assert len(output.decode_seq_groups) == 2
|
||||
assert len(output.prefill_seq_groups) == 0
|
||||
assert output.decode_seq_groups[0].seq_group.request_id == "0"
|
||||
assert output.decode_seq_groups[1].seq_group.request_id == "1"
|
||||
assert len(output.preempted) == 0
|
||||
assert len(output.swapped_out) == 1
|
||||
# Budget should refledct preempted requests.
|
||||
assert budget.num_batched_tokens == 2
|
||||
# since there are 2 sequences, 2 should be subtracted.
|
||||
assert budget.num_curr_seqs == 4
|
||||
# Both should be preempted, not swapped.
|
||||
assert output.blocks_to_swap_out == expected_swap_mapping
|
||||
# Nothing is copied.
|
||||
assert output.blocks_to_copy == {}
|
||||
|
||||
|
||||
def test_schedule_decode_blocks_to_copy_update():
|
||||
"""
|
||||
Verify blocks_to_copy is updated.
|
||||
"""
|
||||
scheduler = initialize_scheduler()
|
||||
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
|
||||
running = deque()
|
||||
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||
curr_loras = None
|
||||
scheduler._allocate_and_set_running(seq_group, 60)
|
||||
append_new_token_seq_group(60, seq_group, 1)
|
||||
running.append(seq_group)
|
||||
|
||||
# The last request should be swapped out.
|
||||
scheduler.block_manager.append_slots = MagicMock()
|
||||
scheduler.block_manager.append_slots.return_value = {2: [3]}
|
||||
|
||||
budget = create_token_budget()
|
||||
remaining_running, output = scheduler._schedule_running(
|
||||
running, budget, curr_loras, policy)
|
||||
assert len(remaining_running) == 0
|
||||
assert len(output.decode_seq_groups) == 1
|
||||
assert len(output.prefill_seq_groups) == 0
|
||||
assert len(output.preempted) == 0
|
||||
assert len(output.swapped_out) == 0
|
||||
# Nothing is preempted.
|
||||
assert output.blocks_to_swap_out == {}
|
||||
# Since append_slot returns the source -> dist mapping, it should
|
||||
# applied.
|
||||
assert output.blocks_to_copy == {2: [3]}
|
||||
|
||||
|
||||
def test_schedule_swapped_simple():
|
||||
scheduler = initialize_scheduler()
|
||||
swapped = deque()
|
||||
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||
curr_loras = None
|
||||
blocks_to_swap_out = {}
|
||||
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
|
||||
scheduler._allocate_and_set_running(seq_group, 60)
|
||||
append_new_token_seq_group(60, seq_group, 1)
|
||||
scheduler._swap_out(seq_group, blocks_to_swap_out)
|
||||
swapped.append(seq_group)
|
||||
|
||||
budget = create_token_budget()
|
||||
remaining_swapped, output = scheduler._schedule_swapped(
|
||||
swapped, budget, curr_loras, policy)
|
||||
assert len(remaining_swapped) == 0
|
||||
assert budget.num_batched_tokens == 1
|
||||
assert budget.num_curr_seqs == 2
|
||||
assert len(output.decode_seq_groups) == 1
|
||||
assert len(output.prefill_seq_groups) == 0
|
||||
# swap in is the reverse of swap out
|
||||
blocks_to_swap_in_reverse = {}
|
||||
for swapin, swapout in output.blocks_to_swap_in.items():
|
||||
blocks_to_swap_in_reverse[swapout] = swapin
|
||||
assert blocks_to_swap_out == blocks_to_swap_in_reverse
|
||||
|
||||
|
||||
def test_schedule_swapped_max_token_budget():
|
||||
scheduler = initialize_scheduler()
|
||||
swapped = deque()
|
||||
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||
curr_loras = None
|
||||
blocks_to_swap_out = {}
|
||||
for _ in range(2):
|
||||
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
|
||||
scheduler._allocate_and_set_running(seq_group, 60)
|
||||
append_new_token_seq_group(60, seq_group, 1)
|
||||
scheduler._swap_out(seq_group, blocks_to_swap_out)
|
||||
swapped.append(seq_group)
|
||||
|
||||
budget = create_token_budget(token_budget=1)
|
||||
remaining_swapped, output = scheduler._schedule_swapped(
|
||||
swapped, budget, curr_loras, policy)
|
||||
assert len(remaining_swapped) == 1
|
||||
assert budget.num_batched_tokens == 1
|
||||
assert budget.num_curr_seqs == 2
|
||||
assert len(output.decode_seq_groups) == 1
|
||||
assert len(output.prefill_seq_groups) == 0
|
||||
|
||||
# Verify num_batched_tokens are respected.
|
||||
budget = create_token_budget(token_budget=1)
|
||||
add_token_budget(budget, 1, 0)
|
||||
remaining_swapped, output = scheduler._schedule_swapped(
|
||||
remaining_swapped, budget, curr_loras, policy)
|
||||
assert len(remaining_swapped) == 1
|
||||
assert budget.num_batched_tokens == 1
|
||||
assert budget.num_curr_seqs == 0
|
||||
assert len(output.decode_seq_groups) == 0
|
||||
assert len(output.prefill_seq_groups) == 0
|
||||
|
||||
|
||||
def test_schedule_swapped_max_seqs():
|
||||
scheduler = initialize_scheduler()
|
||||
swapped = deque()
|
||||
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||
curr_loras = None
|
||||
blocks_to_swap_out = {}
|
||||
for i in range(4):
|
||||
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
|
||||
scheduler._allocate_and_set_running(seq_group, 60)
|
||||
append_new_token_seq_group(60, seq_group, 1)
|
||||
scheduler._swap_out(seq_group, blocks_to_swap_out)
|
||||
swapped.append(seq_group)
|
||||
|
||||
budget = create_token_budget(max_num_seqs=2)
|
||||
remaining_swapped, output = scheduler._schedule_swapped(
|
||||
swapped, budget, curr_loras, policy)
|
||||
assert len(remaining_swapped) == 2
|
||||
assert budget.num_batched_tokens == 2
|
||||
assert budget.num_curr_seqs == 2
|
||||
assert len(output.decode_seq_groups) == 2
|
||||
assert len(output.prefill_seq_groups) == 0
|
||||
|
||||
# Verify num_curr_seqs are respected.
|
||||
remaining_swapped, output = scheduler._schedule_swapped(
|
||||
remaining_swapped, budget, curr_loras, policy)
|
||||
assert len(remaining_swapped) == 2
|
||||
assert budget.num_batched_tokens == 2
|
||||
assert budget.num_curr_seqs == 2
|
||||
assert len(output.decode_seq_groups) == 0
|
||||
assert len(output.prefill_seq_groups) == 0
|
||||
|
||||
|
||||
def test_schedule_swapped_max_loras():
|
||||
lora_config = LoRAConfig(max_lora_rank=8, max_loras=1)
|
||||
scheduler = initialize_scheduler(lora_config=lora_config)
|
||||
swapped = deque()
|
||||
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||
curr_loras = set()
|
||||
blocks_to_swap_out = {}
|
||||
for i in range(2):
|
||||
_, seq_group = create_dummy_prompt(str(i),
|
||||
prompt_length=60,
|
||||
lora_request=LoRARequest(
|
||||
lora_name=str(i),
|
||||
lora_int_id=i + 1,
|
||||
lora_local_path="abc"))
|
||||
scheduler._allocate_and_set_running(seq_group, 60)
|
||||
append_new_token_seq_group(60, seq_group, 1)
|
||||
scheduler._swap_out(seq_group, blocks_to_swap_out)
|
||||
swapped.append(seq_group)
|
||||
|
||||
budget = create_token_budget()
|
||||
remaining_swapped, output = scheduler._schedule_swapped(
|
||||
swapped, budget, curr_loras, policy)
|
||||
assert len(remaining_swapped) == 1
|
||||
assert budget.num_batched_tokens == 1
|
||||
assert budget.num_curr_seqs == 1
|
||||
assert len(output.decode_seq_groups) == 1
|
||||
assert len(output.prefill_seq_groups) == 0
|
||||
assert len(curr_loras) == 1
|
||||
|
||||
|
||||
def test_schedule_swapped_cannot_swap_in():
|
||||
scheduler = initialize_scheduler()
|
||||
swapped = deque()
|
||||
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||
curr_loras = None
|
||||
blocks_to_swap_out = {}
|
||||
for _ in range(2):
|
||||
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
|
||||
scheduler._allocate_and_set_running(seq_group, 60)
|
||||
append_new_token_seq_group(60, seq_group, 1)
|
||||
scheduler._swap_out(seq_group, blocks_to_swap_out)
|
||||
swapped.append(seq_group)
|
||||
|
||||
# The last request should be swapped out.
|
||||
scheduler.block_manager.can_swap_in = MagicMock()
|
||||
scheduler.block_manager.can_swap_in.return_value = False
|
||||
# Since we cannot swap in, none of the requests are swapped in.
|
||||
budget = create_token_budget()
|
||||
remaining_swapped, output = scheduler._schedule_swapped(
|
||||
swapped, budget, curr_loras, policy)
|
||||
assert len(remaining_swapped) == 2
|
||||
assert budget.num_batched_tokens == 0
|
||||
assert budget.num_curr_seqs == 0
|
||||
assert len(output.decode_seq_groups) == 0
|
||||
assert len(output.prefill_seq_groups) == 0
|
||||
|
||||
|
||||
def test_schedule_swapped_blocks_to_copy():
|
||||
scheduler = initialize_scheduler()
|
||||
swapped = deque()
|
||||
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||
curr_loras = None
|
||||
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
|
||||
scheduler._allocate_and_set_running(seq_group, 60)
|
||||
append_new_token_seq_group(60, seq_group, 1)
|
||||
blocks_to_swap_out = {}
|
||||
scheduler._swap_out(seq_group, blocks_to_swap_out)
|
||||
swapped.append(seq_group)
|
||||
|
||||
# The last request should be swapped out.
|
||||
scheduler.block_manager.append_slots = MagicMock()
|
||||
scheduler.block_manager.append_slots.return_value = {2: [3]}
|
||||
|
||||
budget = create_token_budget()
|
||||
remaining_swapped, output = scheduler._schedule_swapped(
|
||||
swapped, budget, curr_loras, policy)
|
||||
assert len(remaining_swapped) == 0
|
||||
assert len(output.decode_seq_groups) == 1
|
||||
assert len(output.prefill_seq_groups) == 0
|
||||
assert output.blocks_to_copy == {2: [3]}
|
||||
|
||||
|
||||
def test_scheduling_budget():
|
||||
TOKEN_BUDGET = 4
|
||||
MAX_SEQS = 4
|
||||
budget = SchedulingBudget(token_budget=TOKEN_BUDGET, max_num_seqs=MAX_SEQS)
|
||||
assert budget.can_schedule(num_new_tokens=1, num_new_seqs=1)
|
||||
assert budget.can_schedule(num_new_tokens=4, num_new_seqs=4)
|
||||
assert not budget.can_schedule(num_new_tokens=1, num_new_seqs=5)
|
||||
assert not budget.can_schedule(num_new_tokens=5, num_new_seqs=1)
|
||||
assert not budget.can_schedule(num_new_tokens=5, num_new_seqs=5)
|
||||
assert budget.remaining_token_budget() == TOKEN_BUDGET
|
||||
|
||||
# Verify add/subtract num batched tokens.
|
||||
_, seq_group = create_dummy_prompt("1", 3)
|
||||
budget.add_num_batched_tokens(seq_group.request_id, 2)
|
||||
assert budget.remaining_token_budget() == 2
|
||||
assert budget.num_batched_tokens == 2
|
||||
assert budget.can_schedule(num_new_tokens=2, num_new_seqs=1)
|
||||
assert not budget.can_schedule(num_new_tokens=3, num_new_seqs=1)
|
||||
# Verify adding another seq group is no-op.
|
||||
budget.add_num_batched_tokens(seq_group.request_id, 2)
|
||||
assert budget.remaining_token_budget() == 2
|
||||
assert budget.num_batched_tokens == 2
|
||||
budget.subtract_num_batched_tokens(seq_group.request_id, 2)
|
||||
assert budget.remaining_token_budget() == 4
|
||||
assert budget.num_batched_tokens == 0
|
||||
budget.subtract_num_batched_tokens(seq_group.request_id, 2)
|
||||
assert budget.remaining_token_budget() == 4
|
||||
assert budget.num_batched_tokens == 0
|
||||
|
||||
# Verify add/subtract max seqs.
|
||||
_, seq_group = create_dummy_prompt("1", 3)
|
||||
budget.add_num_seqs(seq_group.request_id, 2)
|
||||
assert budget.can_schedule(num_new_tokens=1, num_new_seqs=2)
|
||||
assert not budget.can_schedule(num_new_tokens=1, num_new_seqs=3)
|
||||
assert budget.num_curr_seqs == 2
|
||||
# Verify adding another seq group is no-op.
|
||||
budget.add_num_seqs(seq_group.request_id, 2)
|
||||
assert budget.num_curr_seqs == 2
|
||||
budget.subtract_num_seqs(seq_group.request_id, 2)
|
||||
assert budget.num_curr_seqs == 0
|
||||
budget.subtract_num_seqs(seq_group.request_id, 2)
|
||||
assert budget.num_curr_seqs == 0
|
||||
|
@ -1,14 +1,19 @@
|
||||
import time
|
||||
from typing import Tuple
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sequence import Logprob, Sequence, SequenceGroup
|
||||
|
||||
|
||||
def create_dummy_prompt(
|
||||
request_id: str,
|
||||
prompt_length: int,
|
||||
block_size: int = None) -> Tuple[Sequence, SequenceGroup]:
|
||||
block_size: Optional[int] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
use_beam_search: bool = False,
|
||||
best_of: int = 1,
|
||||
) -> Tuple[Sequence, SequenceGroup]:
|
||||
if not block_size:
|
||||
block_size = prompt_length
|
||||
|
||||
@ -17,8 +22,10 @@ def create_dummy_prompt(
|
||||
prompt_tokens = list(range(prompt_length))
|
||||
prompt_str = " ".join([str(t) for t in prompt_tokens])
|
||||
prompt = Sequence(int(request_id), prompt_str, prompt_tokens, block_size)
|
||||
seq_group = SequenceGroup(request_id, [prompt], SamplingParams(),
|
||||
time.time(), None)
|
||||
seq_group = SequenceGroup(
|
||||
request_id, [prompt],
|
||||
SamplingParams(use_beam_search=use_beam_search, best_of=best_of),
|
||||
time.time(), lora_request)
|
||||
|
||||
return prompt, seq_group
|
||||
|
||||
|
@ -33,11 +33,16 @@ def test_models(
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
) -> None:
|
||||
|
||||
hf_model = hf_runner(model, dtype=dtype)
|
||||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||
del hf_model
|
||||
|
||||
vllm_model = vllm_runner(model, dtype=dtype, tensor_parallel_size=2)
|
||||
vllm_model = vllm_runner(
|
||||
model,
|
||||
dtype=dtype,
|
||||
tensor_parallel_size=2,
|
||||
)
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
del vllm_model
|
||||
|
||||
|
66
tests/distributed/test_chunked_prefill_distributed.py
Normal file
66
tests/distributed/test_chunked_prefill_distributed.py
Normal file
@ -0,0 +1,66 @@
|
||||
"""Compare the outputs of HF and distributed vLLM when using greedy sampling.
|
||||
vLLM will allocate all the available memory, so we need to run the tests one
|
||||
by one. The solution is to pass arguments (model name) by environment
|
||||
variables.
|
||||
|
||||
Run:
|
||||
```sh
|
||||
TEST_DIST_MODEL=facebook/opt-125m pytest \
|
||||
test_chunked_prefill_distributed.py
|
||||
TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf \
|
||||
test_chunked_prefill_distributed.py
|
||||
```
|
||||
"""
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
MODELS = [
|
||||
os.environ["TEST_DIST_MODEL"],
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||
reason="Need at least 2 GPUs to run the test.")
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [5])
|
||||
@pytest.mark.parametrize("chunked_prefill_token_size", [16])
|
||||
def test_models(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
chunked_prefill_token_size: int,
|
||||
) -> None:
|
||||
# Add a chunked prefill config.
|
||||
max_num_seqs = min(chunked_prefill_token_size, 256)
|
||||
assert chunked_prefill_token_size != -1
|
||||
enable_chunked_prefill = True
|
||||
max_num_batched_tokens = chunked_prefill_token_size
|
||||
|
||||
hf_model = hf_runner(model, dtype=dtype)
|
||||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||
del hf_model
|
||||
|
||||
vllm_model = vllm_runner(
|
||||
model,
|
||||
dtype=dtype,
|
||||
tensor_parallel_size=2,
|
||||
max_num_seqs=max_num_seqs,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
)
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
del vllm_model
|
||||
|
||||
for i in range(len(example_prompts)):
|
||||
hf_output_ids, hf_output_str = hf_outputs[i]
|
||||
vllm_output_ids, vllm_output_str = vllm_outputs[i]
|
||||
assert hf_output_str == vllm_output_str, (
|
||||
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
|
||||
assert hf_output_ids == vllm_output_ids, (
|
||||
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
|
@ -8,8 +8,8 @@ import pytest
|
||||
import ray
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.parallel_utils.communication_op import (
|
||||
broadcast_tensor_dict, tensor_model_parallel_all_gather,
|
||||
from vllm.distributed import (broadcast_tensor_dict,
|
||||
tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.test_utils import (init_test_distributed_environment,
|
||||
multi_process_tensor_parallel)
|
||||
|
@ -6,9 +6,8 @@ import ray
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from vllm.model_executor.parallel_utils import custom_all_reduce as custom_ar
|
||||
from vllm.model_executor.parallel_utils.communication_op import (
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.distributed import tensor_model_parallel_all_reduce
|
||||
from vllm.distributed.device_communicators import custom_all_reduce
|
||||
from vllm.test_utils import (init_test_distributed_environment,
|
||||
multi_process_tensor_parallel)
|
||||
|
||||
@ -26,10 +25,10 @@ def graph_allreduce(world_size, rank, distributed_init_port):
|
||||
init_test_distributed_environment(1, world_size, rank,
|
||||
distributed_init_port)
|
||||
|
||||
custom_ar.init_custom_ar()
|
||||
custom_all_reduce.init_custom_all_reduce()
|
||||
for sz in test_sizes:
|
||||
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
|
||||
with custom_ar.capture():
|
||||
with custom_all_reduce.capture():
|
||||
# use integers so result matches NCCL exactly
|
||||
inp1 = torch.randint(1,
|
||||
16, (sz, ),
|
||||
@ -62,8 +61,8 @@ def eager_allreduce(world_size, rank, distributed_init_port):
|
||||
distributed_init_port)
|
||||
|
||||
sz = 1024
|
||||
custom_ar.init_custom_ar()
|
||||
fa = custom_ar.get_handle()
|
||||
custom_all_reduce.init_custom_all_reduce()
|
||||
fa = custom_all_reduce.get_handle()
|
||||
inp = torch.ones(sz, dtype=torch.float32, device=device)
|
||||
out = fa.all_reduce_unreg(inp)
|
||||
assert torch.allclose(out, inp * world_size)
|
||||
|
@ -4,7 +4,7 @@ import os
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.parallel_utils.pynccl import (NCCLCommunicator,
|
||||
from vllm.distributed.device_communicators.pynccl import (NCCLCommunicator,
|
||||
ncclGetUniqueId)
|
||||
|
||||
|
||||
|
32
tests/engine/test_detokenization.py
Normal file
32
tests/engine/test_detokenization.py
Normal file
@ -0,0 +1,32 @@
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.llm import LLM
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
|
||||
def test_computed_prefix_blocks(model: str):
|
||||
# This test checks if the engine generates completions both with and
|
||||
# without optional detokenization, that detokenization includes text
|
||||
# and no-detokenization doesn't, and that both completions have the same
|
||||
# token_ids.
|
||||
prompt = (
|
||||
"You are a helpful assistant. How do I build a car from cardboard and "
|
||||
"paper clips? Is there an easy to follow video tutorial available "
|
||||
"online for free?")
|
||||
|
||||
llm = LLM(model=model)
|
||||
sampling_params = SamplingParams(max_tokens=10,
|
||||
temperature=0.0,
|
||||
detokenize=False)
|
||||
|
||||
outputs_no_detokenization = llm.generate(prompt,
|
||||
sampling_params)[0].outputs[0]
|
||||
sampling_params.detokenize = True
|
||||
outputs_with_detokenization = llm.generate(prompt,
|
||||
sampling_params)[0].outputs[0]
|
||||
|
||||
assert outputs_no_detokenization.text == ''
|
||||
assert outputs_with_detokenization.text != ''
|
||||
assert outputs_no_detokenization.token_ids == \
|
||||
outputs_with_detokenization.token_ids
|
@ -3,7 +3,7 @@
|
||||
2. One of the provided stop tokens
|
||||
3. The EOS token
|
||||
|
||||
Run `pytest tests/samplers/test_stop_reason.py`.
|
||||
Run `pytest tests/engine/test_stop_reason.py`.
|
||||
"""
|
||||
|
||||
import pytest
|
111
tests/engine/test_stop_strings.py
Normal file
111
tests/engine/test_stop_strings.py
Normal file
@ -0,0 +1,111 @@
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import CompletionOutput, LLMEngine, SamplingParams
|
||||
|
||||
MODEL = "meta-llama/llama-2-7b-hf"
|
||||
MAX_TOKENS = 200
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def vllm_model(vllm_runner):
|
||||
return vllm_runner(MODEL)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_stop_basic(vllm_model):
|
||||
_test_stopping(vllm_model.model.llm_engine,
|
||||
stop=["."],
|
||||
include_in_output=False,
|
||||
expected_output="VLLM is a 100% volunteer organization",
|
||||
expected_reason=".")
|
||||
|
||||
_test_stopping(vllm_model.model.llm_engine,
|
||||
stop=["."],
|
||||
include_in_output=True,
|
||||
expected_output="VLLM is a 100% volunteer organization.",
|
||||
expected_reason=".")
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_stop_multi_tokens(vllm_model):
|
||||
_test_stopping(
|
||||
vllm_model.model.llm_engine,
|
||||
stop=["group of peo", "short"],
|
||||
include_in_output=False,
|
||||
expected_output="VLLM is a 100% volunteer organization. We are a ",
|
||||
expected_reason="group of peo")
|
||||
|
||||
_test_stopping(
|
||||
vllm_model.model.llm_engine,
|
||||
stop=["group of peo", "short"],
|
||||
include_in_output=True,
|
||||
expected_output=
|
||||
"VLLM is a 100% volunteer organization. We are a group of peo",
|
||||
expected_reason="group of peo")
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_stop_partial_token(vllm_model):
|
||||
_test_stopping(vllm_model.model.llm_engine,
|
||||
stop=["gani"],
|
||||
include_in_output=False,
|
||||
expected_output="VLLM is a 100% volunteer or",
|
||||
expected_reason="gani")
|
||||
|
||||
_test_stopping(vllm_model.model.llm_engine,
|
||||
stop=["gani"],
|
||||
include_in_output=True,
|
||||
expected_output="VLLM is a 100% volunteer organi",
|
||||
expected_reason="gani")
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_stop_token_id(vllm_model):
|
||||
# token id 13013 => " organization"
|
||||
|
||||
_test_stopping(vllm_model.model.llm_engine,
|
||||
stop_token_ids=[13013],
|
||||
include_in_output=False,
|
||||
expected_output="VLLM is a 100% volunteer",
|
||||
expected_reason=13013)
|
||||
|
||||
_test_stopping(vllm_model.model.llm_engine,
|
||||
stop_token_ids=[13013],
|
||||
include_in_output=True,
|
||||
expected_output="VLLM is a 100% volunteer organization",
|
||||
expected_reason=13013)
|
||||
|
||||
|
||||
def _test_stopping(llm_engine: LLMEngine,
|
||||
expected_output: str,
|
||||
expected_reason: Any,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop_token_ids: Optional[List[int]] = None,
|
||||
include_in_output: bool = False) -> None:
|
||||
llm_engine.add_request(
|
||||
"id", "A story about vLLM:\n",
|
||||
SamplingParams(
|
||||
temperature=0.0,
|
||||
max_tokens=MAX_TOKENS,
|
||||
stop=stop,
|
||||
stop_token_ids=stop_token_ids,
|
||||
include_stop_str_in_output=include_in_output,
|
||||
), None)
|
||||
|
||||
output: Optional[CompletionOutput] = None
|
||||
output_text = ""
|
||||
stop_reason = None
|
||||
while llm_engine.has_unfinished_requests():
|
||||
(request_output, ) = llm_engine.step()
|
||||
(output, ) = request_output.outputs
|
||||
|
||||
# Ensure we don't backtrack
|
||||
assert output.text.startswith(output_text)
|
||||
output_text = output.text
|
||||
stop_reason = output.stop_reason
|
||||
|
||||
assert output is not None
|
||||
assert output_text == expected_output
|
||||
assert stop_reason == expected_reason
|
@ -1,11 +1,14 @@
|
||||
# This unit test should be moved to a new
|
||||
# tests/test_guided_decoding directory.
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm.model_executor.guided_logits_processors import (JSONLogitsProcessor,
|
||||
RegexLogitsProcessor)
|
||||
from vllm.entrypoints.openai.protocol import CompletionRequest
|
||||
from vllm.model_executor.guided_decoding import (
|
||||
get_guided_decoding_logits_processor)
|
||||
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
|
||||
JSONLogitsProcessor, RegexLogitsProcessor)
|
||||
|
||||
TEST_SCHEMA = {
|
||||
"type": "object",
|
||||
@ -73,3 +76,36 @@ def test_guided_logits_processors():
|
||||
json_LP(token_ids, tensor)
|
||||
assert tensor.shape == original_tensor.shape
|
||||
assert not torch.allclose(tensor, original_tensor)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("backend", ["outlines", "lm-format-enforcer"])
|
||||
async def test_guided_logits_processor_black_box(backend: str):
|
||||
tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta')
|
||||
token_ids = tokenizer.encode(
|
||||
f"Give an example IPv4 address with this regex: {TEST_REGEX}")
|
||||
regex_request = CompletionRequest(model='test',
|
||||
prompt=token_ids,
|
||||
guided_regex=TEST_REGEX)
|
||||
regex_lp = await get_guided_decoding_logits_processor(
|
||||
backend, regex_request, tokenizer)
|
||||
assert regex_lp is not None
|
||||
tensor = torch.rand(32000)
|
||||
original_tensor = torch.clone(tensor)
|
||||
tensor = regex_lp(token_ids, tensor)
|
||||
assert tensor.shape == original_tensor.shape
|
||||
assert not torch.allclose(tensor, original_tensor)
|
||||
|
||||
token_ids = tokenizer.encode(
|
||||
f"Give an employee profile that fits this schema: {TEST_SCHEMA}")
|
||||
json_request = CompletionRequest(model='test',
|
||||
prompt=token_ids,
|
||||
guided_json=TEST_SCHEMA)
|
||||
json_lp = await get_guided_decoding_logits_processor(
|
||||
backend, json_request, tokenizer)
|
||||
assert json_lp is not None
|
||||
tensor = torch.rand(32000)
|
||||
original_tensor = torch.clone(tensor)
|
||||
tensor = json_lp(token_ids, tensor)
|
||||
assert tensor.shape == original_tensor.shape
|
||||
assert not torch.allclose(tensor, original_tensor)
|
||||
|
@ -141,7 +141,7 @@ def server(zephyr_lora_files):
|
||||
"--max-cpu-loras",
|
||||
"2",
|
||||
"--max-num-seqs",
|
||||
"128"
|
||||
"128",
|
||||
])
|
||||
ray.get(server_runner.ready.remote())
|
||||
yield server_runner
|
||||
@ -506,7 +506,10 @@ async def test_logits_bias(server, client: openai.AsyncOpenAI):
|
||||
assert first_response != completion.choices[0].text
|
||||
|
||||
|
||||
async def test_guided_json_completion(server, client: openai.AsyncOpenAI):
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
["outlines", "lm-format-enforcer"])
|
||||
async def test_guided_json_completion(server, client: openai.AsyncOpenAI,
|
||||
guided_decoding_backend: str):
|
||||
completion = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt=f"Give an example JSON for an employee profile "
|
||||
@ -514,7 +517,8 @@ async def test_guided_json_completion(server, client: openai.AsyncOpenAI):
|
||||
n=3,
|
||||
temperature=1.0,
|
||||
max_tokens=500,
|
||||
extra_body=dict(guided_json=TEST_SCHEMA))
|
||||
extra_body=dict(guided_json=TEST_SCHEMA,
|
||||
guided_decoding_backend=guided_decoding_backend))
|
||||
|
||||
assert completion.id is not None
|
||||
assert completion.choices is not None and len(completion.choices) == 3
|
||||
@ -524,7 +528,10 @@ async def test_guided_json_completion(server, client: openai.AsyncOpenAI):
|
||||
jsonschema.validate(instance=output_json, schema=TEST_SCHEMA)
|
||||
|
||||
|
||||
async def test_guided_json_chat(server, client: openai.AsyncOpenAI):
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
["outlines", "lm-format-enforcer"])
|
||||
async def test_guided_json_chat(server, client: openai.AsyncOpenAI,
|
||||
guided_decoding_backend: str):
|
||||
messages = [{
|
||||
"role": "system",
|
||||
"content": "you are a helpful assistant"
|
||||
@ -538,8 +545,9 @@ async def test_guided_json_chat(server, client: openai.AsyncOpenAI):
|
||||
chat_completion = await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=messages,
|
||||
max_tokens=500,
|
||||
extra_body=dict(guided_json=TEST_SCHEMA))
|
||||
max_tokens=1000,
|
||||
extra_body=dict(guided_json=TEST_SCHEMA,
|
||||
guided_decoding_backend=guided_decoding_backend))
|
||||
message = chat_completion.choices[0].message
|
||||
assert message.content is not None
|
||||
json1 = json.loads(message.content)
|
||||
@ -555,8 +563,9 @@ async def test_guided_json_chat(server, client: openai.AsyncOpenAI):
|
||||
chat_completion = await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=messages,
|
||||
max_tokens=500,
|
||||
extra_body=dict(guided_json=TEST_SCHEMA))
|
||||
max_tokens=1000,
|
||||
extra_body=dict(guided_json=TEST_SCHEMA,
|
||||
guided_decoding_backend=guided_decoding_backend))
|
||||
message = chat_completion.choices[0].message
|
||||
assert message.content is not None
|
||||
json2 = json.loads(message.content)
|
||||
@ -565,14 +574,18 @@ async def test_guided_json_chat(server, client: openai.AsyncOpenAI):
|
||||
assert json1["age"] != json2["age"]
|
||||
|
||||
|
||||
async def test_guided_regex_completion(server, client: openai.AsyncOpenAI):
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
["outlines", "lm-format-enforcer"])
|
||||
async def test_guided_regex_completion(server, client: openai.AsyncOpenAI,
|
||||
guided_decoding_backend: str):
|
||||
completion = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt=f"Give an example IPv4 address with this regex: {TEST_REGEX}",
|
||||
n=3,
|
||||
temperature=1.0,
|
||||
max_tokens=20,
|
||||
extra_body=dict(guided_regex=TEST_REGEX))
|
||||
extra_body=dict(guided_regex=TEST_REGEX,
|
||||
guided_decoding_backend=guided_decoding_backend))
|
||||
|
||||
assert completion.id is not None
|
||||
assert completion.choices is not None and len(completion.choices) == 3
|
||||
@ -581,7 +594,10 @@ async def test_guided_regex_completion(server, client: openai.AsyncOpenAI):
|
||||
assert re.fullmatch(TEST_REGEX, completion.choices[i].text) is not None
|
||||
|
||||
|
||||
async def test_guided_regex_chat(server, client: openai.AsyncOpenAI):
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
["outlines", "lm-format-enforcer"])
|
||||
async def test_guided_regex_chat(server, client: openai.AsyncOpenAI,
|
||||
guided_decoding_backend: str):
|
||||
messages = [{
|
||||
"role": "system",
|
||||
"content": "you are a helpful assistant"
|
||||
@ -595,7 +611,8 @@ async def test_guided_regex_chat(server, client: openai.AsyncOpenAI):
|
||||
model=MODEL_NAME,
|
||||
messages=messages,
|
||||
max_tokens=20,
|
||||
extra_body=dict(guided_regex=TEST_REGEX))
|
||||
extra_body=dict(guided_regex=TEST_REGEX,
|
||||
guided_decoding_backend=guided_decoding_backend))
|
||||
ip1 = chat_completion.choices[0].message.content
|
||||
assert ip1 is not None
|
||||
assert re.fullmatch(TEST_REGEX, ip1) is not None
|
||||
@ -606,21 +623,26 @@ async def test_guided_regex_chat(server, client: openai.AsyncOpenAI):
|
||||
model=MODEL_NAME,
|
||||
messages=messages,
|
||||
max_tokens=20,
|
||||
extra_body=dict(guided_regex=TEST_REGEX))
|
||||
extra_body=dict(guided_regex=TEST_REGEX,
|
||||
guided_decoding_backend=guided_decoding_backend))
|
||||
ip2 = chat_completion.choices[0].message.content
|
||||
assert ip2 is not None
|
||||
assert re.fullmatch(TEST_REGEX, ip2) is not None
|
||||
assert ip1 != ip2
|
||||
|
||||
|
||||
async def test_guided_choice_completion(server, client: openai.AsyncOpenAI):
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
["outlines", "lm-format-enforcer"])
|
||||
async def test_guided_choice_completion(server, client: openai.AsyncOpenAI,
|
||||
guided_decoding_backend: str):
|
||||
completion = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt="The best language for type-safe systems programming is ",
|
||||
n=2,
|
||||
temperature=1.0,
|
||||
max_tokens=10,
|
||||
extra_body=dict(guided_choice=TEST_CHOICE))
|
||||
extra_body=dict(guided_choice=TEST_CHOICE,
|
||||
guided_decoding_backend=guided_decoding_backend))
|
||||
|
||||
assert completion.id is not None
|
||||
assert completion.choices is not None and len(completion.choices) == 2
|
||||
@ -628,7 +650,10 @@ async def test_guided_choice_completion(server, client: openai.AsyncOpenAI):
|
||||
assert completion.choices[i].text in TEST_CHOICE
|
||||
|
||||
|
||||
async def test_guided_choice_chat(server, client: openai.AsyncOpenAI):
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
["outlines", "lm-format-enforcer"])
|
||||
async def test_guided_choice_chat(server, client: openai.AsyncOpenAI,
|
||||
guided_decoding_backend: str):
|
||||
messages = [{
|
||||
"role": "system",
|
||||
"content": "you are a helpful assistant"
|
||||
@ -642,7 +667,8 @@ async def test_guided_choice_chat(server, client: openai.AsyncOpenAI):
|
||||
model=MODEL_NAME,
|
||||
messages=messages,
|
||||
max_tokens=10,
|
||||
extra_body=dict(guided_choice=TEST_CHOICE))
|
||||
extra_body=dict(guided_choice=TEST_CHOICE,
|
||||
guided_decoding_backend=guided_decoding_backend))
|
||||
choice1 = chat_completion.choices[0].message.content
|
||||
assert choice1 in TEST_CHOICE
|
||||
|
||||
@ -655,18 +681,23 @@ async def test_guided_choice_chat(server, client: openai.AsyncOpenAI):
|
||||
model=MODEL_NAME,
|
||||
messages=messages,
|
||||
max_tokens=10,
|
||||
extra_body=dict(guided_choice=TEST_CHOICE))
|
||||
extra_body=dict(guided_choice=TEST_CHOICE,
|
||||
guided_decoding_backend=guided_decoding_backend))
|
||||
choice2 = chat_completion.choices[0].message.content
|
||||
assert choice2 in TEST_CHOICE
|
||||
assert choice1 != choice2
|
||||
|
||||
|
||||
async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI):
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
["outlines", "lm-format-enforcer"])
|
||||
async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI,
|
||||
guided_decoding_backend: str):
|
||||
with pytest.raises(openai.BadRequestError):
|
||||
_ = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt="Give an example JSON that fits this schema: 42",
|
||||
extra_body=dict(guided_json=42))
|
||||
extra_body=dict(guided_json=42,
|
||||
guided_decoding_backend=guided_decoding_backend))
|
||||
|
||||
messages = [{
|
||||
"role": "system",
|
||||
@ -742,5 +773,36 @@ number: "1" | "2"
|
||||
assert content.strip() == ground_truth
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
# first test base model, then test loras
|
||||
"model_name",
|
||||
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
|
||||
)
|
||||
async def test_echo_logprob_completion(server, client: openai.AsyncOpenAI,
|
||||
model_name: str):
|
||||
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
|
||||
# test using text and token IDs
|
||||
for prompt in ("Hello, my name is", [0, 0, 0, 0, 0]):
|
||||
completion = await client.completions.create(model=model_name,
|
||||
prompt=prompt,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
echo=True,
|
||||
logprobs=1)
|
||||
|
||||
prompt_text = tokenizer.decode(prompt) if isinstance(prompt,
|
||||
list) else prompt
|
||||
assert (completion.choices[0].text is not None
|
||||
and re.search(r"^" + prompt_text, completion.choices[0].text))
|
||||
logprobs = completion.choices[0].logprobs
|
||||
assert logprobs is not None
|
||||
assert len(logprobs.text_offset) > 5
|
||||
assert (len(logprobs.token_logprobs) > 5
|
||||
and logprobs.token_logprobs[0] is None)
|
||||
assert (len(logprobs.top_logprobs) > 5
|
||||
and logprobs.top_logprobs[0] is None)
|
||||
assert len(logprobs.tokens) > 5
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
|
66
tests/entrypoints/test_server_oot_registration.py
Normal file
66
tests/entrypoints/test_server_oot_registration.py
Normal file
@ -0,0 +1,66 @@
|
||||
import multiprocessing
|
||||
import sys
|
||||
import time
|
||||
|
||||
import torch
|
||||
from openai import OpenAI, OpenAIError
|
||||
|
||||
from vllm import ModelRegistry
|
||||
from vllm.model_executor.models.opt import OPTForCausalLM
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.utils import get_open_port
|
||||
|
||||
|
||||
class MyOPTForCausalLM(OPTForCausalLM):
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
# this dummy model always predicts the first token
|
||||
logits = super().compute_logits(hidden_states, sampling_metadata)
|
||||
logits.zero_()
|
||||
logits[:, 0] += 1.0
|
||||
return logits
|
||||
|
||||
|
||||
def server_function(port):
|
||||
# register our dummy model
|
||||
ModelRegistry.register_model("OPTForCausalLM", MyOPTForCausalLM)
|
||||
sys.argv = ["placeholder.py"] + \
|
||||
("--model facebook/opt-125m --dtype"
|
||||
f" float32 --api-key token-abc123 --port {port}").split()
|
||||
import runpy
|
||||
runpy.run_module('vllm.entrypoints.openai.api_server', run_name='__main__')
|
||||
|
||||
|
||||
def test_oot_registration_for_api_server():
|
||||
port = get_open_port()
|
||||
server = multiprocessing.Process(target=server_function, args=(port, ))
|
||||
server.start()
|
||||
client = OpenAI(
|
||||
base_url=f"http://localhost:{port}/v1",
|
||||
api_key="token-abc123",
|
||||
)
|
||||
while True:
|
||||
try:
|
||||
completion = client.chat.completions.create(
|
||||
model="facebook/opt-125m",
|
||||
messages=[{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "Hello!"
|
||||
}],
|
||||
temperature=0,
|
||||
)
|
||||
break
|
||||
except OpenAIError as e:
|
||||
if "Connection error" in str(e):
|
||||
time.sleep(3)
|
||||
else:
|
||||
raise e
|
||||
server.kill()
|
||||
generated_text = completion.choices[0].message.content
|
||||
# make sure only the first token is generated
|
||||
rest = generated_text.replace("<s>", "")
|
||||
assert rest == ""
|
90
tests/fp8_kv/llama2-70b-fp8-kv/kv_cache_scales.json
Normal file
90
tests/fp8_kv/llama2-70b-fp8-kv/kv_cache_scales.json
Normal file
@ -0,0 +1,90 @@
|
||||
{
|
||||
"model_type": "llama",
|
||||
"kv_cache": {
|
||||
"dtype": "float8_e4m3fn",
|
||||
"scaling_factor": {
|
||||
"0": {
|
||||
"0": 0.0230364128947258,
|
||||
"1": 0.01979283057153225,
|
||||
"2": 0.0241350457072258,
|
||||
"3": 0.0308314748108387,
|
||||
"4": 0.0430733822286129,
|
||||
"5": 0.0370396226644516,
|
||||
"6": 0.0306222103536129,
|
||||
"7": 0.0357491634786129,
|
||||
"8": 0.0358189195394516,
|
||||
"9": 0.0443289652466774,
|
||||
"10": 0.0433175228536129,
|
||||
"11": 0.0416782945394516,
|
||||
"12": 0.0366908498108387,
|
||||
"13": 0.0432477705180645,
|
||||
"14": 0.0410505048930645,
|
||||
"15": 0.0457589291036129,
|
||||
"16": 0.0418526791036129,
|
||||
"17": 0.0432477705180645,
|
||||
"18": 0.0469447560608387,
|
||||
"19": 0.0514787957072258,
|
||||
"20": 0.0541294664144516,
|
||||
"21": 0.0587681382894516,
|
||||
"22": 0.0625,
|
||||
"23": 0.0585588738322258,
|
||||
"24": 0.0600237175822258,
|
||||
"25": 0.0588030144572258,
|
||||
"26": 0.0531180277466774,
|
||||
"27": 0.06396484375,
|
||||
"28": 0.0603027381002903,
|
||||
"29": 0.0582101047039032,
|
||||
"30": 0.0625348836183548,
|
||||
"31": 0.0585588738322258,
|
||||
"32": 0.0582798570394516,
|
||||
"33": 0.0575125589966774,
|
||||
"34": 0.0590820349752903,
|
||||
"35": 0.0614188089966774,
|
||||
"36": 0.0631975457072258,
|
||||
"37": 0.0615931935608387,
|
||||
"38": 0.0601283498108387,
|
||||
"39": 0.0571986623108387,
|
||||
"40": 0.0670340433716774,
|
||||
"41": 0.0523507259786129,
|
||||
"42": 0.0547223798930645,
|
||||
"43": 0.0631975457072258,
|
||||
"44": 0.0663713738322258,
|
||||
"45": 0.0603376142680645,
|
||||
"46": 0.0652204304933548,
|
||||
"47": 0.0734514519572258,
|
||||
"48": 0.0693708211183548,
|
||||
"49": 0.0725446492433548,
|
||||
"50": 0.0627790242433548,
|
||||
"51": 0.0691266804933548,
|
||||
"52": 0.0688825398683548,
|
||||
"53": 0.068429134786129,
|
||||
"54": 0.0605119988322258,
|
||||
"55": 0.0799386203289032,
|
||||
"56": 0.0853097140789032,
|
||||
"57": 0.0661969929933548,
|
||||
"58": 0.0689871683716774,
|
||||
"59": 0.0724051371216774,
|
||||
"60": 0.0541643425822258,
|
||||
"61": 0.0626743882894516,
|
||||
"62": 0.0628487765789032,
|
||||
"63": 0.0607212632894516,
|
||||
"64": 0.0589076466858387,
|
||||
"65": 0.0451660193502903,
|
||||
"66": 0.0453055277466774,
|
||||
"67": 0.0414341539144516,
|
||||
"68": 0.0385044664144516,
|
||||
"69": 0.0414341539144516,
|
||||
"70": 0.0466308631002903,
|
||||
"71": 0.0399693101644516,
|
||||
"72": 0.0437011756002903,
|
||||
"73": 0.0434221550822258,
|
||||
"74": 0.0428989976644516,
|
||||
"75": 0.0401785746216774,
|
||||
"76": 0.0431082621216774,
|
||||
"77": 0.0484444759786129,
|
||||
"78": 0.0417829267680645,
|
||||
"79": 0.0418178029358387
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
42
tests/fp8_kv/llama2-7b-fp8-kv/kv_cache_scales.json
Normal file
42
tests/fp8_kv/llama2-7b-fp8-kv/kv_cache_scales.json
Normal file
@ -0,0 +1,42 @@
|
||||
{
|
||||
"model_type": "llama",
|
||||
"kv_cache": {
|
||||
"dtype": "float8_e4m3fn",
|
||||
"scaling_factor": {
|
||||
"0": {
|
||||
"0": 0.0152239128947258,
|
||||
"1": 0.0188860222697258,
|
||||
"2": 0.0354178324341774,
|
||||
"3": 0.0376674123108387,
|
||||
"4": 0.0418526791036129,
|
||||
"5": 0.0433175228536129,
|
||||
"6": 0.0397600457072258,
|
||||
"7": 0.0424455925822258,
|
||||
"8": 0.0415387861430645,
|
||||
"9": 0.0408412404358387,
|
||||
"10": 0.0395856611430645,
|
||||
"11": 0.0377371683716774,
|
||||
"12": 0.0400739423930645,
|
||||
"13": 0.040771484375,
|
||||
"14": 0.0393415205180645,
|
||||
"15": 0.0369001142680645,
|
||||
"16": 0.03857421875,
|
||||
"17": 0.0387486070394516,
|
||||
"18": 0.0403180830180645,
|
||||
"19": 0.0396205373108387,
|
||||
"20": 0.0375627800822258,
|
||||
"21": 0.0407366082072258,
|
||||
"22": 0.0432477705180645,
|
||||
"23": 0.0377022884786129,
|
||||
"24": 0.0399693101644516,
|
||||
"25": 0.0374581478536129,
|
||||
"26": 0.0413295216858387,
|
||||
"27": 0.0442243330180645,
|
||||
"28": 0.0424804724752903,
|
||||
"29": 0.0456891767680645,
|
||||
"30": 0.0409109964966774,
|
||||
"31": 0.0482352152466774
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -7,7 +7,7 @@ from allclose_default import get_default_atol, get_default_rtol
|
||||
from xformers import ops as xops
|
||||
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
|
||||
|
||||
from vllm._C import cache_ops, ops
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.utils import get_max_shared_memory_bytes, is_hip
|
||||
|
||||
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
||||
@ -32,7 +32,7 @@ HEAD_SIZES = [64, 80, 96, 112, 128, 256
|
||||
|
||||
BLOCK_SIZES = [16, 32]
|
||||
USE_ALIBI = [False, True]
|
||||
KV_CACHE_DTYPE = ["auto", "fp8_e5m2"]
|
||||
KV_CACHE_DTYPE = ["auto", "fp8"]
|
||||
SEEDS = [0]
|
||||
CUDA_DEVICES = [
|
||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||
@ -172,6 +172,9 @@ def test_paged_attention(
|
||||
device)
|
||||
key_cache, value_cache = key_caches[0], value_caches[0]
|
||||
|
||||
# Using default kv_scale
|
||||
kv_scale = 1.0
|
||||
|
||||
# Call the paged attention kernel.
|
||||
output = torch.empty_like(query)
|
||||
if version == "v1":
|
||||
@ -188,6 +191,7 @@ def test_paged_attention(
|
||||
max_context_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
kv_scale,
|
||||
)
|
||||
elif version == "v2":
|
||||
num_partitions = ((max_context_len + PARTITION_SIZE - 1) //
|
||||
@ -219,12 +223,13 @@ def test_paged_attention(
|
||||
max_context_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
kv_scale,
|
||||
)
|
||||
else:
|
||||
raise AssertionError(f"Unknown version: {version}")
|
||||
|
||||
# Run the reference implementation.
|
||||
if kv_cache_dtype == "fp8_e5m2":
|
||||
if kv_cache_dtype == "fp8":
|
||||
# Convert cache data back to dtype.
|
||||
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
||||
key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x,
|
||||
@ -232,14 +237,14 @@ def test_paged_attention(
|
||||
dequantized_key_cache = torch.empty(size=key_cache_shape,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
cache_ops.convert_fp8_e5m2(key_cache, dequantized_key_cache)
|
||||
ops.convert_fp8(key_cache, dequantized_key_cache)
|
||||
key_cache = dequantized_key_cache
|
||||
|
||||
value_cache_shape = value_cache.shape
|
||||
dequantized_value_cache = torch.empty(size=value_cache_shape,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
cache_ops.convert_fp8_e5m2(value_cache, dequantized_value_cache)
|
||||
ops.convert_fp8(value_cache, dequantized_value_cache)
|
||||
value_cache = dequantized_value_cache
|
||||
|
||||
ref_output = torch.empty_like(query)
|
||||
@ -263,7 +268,8 @@ def test_paged_attention(
|
||||
|
||||
# NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
|
||||
# so we use a relaxed tolerance for the test.
|
||||
if kv_cache_dtype == "fp8_e5m2":
|
||||
atol, rtol = 1e-3, 1e-5
|
||||
if kv_cache_dtype == "fp8":
|
||||
atol, rtol = 1e-2, 1e-5
|
||||
assert torch.allclose(output, ref_output, atol=atol, rtol=rtol)
|
||||
|
||||
|
@ -4,7 +4,8 @@ from typing import Tuple
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm._C import cache_ops
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.utils import is_hip
|
||||
|
||||
COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')]
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||
@ -23,7 +24,7 @@ SEEDS = [0]
|
||||
CUDA_DEVICES = [
|
||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||
]
|
||||
KV_CACHE_DTYPE = ["auto", "fp8_e5m2"]
|
||||
KV_CACHE_DTYPE = ["auto", "fp8"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
|
||||
@ -79,7 +80,7 @@ def test_copy_blocks(
|
||||
cloned_value_caches = [value_cache.clone() for value_cache in value_caches]
|
||||
|
||||
# Call the copy blocks kernel.
|
||||
cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
|
||||
ops.copy_blocks(key_caches, value_caches, block_mapping)
|
||||
|
||||
# Run the reference implementation.
|
||||
for src, dsts in block_mapping.items():
|
||||
@ -105,6 +106,7 @@ def test_copy_blocks(
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
|
||||
@torch.inference_mode()
|
||||
def test_reshape_and_cache(
|
||||
kv_cache_factory,
|
||||
@ -116,7 +118,10 @@ def test_reshape_and_cache(
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: str,
|
||||
kv_cache_dtype: str,
|
||||
) -> None:
|
||||
if not is_hip() and kv_cache_dtype == "fp8":
|
||||
pytest.skip() # This test is not tuned for e5m2 cuda precision
|
||||
random.seed(seed)
|
||||
torch.random.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
@ -132,17 +137,33 @@ def test_reshape_and_cache(
|
||||
|
||||
# Create the KV caches.
|
||||
key_caches, value_caches = kv_cache_factory(num_blocks, block_size, 1,
|
||||
num_heads, head_size, dtype,
|
||||
None, seed, device)
|
||||
num_heads, head_size,
|
||||
kv_cache_dtype, dtype, seed,
|
||||
device)
|
||||
key_cache, value_cache = key_caches[0], value_caches[0]
|
||||
|
||||
# Clone the KV caches.
|
||||
if kv_cache_dtype == "fp8":
|
||||
cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
|
||||
ops.convert_fp8(key_cache, cloned_key_cache)
|
||||
cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
|
||||
ops.convert_fp8(value_cache, cloned_value_cache)
|
||||
else:
|
||||
cloned_key_cache = key_cache.clone()
|
||||
cloned_value_cache = value_cache.clone()
|
||||
|
||||
# Using default kv_scale
|
||||
kv_scale = 1.0
|
||||
|
||||
# Call the reshape_and_cache kernel.
|
||||
cache_ops.reshape_and_cache(key, value, key_cache, value_cache,
|
||||
slot_mapping, "auto")
|
||||
ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping,
|
||||
kv_cache_dtype, kv_scale)
|
||||
|
||||
if kv_cache_dtype == "fp8":
|
||||
result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
|
||||
ops.convert_fp8(key_cache, result_key_cache)
|
||||
result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
|
||||
ops.convert_fp8(value_cache, result_value_cache)
|
||||
|
||||
# Run the reference implementation.
|
||||
reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
|
||||
@ -156,6 +177,16 @@ def test_reshape_and_cache(
|
||||
cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i]
|
||||
cloned_value_cache[block_idx, :, :, block_offset] = value[i]
|
||||
|
||||
if kv_cache_dtype == "fp8":
|
||||
assert torch.allclose(result_key_cache,
|
||||
cloned_key_cache,
|
||||
atol=0.001,
|
||||
rtol=0.1)
|
||||
assert torch.allclose(result_value_cache,
|
||||
cloned_value_cache,
|
||||
atol=0.001,
|
||||
rtol=0.1)
|
||||
else:
|
||||
assert torch.allclose(key_cache, cloned_key_cache)
|
||||
assert torch.allclose(value_cache, cloned_value_cache)
|
||||
|
||||
@ -169,6 +200,7 @@ def test_reshape_and_cache(
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
|
||||
@torch.inference_mode()
|
||||
def test_swap_blocks(
|
||||
kv_cache_factory,
|
||||
@ -181,7 +213,12 @@ def test_swap_blocks(
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: str,
|
||||
kv_cache_dtype: str,
|
||||
) -> None:
|
||||
if kv_cache_dtype == "fp8" and "cpu" in direction:
|
||||
pytest.skip()
|
||||
if not is_hip() and kv_cache_dtype == "fp8":
|
||||
pytest.skip() # This test is not tuned for e5m2 cuda precision
|
||||
random.seed(seed)
|
||||
torch.random.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
@ -202,24 +239,60 @@ def test_swap_blocks(
|
||||
|
||||
# Create the KV caches on the first device.
|
||||
src_key_caches, src_value_caches = kv_cache_factory(
|
||||
num_blocks, block_size, 1, num_heads, head_size, dtype, None, seed,
|
||||
src_device)
|
||||
num_blocks, block_size, 1, num_heads, head_size, kv_cache_dtype, dtype,
|
||||
seed, src_device)
|
||||
|
||||
# Create the KV caches on the second device.
|
||||
dist_key_caches, dist_value_caches = kv_cache_factory(
|
||||
num_blocks, block_size, 1, num_heads, head_size, dtype, None, seed,
|
||||
dst_device)
|
||||
num_blocks, block_size, 1, num_heads, head_size, kv_cache_dtype, dtype,
|
||||
seed, dst_device)
|
||||
|
||||
src_key_caches_clone = src_key_caches[0].clone()
|
||||
src_value_caches_clone = src_value_caches[0].clone()
|
||||
|
||||
# Call the swap_blocks kernel.
|
||||
cache_ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping)
|
||||
cache_ops.swap_blocks(src_value_caches[0], dist_value_caches[0],
|
||||
block_mapping)
|
||||
ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping)
|
||||
ops.swap_blocks(src_value_caches[0], dist_value_caches[0], block_mapping)
|
||||
|
||||
for src, dst in block_mapping.items():
|
||||
assert torch.allclose(src_key_caches_clone[src].cpu(),
|
||||
dist_key_caches[0][dst].cpu())
|
||||
assert torch.allclose(src_value_caches_clone[src].cpu(),
|
||||
dist_value_caches[0][dst].cpu())
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_hip(), reason="FP8 conversion test requires e4m3")
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_fp8_conversion(
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
block_size: int,
|
||||
num_blocks: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: str,
|
||||
) -> None:
|
||||
random.seed(seed)
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
low = -224.0
|
||||
high = 224.0
|
||||
shape = (num_blocks, num_heads, head_size, block_size)
|
||||
cache = torch.empty(shape, dtype=dtype, device=device)
|
||||
cache.uniform_(low, high)
|
||||
|
||||
cache_fp8 = torch.empty_like(cache, dtype=torch.uint8)
|
||||
ops.convert_fp8(cache, cache_fp8)
|
||||
|
||||
converted_cache = torch.empty_like(cache)
|
||||
ops.convert_fp8(cache_fp8, converted_cache)
|
||||
|
||||
assert torch.allclose(cache, converted_cache, atol=0.001, rtol=0.1)
|
||||
|
@ -73,7 +73,7 @@ def test_mixtral_moe(dtype: torch.dtype):
|
||||
).cuda()
|
||||
|
||||
# Load the weights
|
||||
vllm_moe.gate.linear_weights["weight"][:] = hf_moe.gate.weight.data
|
||||
vllm_moe.gate.weight.data[:] = hf_moe.gate.weight.data
|
||||
for i in range(config.num_local_experts):
|
||||
weights = (hf_moe.experts[i].w1.weight.data,
|
||||
hf_moe.experts[i].w3.weight.data)
|
||||
|
@ -12,6 +12,7 @@ from huggingface_hub import snapshot_download
|
||||
|
||||
import vllm
|
||||
from vllm.config import LoRAConfig
|
||||
from vllm.distributed import destroy_model_parallel, initialize_model_parallel
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
@ -19,8 +20,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
destroy_model_parallel, initialize_model_parallel)
|
||||
|
||||
|
||||
def cleanup():
|
||||
@ -144,6 +143,11 @@ def baichuan_lora_files():
|
||||
return snapshot_download(repo_id="jeeejeee/baichuan7b-text2sql-spider")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def tinyllama_lora_files():
|
||||
return snapshot_download(repo_id="jashing/tinyllama-colorist-lora")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llama_2_7b_engine_extra_embeddings() -> nn.Module:
|
||||
cleanup()
|
||||
|
@ -62,7 +62,7 @@ def test_baichuan_lora(baichuan_lora_files):
|
||||
|
||||
|
||||
@pytest.mark.skip("Requires multiple GPUs")
|
||||
def test_llama_tensor_parallel_equality(baichuan_lora_files):
|
||||
def test_baichuan_tensor_parallel_equality(baichuan_lora_files):
|
||||
# Cannot use as it will initialize torch.cuda too early...
|
||||
# if torch.cuda.device_count() < 4:
|
||||
# pytest.skip(f"Not enough GPUs for tensor parallelism {4}")
|
||||
|
@ -170,7 +170,8 @@ def create_random_inputs(
|
||||
@torch.inference_mode()
|
||||
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_embeddings(dist_init, num_loras, device) -> None:
|
||||
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
|
||||
def test_embeddings(dist_init, num_loras, device, vocab_size) -> None:
|
||||
|
||||
torch.set_default_device(device)
|
||||
max_loras = 8
|
||||
@ -179,9 +180,9 @@ def test_embeddings(dist_init, num_loras, device) -> None:
|
||||
lora_dtype=torch.float16)
|
||||
|
||||
def create_random_embedding_layer():
|
||||
embedding = VocabParallelEmbedding(512, 256)
|
||||
embedding = VocabParallelEmbedding(vocab_size, 256)
|
||||
embedding.weight.data = torch.rand_like(embedding.weight.data)
|
||||
embedding.weight.data[512:, :] = 0
|
||||
embedding.weight.data[vocab_size:, :] = 0
|
||||
lora_embedding = VocabParallelEmbeddingWithLoRA(embedding)
|
||||
lora_embedding.create_lora_weights(max_loras, lora_config)
|
||||
|
||||
@ -203,12 +204,13 @@ def test_embeddings(dist_init, num_loras, device) -> None:
|
||||
active_lora_ids=list(lora_dict.keys()),
|
||||
num_inputs=num_loras * 3,
|
||||
input_size=(200, ),
|
||||
input_range=(1, 512),
|
||||
input_range=(1, vocab_size),
|
||||
)
|
||||
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
||||
|
||||
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
|
||||
512, lora_config.lora_extra_vocab_size)
|
||||
vocab_size,
|
||||
lora_config.lora_extra_vocab_size)
|
||||
lora_embedding.set_mapping(*mapping_info)
|
||||
|
||||
lora_result = lora_embedding(torch.cat(inputs))
|
||||
@ -240,12 +242,13 @@ def test_embeddings(dist_init, num_loras, device) -> None:
|
||||
active_lora_ids=[0],
|
||||
num_inputs=num_loras * 3,
|
||||
input_size=(200, ),
|
||||
input_range=(1, 512),
|
||||
input_range=(1, vocab_size),
|
||||
)
|
||||
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
||||
|
||||
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
|
||||
512, lora_config.lora_extra_vocab_size)
|
||||
vocab_size,
|
||||
lora_config.lora_extra_vocab_size)
|
||||
lora_embedding.set_mapping(*mapping_info, )
|
||||
|
||||
lora_result = lora_embedding(torch.cat(inputs))
|
||||
@ -263,7 +266,9 @@ def test_embeddings(dist_init, num_loras, device) -> None:
|
||||
# reason="Fails when loras are in any slot other than the first.")
|
||||
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None:
|
||||
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
|
||||
def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
|
||||
vocab_size) -> None:
|
||||
|
||||
torch.set_default_device(device)
|
||||
max_loras = 8
|
||||
@ -272,15 +277,15 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None:
|
||||
lora_dtype=torch.float16)
|
||||
|
||||
def create_random_embedding_layer():
|
||||
embedding = VocabParallelEmbedding(512, 256)
|
||||
embedding = VocabParallelEmbedding(vocab_size, 256)
|
||||
embedding_data = torch.rand_like(embedding.weight.data)
|
||||
embedding.weight.data = embedding_data
|
||||
embedding.weight.data[512:, :] = 0
|
||||
embedding.weight.data[vocab_size:, :] = 0
|
||||
expanded_embedding = VocabParallelEmbedding(
|
||||
512 + lora_config.lora_extra_vocab_size * max_loras,
|
||||
vocab_size + lora_config.lora_extra_vocab_size * max_loras,
|
||||
256,
|
||||
org_num_embeddings=512)
|
||||
expanded_embedding.weight.data[:512, :] = embedding_data
|
||||
org_num_embeddings=vocab_size)
|
||||
expanded_embedding.weight.data[:vocab_size, :] = embedding_data
|
||||
# We need to deepcopy the embedding as it will be modified
|
||||
# in place
|
||||
lora_embedding = VocabParallelEmbeddingWithLoRA(
|
||||
@ -298,7 +303,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None:
|
||||
id_to_index,
|
||||
layer=lora_embedding,
|
||||
layer_weights=torch.zeros(
|
||||
(256, 512 + lora_config.lora_extra_vocab_size)),
|
||||
(256, vocab_size + lora_config.lora_extra_vocab_size)),
|
||||
generate_embeddings_tensor=256,
|
||||
)
|
||||
|
||||
@ -316,7 +321,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None:
|
||||
active_lora_ids=list(lora_dict.keys()),
|
||||
num_inputs=num_loras * 3,
|
||||
input_size=(200, ),
|
||||
input_range=(1, 512),
|
||||
input_range=(1, vocab_size),
|
||||
)
|
||||
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
||||
|
||||
@ -327,16 +332,18 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None:
|
||||
for input_, original_input_, lora_id in zip(inputs, original_inputs,
|
||||
prompt_mapping):
|
||||
embedding_id = lora_id - 1
|
||||
input_[-1] = 512 + (embedding_id * embeddings_tensor_len)
|
||||
original_input_[-1] = 512
|
||||
input_[-2] = 512 + ((embedding_id + 1) * embeddings_tensor_len - 1)
|
||||
original_input_[-2] = 512 + embeddings_tensor_len - 1
|
||||
input_[-1] = vocab_size + (embedding_id * embeddings_tensor_len)
|
||||
original_input_[-1] = vocab_size
|
||||
input_[-2] = vocab_size + (
|
||||
(embedding_id + 1) * embeddings_tensor_len - 1)
|
||||
original_input_[-2] = vocab_size + embeddings_tensor_len - 1
|
||||
|
||||
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
|
||||
512, lora_config.lora_extra_vocab_size)
|
||||
vocab_size,
|
||||
lora_config.lora_extra_vocab_size)
|
||||
lora_embedding.set_mapping(*mapping_info, )
|
||||
|
||||
expanded_embedding.weight[512:512 +
|
||||
expanded_embedding.weight[vocab_size:vocab_size +
|
||||
(embeddings_tensor_len *
|
||||
max_loras)] = torch.cat(embeddings_tensors)
|
||||
|
||||
@ -370,14 +377,15 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None:
|
||||
active_lora_ids=[0],
|
||||
num_inputs=num_loras * 3,
|
||||
input_size=(200, ),
|
||||
input_range=(1, 512),
|
||||
input_range=(1, vocab_size),
|
||||
)
|
||||
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
||||
|
||||
original_inputs = deepcopy(inputs)
|
||||
|
||||
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
|
||||
512, lora_config.lora_extra_vocab_size)
|
||||
vocab_size,
|
||||
lora_config.lora_extra_vocab_size)
|
||||
lora_embedding.set_mapping(*mapping_info, )
|
||||
|
||||
lora_result = lora_embedding(torch.cat(original_inputs))
|
||||
@ -393,7 +401,9 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None:
|
||||
@torch.inference_mode()
|
||||
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_lm_head_logits_processor(dist_init, num_loras, device) -> None:
|
||||
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
|
||||
def test_lm_head_logits_processor(dist_init, num_loras, device,
|
||||
vocab_size) -> None:
|
||||
|
||||
torch.set_default_device(device)
|
||||
max_loras = 8
|
||||
@ -402,12 +412,12 @@ def test_lm_head_logits_processor(dist_init, num_loras, device) -> None:
|
||||
lora_dtype=torch.float16)
|
||||
|
||||
def _pretest():
|
||||
linear = ParallelLMHead(32000 + lora_config.lora_extra_vocab_size,
|
||||
1024, 32000)
|
||||
linear = ParallelLMHead(vocab_size + lora_config.lora_extra_vocab_size,
|
||||
1024, vocab_size)
|
||||
linear.weight.data = torch.rand_like(linear.weight.data)
|
||||
linear.weight.data[:, 32000:] = 0
|
||||
linear.weight.data[:, vocab_size:] = 0
|
||||
logits_processor = LogitsProcessor(
|
||||
32000 + lora_config.lora_extra_vocab_size, 32000)
|
||||
vocab_size + lora_config.lora_extra_vocab_size, vocab_size)
|
||||
lora_logits_processor = LogitsProcessorWithLoRA(
|
||||
logits_processor, 1024, linear.weight.dtype, linear.weight.device)
|
||||
lora_logits_processor.create_lora_weights(max_loras, lora_config)
|
||||
@ -444,7 +454,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device) -> None:
|
||||
lora_mapping,
|
||||
id_to_index,
|
||||
max_loras,
|
||||
32000,
|
||||
vocab_size,
|
||||
lora_config.lora_extra_vocab_size,
|
||||
)
|
||||
lora_logits_processor.set_mapping(*mapping_info, )
|
||||
@ -460,7 +470,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device) -> None:
|
||||
org_vocab_size:logits_processor.org_vocab_size +
|
||||
embeddings_tensor_len] = embeddings_tensor
|
||||
|
||||
logits_processor.org_vocab_size = (32000 +
|
||||
logits_processor.org_vocab_size = (vocab_size +
|
||||
lora_config.lora_extra_vocab_size)
|
||||
expected_results = []
|
||||
for input_, lora_id in zip(inputs, prompt_mapping):
|
||||
@ -468,11 +478,11 @@ def test_lm_head_logits_processor(dist_init, num_loras, device) -> None:
|
||||
result = logits_processor._get_logits(hidden_states=input_,
|
||||
embedding=linear.weight,
|
||||
embedding_bias=None)
|
||||
result[:, 32000 + embeddings_tensor_len:] = float("-inf")
|
||||
result[:, vocab_size + embeddings_tensor_len:] = float("-inf")
|
||||
result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
|
||||
expected_results.append(result)
|
||||
expected_result = torch.cat(expected_results)
|
||||
logits_processor.org_vocab_size = 32000
|
||||
logits_processor.org_vocab_size = vocab_size
|
||||
|
||||
# Check that resetting the lora weights succeeds
|
||||
|
||||
@ -489,14 +499,14 @@ def test_lm_head_logits_processor(dist_init, num_loras, device) -> None:
|
||||
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
||||
|
||||
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
|
||||
32000,
|
||||
vocab_size,
|
||||
lora_config.lora_extra_vocab_size)
|
||||
lora_logits_processor.set_mapping(*mapping_info, )
|
||||
|
||||
lora_result = lora_logits_processor._get_logits(
|
||||
hidden_states=torch.cat(inputs),
|
||||
embedding=original_weight,
|
||||
embedding_bias=None)[:, :32000]
|
||||
embedding_bias=None)[:, :vocab_size]
|
||||
expected_result = logits_processor._get_logits(
|
||||
hidden_states=torch.cat(inputs),
|
||||
embedding=original_weight,
|
||||
|
40
tests/lora/test_lora_checkpoints.py
Normal file
40
tests/lora/test_lora_checkpoints.py
Normal file
@ -0,0 +1,40 @@
|
||||
import pytest
|
||||
|
||||
from vllm.lora.models import LoRAModel
|
||||
from vllm.model_executor.models.baichuan import BaiChuanBaseForCausalLM
|
||||
|
||||
|
||||
@pytest.mark.parametrize("lora_name", ["baichuan7B", "chatglm3-6b"])
|
||||
def test_load_checkpoints(lora_name, chatglm3_lora_files, baichuan_lora_files):
|
||||
supported_lora_modules = BaiChuanBaseForCausalLM.supported_lora_modules
|
||||
packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping
|
||||
embedding_modules = BaiChuanBaseForCausalLM.embedding_modules
|
||||
embed_padding_modules = BaiChuanBaseForCausalLM.embedding_padding_modules
|
||||
expected_lora_modules = []
|
||||
for module in supported_lora_modules:
|
||||
if module in packed_modules_mapping:
|
||||
expected_lora_modules.extend(packed_modules_mapping[module])
|
||||
else:
|
||||
expected_lora_modules.append(module)
|
||||
if lora_name == "baichuan7B":
|
||||
# For the baichuan7B model, load it's LoRA,
|
||||
# and the test should pass.
|
||||
LoRAModel.from_local_checkpoint(
|
||||
baichuan_lora_files,
|
||||
expected_lora_modules,
|
||||
lora_model_id=1,
|
||||
device="cpu",
|
||||
embedding_modules=embedding_modules,
|
||||
embedding_padding_modules=embed_padding_modules)
|
||||
else:
|
||||
# For the baichuan7B model, load chatglm3-6b's LoRA,
|
||||
# and the test should raise the following error.
|
||||
expected_error = "Please verify that the loaded LoRA module is correct" # noqa: E501
|
||||
with pytest.raises(ValueError, match=expected_error):
|
||||
LoRAModel.from_local_checkpoint(
|
||||
chatglm3_lora_files,
|
||||
expected_lora_modules,
|
||||
lora_model_id=1,
|
||||
device="cpu",
|
||||
embedding_modules=embedding_modules,
|
||||
embedding_padding_modules=embed_padding_modules)
|
@ -43,10 +43,52 @@ def _lora_ref_impl(
|
||||
|
||||
|
||||
H1 = H2 = [
|
||||
128, 256, 512, 1024, 1152, 1280, 1536, 2048, 2304, 2560, 2752, 3072, 3456,
|
||||
3584, 4096, 4608, 5120, 5504, 5632, 6144, 6848, 6912, 7168, 8192, 9216,
|
||||
10240, 11008, 13824, 14336, 22016, 24576, 27392, 32000, 32256, 32512,
|
||||
32768, 33024
|
||||
128,
|
||||
256,
|
||||
512,
|
||||
1024,
|
||||
1152,
|
||||
1280,
|
||||
1536,
|
||||
2048,
|
||||
2304,
|
||||
2560,
|
||||
2752,
|
||||
3072,
|
||||
3456,
|
||||
3584,
|
||||
4096,
|
||||
4608,
|
||||
5120,
|
||||
5504,
|
||||
5632,
|
||||
6144,
|
||||
6848,
|
||||
6912,
|
||||
7168,
|
||||
8192,
|
||||
9216,
|
||||
10240,
|
||||
11008,
|
||||
13824,
|
||||
14336,
|
||||
15360,
|
||||
22016,
|
||||
24576,
|
||||
27392,
|
||||
32000,
|
||||
32256,
|
||||
32512,
|
||||
32768,
|
||||
33024,
|
||||
36864,
|
||||
49152,
|
||||
64000,
|
||||
64256,
|
||||
102400,
|
||||
102656,
|
||||
128000,
|
||||
128256,
|
||||
]
|
||||
SEED = [0xabcdabcd987]
|
||||
CUDA_DEVICES = [
|
||||
|
179
tests/lora/test_quant_model.py
Normal file
179
tests/lora/test_quant_model.py
Normal file
@ -0,0 +1,179 @@
|
||||
# Adapted from
|
||||
# https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/tests/lora/test_llama.py
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
import vllm
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
from .conftest import cleanup
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelWithQuantization:
|
||||
model_path: str
|
||||
quantization: str
|
||||
|
||||
|
||||
MODELS: List[ModelWithQuantization] = [
|
||||
ModelWithQuantization(model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ",
|
||||
quantization="AWQ"),
|
||||
ModelWithQuantization(model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ",
|
||||
quantization="GPTQ"),
|
||||
]
|
||||
|
||||
|
||||
def do_sample(llm, lora_path: str, lora_id: int, max_tokens=256):
|
||||
raw_prompts = [
|
||||
"Give me an orange-ish brown color",
|
||||
"Give me a neon pink color",
|
||||
]
|
||||
|
||||
def format_prompt_tuples(prompt):
|
||||
return f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
|
||||
|
||||
prompts = [format_prompt_tuples(p) for p in raw_prompts]
|
||||
|
||||
sampling_params = vllm.SamplingParams(temperature=0,
|
||||
max_tokens=max_tokens,
|
||||
stop=["<|im_end|>"])
|
||||
outputs = llm.generate(
|
||||
prompts,
|
||||
sampling_params,
|
||||
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
|
||||
if lora_id else None)
|
||||
# Print the outputs.
|
||||
generated_texts = []
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
generated_texts.append(generated_text)
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
return generated_texts
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("tp_size", [1])
|
||||
def test_quant_model_lora(tinyllama_lora_files, model, tp_size):
|
||||
# Cannot use as it will initialize torch.cuda too early...
|
||||
# if torch.cuda.device_count() < tp_size:
|
||||
# pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
|
||||
|
||||
llm = vllm.LLM(model=model.model_path,
|
||||
enable_lora=True,
|
||||
max_num_seqs=16,
|
||||
max_loras=4,
|
||||
max_model_len=400,
|
||||
tensor_parallel_size=tp_size,
|
||||
quantization=model.quantization,
|
||||
trust_remote_code=True)
|
||||
|
||||
if model.quantization is None:
|
||||
expected_no_lora_output = [
|
||||
"Here are some examples of orange-brown colors",
|
||||
"I'm sorry, I don't have"
|
||||
]
|
||||
expected_lora_output = [
|
||||
"#ff8050",
|
||||
"#ff8080",
|
||||
]
|
||||
elif model.quantization == "AWQ":
|
||||
expected_no_lora_output = [
|
||||
"I'm sorry, I don't understand",
|
||||
"I'm sorry, I don't understand",
|
||||
]
|
||||
expected_lora_output = [
|
||||
"#f07700: A v",
|
||||
"#f00000: A v",
|
||||
]
|
||||
elif model.quantization == "GPTQ":
|
||||
expected_no_lora_output = [
|
||||
"I'm sorry, I don't have",
|
||||
"I'm sorry, I don't have",
|
||||
]
|
||||
expected_lora_output = [
|
||||
"#f08800: This is",
|
||||
"#f07788 \n#",
|
||||
]
|
||||
|
||||
def expect_match(output, expected_output):
|
||||
# HACK: GPTQ lora outputs are just incredibly unstable.
|
||||
# Assert that the outputs changed.
|
||||
if (model.quantization == "GPTQ"
|
||||
and expected_output is expected_lora_output):
|
||||
assert output != expected_no_lora_output
|
||||
for i, o in enumerate(output):
|
||||
assert o.startswith(
|
||||
'#'), f"Expected example {i} to start with # but got {o}"
|
||||
return
|
||||
assert output == expected_output
|
||||
|
||||
max_tokens = 10
|
||||
|
||||
print("lora adapter created")
|
||||
output = do_sample(llm,
|
||||
tinyllama_lora_files,
|
||||
lora_id=0,
|
||||
max_tokens=max_tokens)
|
||||
expect_match(output, expected_no_lora_output)
|
||||
|
||||
print("lora 1")
|
||||
output = do_sample(llm,
|
||||
tinyllama_lora_files,
|
||||
lora_id=1,
|
||||
max_tokens=max_tokens)
|
||||
expect_match(output, expected_lora_output)
|
||||
|
||||
print("no lora")
|
||||
output = do_sample(llm,
|
||||
tinyllama_lora_files,
|
||||
lora_id=0,
|
||||
max_tokens=max_tokens)
|
||||
expect_match(output, expected_no_lora_output)
|
||||
|
||||
print("lora 2")
|
||||
output = do_sample(llm,
|
||||
tinyllama_lora_files,
|
||||
lora_id=2,
|
||||
max_tokens=max_tokens)
|
||||
expect_match(output, expected_lora_output)
|
||||
|
||||
print("removing lora")
|
||||
|
||||
del llm
|
||||
cleanup()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.skip("Requires multiple GPUs")
|
||||
def test_quant_model_tp_equality(tinyllama_lora_files, model):
|
||||
# Cannot use as it will initialize torch.cuda too early...
|
||||
# if torch.cuda.device_count() < 2:
|
||||
# pytest.skip(f"Not enough GPUs for tensor parallelism {2}")
|
||||
|
||||
llm_tp1 = vllm.LLM(model=model.model_path,
|
||||
enable_lora=True,
|
||||
max_num_seqs=16,
|
||||
max_loras=4,
|
||||
tensor_parallel_size=1,
|
||||
quantization=model.quantization,
|
||||
trust_remote_code=True)
|
||||
output_tp1 = do_sample(llm_tp1, tinyllama_lora_files, lora_id=1)
|
||||
|
||||
del llm_tp1
|
||||
cleanup()
|
||||
|
||||
llm_tp2 = vllm.LLM(model=model.model_path,
|
||||
enable_lora=True,
|
||||
max_num_seqs=16,
|
||||
max_loras=4,
|
||||
tensor_parallel_size=2,
|
||||
quantization=model.quantization)
|
||||
output_tp2 = do_sample(llm_tp2, tinyllama_lora_files, lora_id=1)
|
||||
|
||||
del llm_tp2
|
||||
cleanup()
|
||||
|
||||
assert output_tp1 == output_tp2
|
@ -3,8 +3,8 @@ import random
|
||||
import tempfile
|
||||
from unittest.mock import patch
|
||||
|
||||
from vllm.config import (DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig,
|
||||
SchedulerConfig)
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig)
|
||||
from vllm.lora.models import LoRAMapping
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.worker.worker import Worker
|
||||
@ -27,6 +27,10 @@ def test_worker_apply_lora(sql_lora_files):
|
||||
parallel_config=ParallelConfig(1, 1, False),
|
||||
scheduler_config=SchedulerConfig(32, 32, 32),
|
||||
device_config=DeviceConfig("cuda"),
|
||||
cache_config=CacheConfig(block_size=16,
|
||||
gpu_memory_utilization=1.,
|
||||
swap_space=0,
|
||||
cache_dtype="auto"),
|
||||
local_rank=0,
|
||||
rank=0,
|
||||
lora_config=LoRAConfig(max_lora_rank=8, max_cpu_loras=32,
|
||||
|
26
tests/model_executor/weight_utils.py
Normal file
26
tests/model_executor/weight_utils.py
Normal file
@ -0,0 +1,26 @@
|
||||
import os
|
||||
|
||||
import huggingface_hub.constants
|
||||
import pytest
|
||||
|
||||
from vllm.model_executor.weight_utils import enable_hf_transfer
|
||||
|
||||
|
||||
def test_hf_transfer_auto_activation():
|
||||
if "HF_HUB_ENABLE_HF_TRANSFER" in os.environ:
|
||||
# in case it is already set, we can't test the auto activation
|
||||
pytest.skip(
|
||||
"HF_HUB_ENABLE_HF_TRANSFER is set, can't test auto activation")
|
||||
enable_hf_transfer()
|
||||
try:
|
||||
# enable hf hub transfer if available
|
||||
import hf_transfer # type: ignore # noqa
|
||||
HF_TRANFER_ACTIVE = True
|
||||
except ImportError:
|
||||
HF_TRANFER_ACTIVE = False
|
||||
assert (huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER ==
|
||||
HF_TRANFER_ACTIVE)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_hf_transfer_auto_activation()
|
@ -12,7 +12,7 @@ MODELS = [
|
||||
"gpt2",
|
||||
"bigcode/tiny_starcoder_py",
|
||||
"EleutherAI/pythia-70m",
|
||||
"bigscience/bloom-560m",
|
||||
"bigscience/bloom-560m", # Testing alibi slopes.
|
||||
"microsoft/phi-2",
|
||||
"stabilityai/stablelm-3b-4e1t",
|
||||
# "allenai/OLMo-1B", # Broken
|
||||
|
32
tests/models/test_oot_registration.py
Normal file
32
tests/models/test_oot_registration.py
Normal file
@ -0,0 +1,32 @@
|
||||
import torch
|
||||
|
||||
from vllm import LLM, ModelRegistry, SamplingParams
|
||||
from vllm.model_executor.models.opt import OPTForCausalLM
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
|
||||
|
||||
class MyOPTForCausalLM(OPTForCausalLM):
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
# this dummy model always predicts the first token
|
||||
logits = super().compute_logits(hidden_states, sampling_metadata)
|
||||
logits.zero_()
|
||||
logits[:, 0] += 1.0
|
||||
return logits
|
||||
|
||||
|
||||
def test_oot_registration():
|
||||
# register our dummy model
|
||||
ModelRegistry.register_model("OPTForCausalLM", MyOPTForCausalLM)
|
||||
prompts = ["Hello, my name is", "The text does not matter"]
|
||||
sampling_params = SamplingParams(temperature=0)
|
||||
llm = LLM(model="facebook/opt-125m")
|
||||
first_token = llm.get_tokenizer().decode(0)
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
for output in outputs:
|
||||
generated_text = output.outputs[0].text
|
||||
# make sure only the first token is generated
|
||||
rest = generated_text.replace(first_token, "")
|
||||
assert rest == ""
|
62
tests/samplers/test_logits_processor.py
Normal file
62
tests/samplers/test_logits_processor.py
Normal file
@ -0,0 +1,62 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import SamplingParams
|
||||
|
||||
MODELS = ["facebook/opt-125m"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
def test_logits_processor_force_generate(
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
) -> None:
|
||||
vllm_model = vllm_runner(model, dtype=dtype)
|
||||
tokenizer = vllm_model.model.get_tokenizer()
|
||||
repeat_times = 2
|
||||
enforced_answers = " vLLM"
|
||||
vllm_token_ids = tokenizer.encode(enforced_answers,
|
||||
add_special_tokens=False)
|
||||
max_tokens = len(vllm_token_ids) * repeat_times
|
||||
|
||||
def pick_vllm(token_ids, logits):
|
||||
token_id = vllm_token_ids[len(token_ids) % len(vllm_token_ids)]
|
||||
logits[token_id] = torch.finfo(logits.dtype).max
|
||||
return logits
|
||||
|
||||
params_with_logprobs = SamplingParams(
|
||||
logits_processors=[pick_vllm],
|
||||
prompt_logprobs=3,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
# test logits_processors when prompt_logprobs is not None
|
||||
vllm_model.model._add_request(
|
||||
prompt=example_prompts[0],
|
||||
sampling_params=params_with_logprobs,
|
||||
prompt_token_ids=None,
|
||||
)
|
||||
|
||||
# test prompt_logprobs is not None
|
||||
vllm_model.model._add_request(
|
||||
prompt=example_prompts[1],
|
||||
sampling_params=SamplingParams(
|
||||
prompt_logprobs=3,
|
||||
max_tokens=max_tokens,
|
||||
),
|
||||
prompt_token_ids=None,
|
||||
)
|
||||
|
||||
# test grouped requests
|
||||
vllm_model.model._add_request(
|
||||
prompt=example_prompts[2],
|
||||
sampling_params=SamplingParams(max_tokens=max_tokens),
|
||||
prompt_token_ids=None,
|
||||
)
|
||||
|
||||
outputs = vllm_model.model._run_engine(False)
|
||||
|
||||
assert outputs[0].outputs[0].text == enforced_answers * repeat_times
|
@ -1,3 +1,4 @@
|
||||
import itertools
|
||||
import random
|
||||
from typing import List, Optional, Tuple
|
||||
from unittest.mock import patch
|
||||
@ -194,11 +195,15 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
|
||||
|
||||
def create_sampling_params(min_tokens,
|
||||
eos_token_id=0,
|
||||
stop_token_ids=None):
|
||||
*,
|
||||
stop_token_ids: Optional[List[str]] = None,
|
||||
prompt_logprobs: Optional[int] = None):
|
||||
sampling_params = SamplingParams(
|
||||
min_tokens=min_tokens,
|
||||
max_tokens=9999, # keep higher than max of min_tokens
|
||||
stop_token_ids=stop_token_ids,
|
||||
# requesting prompt_logprobs changes the structure of `logits`
|
||||
prompt_logprobs=prompt_logprobs,
|
||||
)
|
||||
sampling_params.eos_token_id = eos_token_id
|
||||
return sampling_params
|
||||
@ -217,9 +222,9 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
|
||||
|
||||
expected_penalization = []
|
||||
sequence_metadata_list = []
|
||||
while batch_size > 0:
|
||||
# 20% chance to generate prompt seq group with single sequence
|
||||
# 20% chance to generate seq group metadata list with all prompts
|
||||
is_prompt = random.random() < 0.2
|
||||
while batch_size > 0:
|
||||
num_seqs = 1 if is_prompt else random.randint(1, batch_size)
|
||||
|
||||
eos_token_id = random.randint(0, VOCAB_SIZE - 1)
|
||||
@ -240,7 +245,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
|
||||
seq_group_penalization = []
|
||||
for _ in range(num_seqs):
|
||||
num_input = random.randint(1, 100)
|
||||
num_generated = random.randint(1, 100) if not is_prompt else 0
|
||||
num_generated = 0 if is_prompt else random.randint(1, 100)
|
||||
seq_data[next(seq_id_counter)] = create_sequence_data(
|
||||
num_input=num_input, num_generated=num_generated)
|
||||
seq_group_penalization.append(num_generated < min_tokens)
|
||||
@ -292,6 +297,21 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
|
||||
]
|
||||
}
|
||||
|
||||
prompt_with_penalization_and_prompt_logprobs = {
|
||||
"expected_penalization": [False, False, True],
|
||||
"seq_group_metadata_list": [
|
||||
SequenceGroupMetadata(
|
||||
request_id="test_1",
|
||||
is_prompt=True,
|
||||
seq_data={
|
||||
next(seq_id_counter): create_sequence_data(num_input=3),
|
||||
},
|
||||
sampling_params=create_sampling_params(1, prompt_logprobs=3),
|
||||
block_tables={},
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
stop_penalizing_after_min_tokens = {
|
||||
"expected_penalization": [False],
|
||||
"seq_group_metadata_list": [
|
||||
@ -309,8 +329,34 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
|
||||
}
|
||||
|
||||
stop_token_ids = [42, 99, 42, 0] # intentional duplication
|
||||
simple_combination = {
|
||||
"expected_penalization": [True, False, False],
|
||||
prompt_combination = {
|
||||
"expected_penalization": [False, True, False],
|
||||
"seq_group_metadata_list": [
|
||||
SequenceGroupMetadata(
|
||||
request_id="test_2",
|
||||
is_prompt=True,
|
||||
seq_data={
|
||||
next(seq_id_counter): create_sequence_data(num_input=2),
|
||||
},
|
||||
sampling_params=create_sampling_params(1, prompt_logprobs=3),
|
||||
block_tables={},
|
||||
),
|
||||
SequenceGroupMetadata(
|
||||
request_id="test_3",
|
||||
is_prompt=True,
|
||||
seq_data={
|
||||
next(seq_id_counter): create_sequence_data(),
|
||||
},
|
||||
sampling_params=create_sampling_params(
|
||||
0, stop_token_ids=stop_token_ids),
|
||||
block_tables={},
|
||||
)
|
||||
]
|
||||
}
|
||||
|
||||
stop_token_ids = [1, 999, 37, 37] # intentional duplication
|
||||
decode_combination = {
|
||||
"expected_penalization": [True, False, False, True, False],
|
||||
"seq_group_metadata_list": [
|
||||
SequenceGroupMetadata(
|
||||
request_id="test_1",
|
||||
@ -327,14 +373,19 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
|
||||
),
|
||||
SequenceGroupMetadata(
|
||||
request_id="test_2",
|
||||
is_prompt=True,
|
||||
is_prompt=False,
|
||||
seq_data={
|
||||
next(seq_id_counter): create_sequence_data(),
|
||||
next(seq_id_counter):
|
||||
create_sequence_data(num_generated=20),
|
||||
next(seq_id_counter):
|
||||
create_sequence_data(num_generated=1),
|
||||
next(seq_id_counter):
|
||||
create_sequence_data(num_generated=10),
|
||||
},
|
||||
sampling_params=create_sampling_params(
|
||||
0, stop_token_ids=stop_token_ids),
|
||||
10, prompt_logprobs=5, stop_token_ids=stop_token_ids),
|
||||
block_tables={},
|
||||
)
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
@ -342,8 +393,10 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
|
||||
test_cases = [
|
||||
prompt_without_penalization,
|
||||
prompt_with_penalization,
|
||||
prompt_with_penalization_and_prompt_logprobs,
|
||||
stop_penalizing_after_min_tokens,
|
||||
simple_combination,
|
||||
prompt_combination,
|
||||
decode_combination,
|
||||
]
|
||||
else:
|
||||
test_cases = [generate_test_case()]
|
||||
@ -351,30 +404,49 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
|
||||
def run_test_case(*,
|
||||
expected_penalization=None,
|
||||
seq_group_metadata_list=None):
|
||||
assert expected_penalization, "Invalid test case"
|
||||
assert seq_group_metadata_list, "Invalid test case"
|
||||
assert expected_penalization, \
|
||||
"Invalid test case, need expected_penalization"
|
||||
assert seq_group_metadata_list, \
|
||||
"Invalid test case, need seq_group_metadata_list"
|
||||
|
||||
batch_size = 0
|
||||
prompt_lens = []
|
||||
sampling_params_per_seq = []
|
||||
sampling_params_per_row = []
|
||||
for sgm in seq_group_metadata_list:
|
||||
num_seqs = len(sgm.seq_data)
|
||||
batch_size += num_seqs
|
||||
sampling_params = sgm.sampling_params
|
||||
for seq_id in sgm.seq_data:
|
||||
prompt_lens.append(sgm.seq_data[seq_id].get_prompt_len())
|
||||
sampling_params_per_seq.append(sampling_params)
|
||||
|
||||
num_rows = len(sgm.seq_data)
|
||||
if sgm.is_prompt:
|
||||
# a prompt seq_group has only one sequence
|
||||
seq_data = next(iter(sgm.seq_data.values()))
|
||||
prompt_len = seq_data.get_prompt_len()
|
||||
prompt_lens.append(prompt_len)
|
||||
|
||||
if sgm.sampling_params.prompt_logprobs:
|
||||
# with prompt_logprobs each token in the prompt has a row in
|
||||
# logits
|
||||
num_rows = prompt_len
|
||||
|
||||
batch_size += num_rows
|
||||
sampling_params_per_row.extend(
|
||||
itertools.repeat(sampling_params, num_rows))
|
||||
|
||||
assert len(
|
||||
expected_penalization
|
||||
) == batch_size, \
|
||||
("Invalid test case, expected_penalization does not match computed"
|
||||
"batch size")
|
||||
|
||||
_, fake_logits, sampler, model_runner = _prepare_test(batch_size)
|
||||
sampling_metadata = model_runner._prepare_sample(
|
||||
seq_group_metadata_list,
|
||||
prompt_lens=prompt_lens,
|
||||
subquery_lens=prompt_lens)
|
||||
prompt_lens=prompt_lens if prompt_lens else None,
|
||||
subquery_lens=prompt_lens if prompt_lens else None)
|
||||
# the logits tensor is modified in-place by the sampler
|
||||
_ = sampler(logits=fake_logits, sampling_metadata=sampling_metadata)
|
||||
|
||||
for logits_idx, (should_penalize, sampling_params) in enumerate(
|
||||
zip(expected_penalization, sampling_params_per_seq)):
|
||||
zip(expected_penalization, sampling_params_per_row)):
|
||||
|
||||
tokens_to_check = [sampling_params.eos_token_id]
|
||||
if sampling_params.stop_token_ids:
|
||||
|
41
tests/spec_decode/e2e/conftest.py
Normal file
41
tests/spec_decode/e2e/conftest.py
Normal file
@ -0,0 +1,41 @@
|
||||
import pytest
|
||||
|
||||
from tests.conftest import cleanup
|
||||
from vllm import LLM
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def baseline_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, seed):
|
||||
return create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, seed)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
test_llm_kwargs, seed):
|
||||
return create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
test_llm_kwargs, seed)
|
||||
|
||||
|
||||
def create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
distinct_llm_kwargs, seed):
|
||||
kwargs = {
|
||||
**common_llm_kwargs,
|
||||
**per_test_common_llm_kwargs,
|
||||
**distinct_llm_kwargs,
|
||||
}
|
||||
|
||||
def generator_inner():
|
||||
llm = LLM(**kwargs)
|
||||
|
||||
set_random_seed(seed)
|
||||
|
||||
yield llm
|
||||
del llm
|
||||
cleanup()
|
||||
|
||||
for llm in generator_inner():
|
||||
yield llm
|
||||
del llm
|
50
tests/spec_decode/e2e/test_correctness.py
Normal file
50
tests/spec_decode/e2e/test_correctness.py
Normal file
@ -0,0 +1,50 @@
|
||||
import pytest
|
||||
|
||||
from vllm import SamplingParams
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Use a small model for a fast test.
|
||||
"model": "facebook/opt-125m",
|
||||
"speculative_model": "facebook/opt-125m",
|
||||
"num_speculative_tokens": 5,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_spec_decode_config(test_llm_generator):
|
||||
output_len = 1024
|
||||
temperature = 0.0
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=output_len,
|
||||
ignore_eos=True,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
AssertionError,
|
||||
match="Speculative decoding not yet supported for GPU backend"):
|
||||
get_token_ids_from_llm_generator(test_llm_generator, prompts,
|
||||
sampling_params)
|
||||
|
||||
|
||||
def get_token_ids_from_llm_generator(llm_generator, prompts, sampling_params):
|
||||
for llm in llm_generator:
|
||||
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
|
||||
token_ids = [output.outputs[0].token_ids for output in outputs]
|
||||
del llm
|
||||
|
||||
return token_ids
|
@ -7,6 +7,7 @@ from .utils import create_seq_group_metadata_from_prompts, mock_worker
|
||||
|
||||
|
||||
@pytest.mark.parametrize('num_target_seq_ids', [100])
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_create_target_seq_id_iterator(num_target_seq_ids: int):
|
||||
"""Verify all new sequence ids are greater than all input
|
||||
seq ids.
|
||||
@ -27,6 +28,7 @@ def test_create_target_seq_id_iterator(num_target_seq_ids: int):
|
||||
|
||||
|
||||
@pytest.mark.parametrize('k', [1, 2, 6])
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_get_token_ids_to_score(k: int):
|
||||
"""Verify correct tokens are selected for scoring.
|
||||
"""
|
||||
@ -53,6 +55,7 @@ def test_get_token_ids_to_score(k: int):
|
||||
|
||||
|
||||
@pytest.mark.parametrize('k', [1, 2, 6])
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_create_single_target_seq_group_metadata(k: int):
|
||||
"""Verify correct creation of a batch-expanded seq group metadata.
|
||||
"""
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user