mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-21 07:13:52 +08:00
Compare commits
77 Commits
v0.8.2
...
sampler-en
Author | SHA1 | Date | |
---|---|---|---|
4c42267293 | |||
24f68342b4 | |||
c5d963835b | |||
b313220727 | |||
15dac210f0 | |||
112b3e5b3b | |||
32d669275b | |||
4098b72210 | |||
46450b8d33 | |||
13ac9cab21 | |||
66aa4c0bf4 | |||
247181536f | |||
07bf813fb5 | |||
8958217ad5 | |||
ac5bc615b0 | |||
8063dfc61a | |||
6278bc829e | |||
3f532cb6a6 | |||
e6c9053f9e | |||
43ed4143c4 | |||
f4c98b4d4c | |||
e1e0fd7543 | |||
df8d3d1287 | |||
619d3de8bd | |||
ecff8309a3 | |||
dcf2a590f5 | |||
54aa619459 | |||
fb22be5817 | |||
7f301dd8ef | |||
8095341a01 | |||
69db16a46a | |||
ce78f9af4e | |||
9239bf718e | |||
7a6d45bc8a | |||
e74ff409e0 | |||
7a888271f5 | |||
9d119a86ae | |||
b2e85e26f4 | |||
dd8a29da99 | |||
27df5199d9 | |||
35fad35a48 | |||
733e7c9e95 | |||
0af4d764d6 | |||
e64afa455c | |||
1711b929b6 | |||
c091c0a588 | |||
1aa162e030 | |||
cf5c8f1686 | |||
4ec2cee000 | |||
99f536f830 | |||
5ebf66748b | |||
781d056280 | |||
5aefd6ac31 | |||
6c663dfd5e | |||
33437bc6e7 | |||
23114d3364 | |||
997c8811d6 | |||
e42389f9d7 | |||
ff38f0a32c | |||
a5cfbab3c8 | |||
ac3cd6e83c | |||
082ab86f5f | |||
6aa196c8dc | |||
a0dd7dcd49 | |||
e977c11111 | |||
5f063a80bd | |||
5d8e1c9279 | |||
0a049c7d86 | |||
d0cfec7ab9 | |||
a608160027 | |||
3f04a7fbf2 | |||
5994430b84 | |||
a9e879b316 | |||
3e2f37a69a | |||
4f044b1d67 | |||
4157f563b4 | |||
051da7efe3 |
@ -134,9 +134,10 @@ if [[ $commands == *"--shard-id="* ]]; then
|
||||
# assign shard-id for each shard
|
||||
commands_gpu=${commands//"--shard-id= "/"--shard-id=${GPU} "}
|
||||
echo "Shard ${GPU} commands:$commands_gpu"
|
||||
echo "Render devices: $BUILDKITE_AGENT_META_DATA_RENDER_DEVICES"
|
||||
docker run \
|
||||
--device /dev/kfd --device /dev/dri \
|
||||
--network host \
|
||||
--device /dev/kfd $BUILDKITE_AGENT_META_DATA_RENDER_DEVICES \
|
||||
--network=host \
|
||||
--shm-size=16gb \
|
||||
--rm \
|
||||
-e HIP_VISIBLE_DEVICES="${GPU}" \
|
||||
@ -163,9 +164,10 @@ if [[ $commands == *"--shard-id="* ]]; then
|
||||
fi
|
||||
done
|
||||
else
|
||||
echo "Render devices: $BUILDKITE_AGENT_META_DATA_RENDER_DEVICES"
|
||||
docker run \
|
||||
--device /dev/kfd --device /dev/dri \
|
||||
--network host \
|
||||
--device /dev/kfd $BUILDKITE_AGENT_META_DATA_RENDER_DEVICES \
|
||||
--network=host \
|
||||
--shm-size=16gb \
|
||||
--rm \
|
||||
-e HIP_VISIBLE_DEVICES=0 \
|
||||
|
@ -38,6 +38,8 @@ function cpu_tests() {
|
||||
set -e
|
||||
pip install -r vllm/requirements/test.txt
|
||||
pip install -r vllm/requirements/cpu.txt
|
||||
pytest -v -s tests/kernels/test_cache.py -m cpu_model
|
||||
pytest -v -s tests/kernels/test_mla_decode_cpu.py -m cpu_model
|
||||
pytest -v -s tests/models/decoder_only/language -m cpu_model
|
||||
pytest -v -s tests/models/embedding/language -m cpu_model
|
||||
pytest -v -s tests/models/encoder_decoder/language -m cpu_model
|
||||
|
@ -22,7 +22,7 @@ docker run --privileged --net host --shm-size=16G -it \
|
||||
&& export VLLM_USE_V1=1 \
|
||||
&& export VLLM_XLA_CHECK_RECOMPILATION=1 \
|
||||
&& echo TEST_1 \
|
||||
&& pytest /workspace/vllm/tests/tpu/test_compilation.py \
|
||||
&& pytest -v -s /workspace/vllm/tests/tpu/test_compilation.py \
|
||||
&& echo TEST_2 \
|
||||
&& pytest -v -s /workspace/vllm/tests/v1/tpu/test_basic.py \
|
||||
&& echo TEST_3 \
|
||||
@ -30,7 +30,11 @@ docker run --privileged --net host --shm-size=16G -it \
|
||||
&& echo TEST_4 \
|
||||
&& pytest -s -v /workspace/vllm/tests/tpu/test_quantization_accuracy.py \
|
||||
&& echo TEST_5 \
|
||||
&& python3 /workspace/vllm/examples/offline_inference/tpu.py" \
|
||||
&& python3 /workspace/vllm/examples/offline_inference/tpu.py \
|
||||
&& echo TEST_6 \
|
||||
&& pytest -s -v /workspace/vllm/tests/tpu/worker/test_tpu_model_runner.py \
|
||||
&& echo TEST_7 \
|
||||
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_sampler.py" \
|
||||
|
||||
|
||||
# TODO: This test fails because it uses RANDOM_SEED sampling
|
||||
|
@ -135,12 +135,14 @@ steps:
|
||||
- examples/offline_inference/rlhf.py
|
||||
- examples/offline_inference/rlhf_colocate.py
|
||||
- tests/examples/offline_inference/data_parallel.py
|
||||
- tests/v1/test_async_llm_dp.py
|
||||
commands:
|
||||
# test with tp=2 and external_dp=2
|
||||
- VLLM_USE_V1=0 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
|
||||
- torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
|
||||
# test with internal dp
|
||||
- python3 ../examples/offline_inference/data_parallel.py
|
||||
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
|
||||
- pytest -v -s distributed/test_utils.py
|
||||
- pytest -v -s compile/test_basic_correctness.py
|
||||
- pytest -v -s distributed/test_pynccl.py
|
||||
@ -287,7 +289,7 @@ steps:
|
||||
source_file_dependencies:
|
||||
- vllm/lora
|
||||
- tests/lora
|
||||
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py --ignore=lora/test_chatglm3_tp.py --ignore=lora/test_llama_tp.py --ignore=lora/test_minicpmv_tp.py --ignore=lora/test_transfomers_model.py
|
||||
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_chatglm3_tp.py --ignore=lora/test_llama_tp.py --ignore=lora/test_minicpmv_tp.py --ignore=lora/test_transfomers_model.py
|
||||
parallelism: 4
|
||||
|
||||
- label: PyTorch Fullgraph Smoke Test # 9min
|
||||
@ -514,7 +516,10 @@ steps:
|
||||
- vllm/worker/worker.py
|
||||
- vllm/worker/model_runner.py
|
||||
- entrypoints/llm/test_collective_rpc.py
|
||||
- tests/v1/test_async_llm_dp.py
|
||||
- vllm/v1/engine/
|
||||
commands:
|
||||
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
|
||||
- VLLM_ENABLE_V1_MULTIPROCESSING=0 pytest -v -s entrypoints/llm/test_collective_rpc.py
|
||||
- pytest -v -s ./compile/test_basic_correctness.py
|
||||
- pytest -v -s ./compile/test_wrapper.py
|
||||
@ -592,8 +597,6 @@ steps:
|
||||
# FIXIT: find out which code initialize cuda before running the test
|
||||
# before the fix, we need to use spawn to test it
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
# This test runs llama 13B, so it is required to run on 4 GPUs.
|
||||
- pytest -v -s -x lora/test_long_context.py
|
||||
# There is some Tensor Parallelism related processing logic in LoRA that
|
||||
# requires multi-GPU testing for validation.
|
||||
- pytest -v -s -x lora/test_chatglm3_tp.py
|
||||
|
30
.github/mergify.yml
vendored
30
.github/mergify.yml
vendored
@ -88,6 +88,36 @@ pull_request_rules:
|
||||
add:
|
||||
- v1
|
||||
|
||||
- name: label-tpu
|
||||
description: Automatically apply tpu label
|
||||
# Keep this list in sync with `label-tpu-remove` conditions
|
||||
conditions:
|
||||
- or:
|
||||
- files~=tpu.py
|
||||
- files~=_tpu
|
||||
- files~=tpu_
|
||||
- files~=/tpu/
|
||||
- files~=pallas
|
||||
actions:
|
||||
label:
|
||||
add:
|
||||
- tpu
|
||||
|
||||
- name: label-tpu-remove
|
||||
description: Automatically remove tpu label
|
||||
# Keep this list in sync with `label-tpu` conditions
|
||||
conditions:
|
||||
- and:
|
||||
- -files~=tpu.py
|
||||
- -files~=_tpu
|
||||
- -files~=tpu_
|
||||
- -files~=/tpu/
|
||||
- -files~=pallas
|
||||
actions:
|
||||
label:
|
||||
remove:
|
||||
- tpu
|
||||
|
||||
- name: ping author on conflicts and add 'needs-rebase' label
|
||||
conditions:
|
||||
- conflict
|
||||
|
@ -461,6 +461,33 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
set(FP4_ARCHS)
|
||||
endif()
|
||||
|
||||
#
|
||||
# CUTLASS MoE kernels
|
||||
|
||||
# The MoE kernel cutlass_moe_mm requires CUDA 12.3 or later (and only works
|
||||
# on Hopper). get_cutlass_moe_mm_data should only be compiled if it's possible
|
||||
# to compile MoE kernels that use its output.
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS)
|
||||
set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu"
|
||||
"csrc/quantization/cutlass_w8a8/moe/moe_data.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${SCALED_MM_ARCHS}")
|
||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM90=1")
|
||||
message(STATUS "Building grouped_mm_c3x for archs: ${SCALED_MM_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS)
|
||||
message(STATUS "Not building grouped_mm_c3x kernels as CUDA Compiler version is "
|
||||
"not >= 12.3, we recommend upgrading to CUDA 12.3 or later "
|
||||
"if you intend on running FP8 quantized MoE models on Hopper.")
|
||||
else()
|
||||
message(STATUS "Not building grouped_mm_c3x as no compatible archs found "
|
||||
"in CUDA target architectures")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
#
|
||||
# Machete kernels
|
||||
|
||||
|
@ -1,37 +1,267 @@
|
||||
FROM mambaorg/micromamba
|
||||
ARG MAMBA_DOCKERFILE_ACTIVATE=1
|
||||
USER root
|
||||
ARG BASE_UBI_IMAGE_TAG=9.5-1741850109
|
||||
|
||||
ENV PATH="/usr/local/cargo/bin:$PATH:/opt/conda/bin/"
|
||||
###############################################################
|
||||
# base stage with basic dependencies
|
||||
###############################################################
|
||||
|
||||
RUN apt-get update -y && apt-get install -y git wget kmod curl vim libnuma-dev libsndfile-dev libprotobuf-dev build-essential ffmpeg libsm6 libxext6 libgl1 libssl-dev
|
||||
FROM registry.access.redhat.com/ubi9/ubi-minimal:${BASE_UBI_IMAGE_TAG} AS base-builder
|
||||
|
||||
# Some packages in requirements/cpu are installed here
|
||||
# IBM provides optimized packages for ppc64le processors in the open-ce project for mamba
|
||||
# Currently these may not be available for venv or pip directly
|
||||
RUN micromamba install -y -n base -c https://ftp.osuosl.org/pub/open-ce/1.11.0-p10/ -c defaults python=3.10 rust && micromamba clean --all --yes
|
||||
ARG PYTHON_VERSION=3.12
|
||||
ARG OPENBLAS_VERSION=0.3.29
|
||||
|
||||
# Set Environment Variables for venv, cargo & openblas
|
||||
ENV VIRTUAL_ENV=/opt/vllm
|
||||
ENV PATH=${VIRTUAL_ENV}/bin:/root/.cargo/bin:$PATH
|
||||
ENV PKG_CONFIG_PATH=/usr/local/lib/pkgconfig/
|
||||
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib64:/usr/local/lib:/usr/lib64:/usr/lib
|
||||
ENV UV_LINK_MODE=copy
|
||||
|
||||
# install gcc-13, python, rust, openblas
|
||||
# Note: A symlink for libatomic.so is created for gcc-13 (linker fails to find libatomic otherwise - reqd. for sentencepiece)
|
||||
# Note: A dummy file 'control' is created in /tmp/ to artificially create dependencies between stages when building stages in parallel
|
||||
# when `--jobs=<N>` is passed with podman build command
|
||||
RUN microdnf install -y openssl-devel dnf \
|
||||
&& dnf install -y https://mirror.stream.centos.org/9-stream/BaseOS/`arch`/os/Packages/centos-gpg-keys-9.0-24.el9.noarch.rpm \
|
||||
https://mirror.stream.centos.org/9-stream/BaseOS/`arch`/os/Packages/centos-stream-repos-9.0-24.el9.noarch.rpm \
|
||||
https://dl.fedoraproject.org/pub/epel/epel-release-latest-9.noarch.rpm \
|
||||
&& dnf config-manager --add-repo https://mirror.stream.centos.org/9-stream/BaseOS/`arch`/os \
|
||||
&& dnf config-manager --add-repo https://mirror.stream.centos.org/9-stream/AppStream/`arch`/os \
|
||||
&& dnf config-manager --set-enabled crb \
|
||||
&& dnf install -y \
|
||||
git tar gcc-toolset-13 automake libtool numactl-devel lapack-devel \
|
||||
pkgconfig xsimd zeromq-devel kmod findutils protobuf* \
|
||||
libtiff-devel libjpeg-devel openjpeg2-devel zlib-devel \
|
||||
freetype-devel lcms2-devel libwebp-devel tcl-devel tk-devel \
|
||||
harfbuzz-devel fribidi-devel libraqm-devel libimagequant-devel libxcb-devel \
|
||||
python${PYTHON_VERSION}-devel python${PYTHON_VERSION}-pip \
|
||||
&& dnf clean all \
|
||||
&& ln -sf /usr/lib64/libatomic.so.1 /usr/lib64/libatomic.so \
|
||||
&& python${PYTHON_VERSION} -m venv ${VIRTUAL_ENV} \
|
||||
&& python -m pip install -U pip uv \
|
||||
&& uv pip install wheel build "setuptools<70" setuptools_scm setuptools_rust meson-python cmake ninja cython scikit_build_core scikit_build \
|
||||
&& curl -sL https://ftp2.osuosl.org/pub/ppc64el/openblas/latest/Openblas_${OPENBLAS_VERSION}_ppc64le.tar.gz | tar xvf - -C /usr/local \
|
||||
&& curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y \
|
||||
&& cd /tmp && touch control
|
||||
|
||||
###############################################################
|
||||
# Stage to build torch family
|
||||
###############################################################
|
||||
|
||||
FROM base-builder AS torch-builder
|
||||
|
||||
ARG MAX_JOBS
|
||||
ARG TORCH_VERSION=2.6.0
|
||||
ARG _GLIBCXX_USE_CXX11_ABI=1
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
source /opt/rh/gcc-toolset-13/enable && \
|
||||
git clone --recursive https://github.com/pytorch/pytorch.git -b v${TORCH_VERSION} && \
|
||||
cd pytorch && \
|
||||
uv pip install -r requirements.txt && \
|
||||
python setup.py develop && \
|
||||
rm -f dist/torch*+git*whl && \
|
||||
MAX_JOBS=${MAX_JOBS:-$(nproc)} \
|
||||
PYTORCH_BUILD_VERSION=${TORCH_VERSION} PYTORCH_BUILD_NUMBER=1 uv build --wheel --out-dir /torchwheels/
|
||||
|
||||
ARG TORCHVISION_VERSION=0.21.0
|
||||
ARG TORCHVISION_USE_NVJPEG=0
|
||||
ARG TORCHVISION_USE_FFMPEG=0
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
source /opt/rh/gcc-toolset-13/enable && \
|
||||
git clone --recursive https://github.com/pytorch/vision.git -b v${TORCHVISION_VERSION} && \
|
||||
cd vision && \
|
||||
MAX_JOBS=${MAX_JOBS:-$(nproc)} \
|
||||
BUILD_VERSION=${TORCHVISION_VERSION} \
|
||||
uv build --wheel --out-dir /torchwheels/ --no-build-isolation
|
||||
|
||||
ARG TORCHAUDIO_VERSION=2.6.0
|
||||
ARG BUILD_SOX=1
|
||||
ARG BUILD_KALDI=1
|
||||
ARG BUILD_RNNT=1
|
||||
ARG USE_FFMPEG=0
|
||||
ARG USE_ROCM=0
|
||||
ARG USE_CUDA=0
|
||||
ARG TORCHAUDIO_TEST_ALLOW_SKIP_IF_NO_FFMPEG=1
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
source /opt/rh/gcc-toolset-13/enable && \
|
||||
git clone --recursive https://github.com/pytorch/audio.git -b v${TORCHAUDIO_VERSION} && \
|
||||
cd audio && \
|
||||
MAX_JOBS=${MAX_JOBS:-$(nproc)} \
|
||||
BUILD_VERSION=${TORCHAUDIO_VERSION} \
|
||||
uv build --wheel --out-dir /torchwheels/ --no-build-isolation
|
||||
|
||||
###############################################################
|
||||
# Stage to build pyarrow
|
||||
###############################################################
|
||||
|
||||
FROM base-builder AS arrow-builder
|
||||
|
||||
ARG MAX_JOBS
|
||||
ARG PYARROW_PARALLEL
|
||||
ARG PYARROW_VERSION=19.0.1
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
source /opt/rh/gcc-toolset-13/enable && \
|
||||
git clone --recursive https://github.com/apache/arrow.git -b apache-arrow-${PYARROW_VERSION} && \
|
||||
cd arrow/cpp && \
|
||||
mkdir build && cd build && \
|
||||
cmake -DCMAKE_BUILD_TYPE=release \
|
||||
-DCMAKE_INSTALL_PREFIX=/usr/local \
|
||||
-DARROW_PYTHON=ON \
|
||||
-DARROW_BUILD_TESTS=OFF \
|
||||
-DARROW_JEMALLOC=ON \
|
||||
-DARROW_BUILD_STATIC="OFF" \
|
||||
-DARROW_PARQUET=ON \
|
||||
.. && \
|
||||
make install -j ${MAX_JOBS:-$(nproc)} && \
|
||||
cd ../../python/ && \
|
||||
uv pip install -v -r requirements-wheel-build.txt && \
|
||||
PYARROW_PARALLEL=${PYARROW_PARALLEL:-$(nproc)} \
|
||||
python setup.py build_ext \
|
||||
--build-type=release --bundle-arrow-cpp \
|
||||
bdist_wheel --dist-dir /arrowwheels/
|
||||
|
||||
###############################################################
|
||||
# Stage to build opencv
|
||||
###############################################################
|
||||
|
||||
FROM base-builder AS cv-builder
|
||||
|
||||
ARG MAX_JOBS
|
||||
ARG OPENCV_VERSION=84
|
||||
ARG ENABLE_HEADLESS=1
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
source /opt/rh/gcc-toolset-13/enable && \
|
||||
git clone --recursive https://github.com/opencv/opencv-python.git -b ${OPENCV_VERSION} && \
|
||||
cd opencv-python && \
|
||||
sed -i 's/"setuptools==59.2.0",/"setuptools<70.0",/g' pyproject.toml && \
|
||||
python -m build --wheel --installer=uv --outdir /opencvwheels/
|
||||
|
||||
###############################################################
|
||||
# Stage to build vllm - this stage builds and installs
|
||||
# vllm, tensorizer and vllm-tgis-adapter and builds uv cache
|
||||
# for transitive dependencies - eg. grpcio
|
||||
###############################################################
|
||||
|
||||
FROM base-builder AS vllmcache-builder
|
||||
|
||||
COPY --from=torch-builder /tmp/control /dev/null
|
||||
COPY --from=arrow-builder /tmp/control /dev/null
|
||||
COPY --from=cv-builder /tmp/control /dev/null
|
||||
|
||||
ARG VLLM_TARGET_DEVICE=cpu
|
||||
|
||||
# this step installs vllm and populates uv cache
|
||||
# with all the transitive dependencies
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
--mount=type=bind,from=torch-builder,source=/torchwheels/,target=/torchwheels/,ro \
|
||||
--mount=type=bind,from=arrow-builder,source=/arrowwheels/,target=/arrowwheels/,ro \
|
||||
--mount=type=bind,from=cv-builder,source=/opencvwheels/,target=/opencvwheels/,ro \
|
||||
--mount=type=bind,src=.,dst=/src/,rw \
|
||||
source /opt/rh/gcc-toolset-13/enable && \
|
||||
uv pip install /opencvwheels/*.whl /arrowwheels/*.whl /torchwheels/*.whl && \
|
||||
sed -i -e 's/.*torch.*//g' /src/pyproject.toml /src/requirements/*.txt && \
|
||||
uv pip install pandas pythran pybind11 && \
|
||||
# sentencepiece.pc is in some pkgconfig inside uv cache
|
||||
export PKG_CONFIG_PATH=$(find / -type d -name "pkgconfig" 2>/dev/null | tr '\n' ':') && \
|
||||
uv pip install -r /src/requirements/common.txt -r /src/requirements/cpu.txt -r /src/requirements/build.txt --no-build-isolation && \
|
||||
cd /src/ && \
|
||||
uv build --wheel --out-dir /vllmwheel/ --no-build-isolation && \
|
||||
uv pip install /vllmwheel/*.whl
|
||||
|
||||
|
||||
###############################################################
|
||||
# Stage to build numactl
|
||||
###############################################################
|
||||
|
||||
FROM base-builder AS numa-builder
|
||||
|
||||
# Note: Building numactl with gcc-11. Compiling with gcc-13 in this builder stage will
|
||||
# trigger recompilation with gcc-11 (and require libtool) in the final stage where we do not have gcc-13
|
||||
ARG MAX_JOBS
|
||||
ARG NUMACTL_VERSION=2.0.19
|
||||
RUN git clone --recursive https://github.com/numactl/numactl.git -b v${NUMACTL_VERSION} \
|
||||
&& cd numactl \
|
||||
&& autoreconf -i && ./configure \
|
||||
&& make -j ${MAX_JOBS:-$(nproc)}
|
||||
|
||||
###############################################################
|
||||
# Stage to build lapack
|
||||
###############################################################
|
||||
|
||||
FROM base-builder AS lapack-builder
|
||||
|
||||
ARG MAX_JOBS
|
||||
ARG LAPACK_VERSION=3.12.1
|
||||
RUN git clone --recursive https://github.com/Reference-LAPACK/lapack.git -b v${LAPACK_VERSION} \
|
||||
&& cd lapack && source /opt/rh/gcc-toolset-13/enable \
|
||||
&& cmake -B build -S . \
|
||||
&& cmake --build build -j ${MAX_JOBS:-$(nproc)}
|
||||
|
||||
|
||||
###############################################################
|
||||
# FINAL VLLM IMAGE STAGE #
|
||||
###############################################################
|
||||
|
||||
FROM registry.access.redhat.com/ubi9/ubi-minimal:${BASE_UBI_IMAGE_TAG} AS vllm-openai
|
||||
|
||||
ARG PYTHON_VERSION=3.12
|
||||
ARG OPENBLAS_VERSION=0.3.29
|
||||
|
||||
# Set Environment Variables for venv & openblas
|
||||
ENV VIRTUAL_ENV=/opt/vllm
|
||||
ENV PATH=${VIRTUAL_ENV}/bin:$PATH
|
||||
ENV PKG_CONFIG_PATH=/usr/local/lib/pkgconfig/
|
||||
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib64:/usr/local/lib:/usr/lib64:/usr/lib
|
||||
ENV UV_LINK_MODE=copy
|
||||
|
||||
# create artificial dependencies between stages for independent stages to build in parallel
|
||||
COPY --from=torch-builder /tmp/control /dev/null
|
||||
COPY --from=arrow-builder /tmp/control /dev/null
|
||||
COPY --from=cv-builder /tmp/control /dev/null
|
||||
COPY --from=vllmcache-builder /tmp/control /dev/null
|
||||
COPY --from=numa-builder /tmp/control /dev/null
|
||||
COPY --from=lapack-builder /tmp/control /dev/null
|
||||
|
||||
# install gcc-11, python, openblas, numactl, lapack
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
--mount=type=bind,from=numa-builder,source=/numactl/,target=/numactl/,rw \
|
||||
--mount=type=bind,from=lapack-builder,source=/lapack/,target=/lapack/,rw \
|
||||
rpm -ivh https://dl.fedoraproject.org/pub/epel/epel-release-latest-9.noarch.rpm && \
|
||||
microdnf install --nodocs -y \
|
||||
tar findutils openssl \
|
||||
pkgconfig xsimd g++ gcc-fortran libsndfile \
|
||||
libtiff libjpeg openjpeg2 zlib zeromq \
|
||||
freetype lcms2 libwebp tcl tk utf8proc \
|
||||
harfbuzz fribidi libraqm libimagequant libxcb \
|
||||
python${PYTHON_VERSION}-devel python${PYTHON_VERSION}-pip \
|
||||
&& microdnf clean all \
|
||||
&& python${PYTHON_VERSION} -m venv ${VIRTUAL_ENV} \
|
||||
&& python -m pip install -U pip uv --no-cache \
|
||||
&& curl -sL https://ftp2.osuosl.org/pub/ppc64el/openblas/latest/Openblas_${OPENBLAS_VERSION}_ppc64le.tar.gz | tar xvf - -C /usr/local \
|
||||
&& make -C /numactl install \
|
||||
&& uv pip install cmake \
|
||||
&& cmake --install /lapack/build \
|
||||
&& uv pip uninstall cmake
|
||||
|
||||
# consume previously built wheels (including vllm)
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
--mount=type=bind,from=torch-builder,source=/torchwheels/,target=/torchwheels/,ro \
|
||||
--mount=type=bind,from=arrow-builder,source=/arrowwheels/,target=/arrowwheels/,ro \
|
||||
--mount=type=bind,from=cv-builder,source=/opencvwheels/,target=/opencvwheels/,ro \
|
||||
--mount=type=bind,from=vllmcache-builder,source=/vllmwheel/,target=/vllmwheel/,ro \
|
||||
HOME=/root uv pip install /opencvwheels/*.whl /arrowwheels/*.whl /torchwheels/*.whl /vllmwheel/*.whl
|
||||
|
||||
COPY ./ /workspace/vllm
|
||||
|
||||
WORKDIR /workspace/vllm
|
||||
ARG GIT_REPO_CHECK=0
|
||||
RUN --mount=type=bind,source=.git,target=.git \
|
||||
if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh; fi
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
RUSTFLAGS='-L /opt/conda/lib' pip install -v --prefer-binary --extra-index-url https://repo.fury.io/mgiessing \
|
||||
'cmake>=3.26' ninja packaging 'setuptools-scm>=8' wheel jinja2 \
|
||||
-r requirements/cpu.txt \
|
||||
xformers uvloop==0.20.0
|
||||
|
||||
RUN --mount=type=bind,source=.git,target=.git \
|
||||
VLLM_TARGET_DEVICE=cpu python3 setup.py install
|
||||
|
||||
# install development dependencies (for testing)
|
||||
RUN python3 -m pip install -e tests/vllm_test_utils
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install -e tests/vllm_test_utils
|
||||
|
||||
WORKDIR /workspace/
|
||||
|
||||
RUN ln -s /workspace/vllm/tests && ln -s /workspace/vllm/examples && ln -s /workspace/vllm/benchmarks
|
||||
|
||||
ENTRYPOINT ["/opt/conda/bin/python3", "-m", "vllm.entrypoints.openai.api_server"]
|
||||
ENTRYPOINT ["python", "-m", "vllm.entrypoints.openai.api_server"]
|
||||
|
@ -12,7 +12,8 @@ ENV PYTORCH_ROCM_ARCH=${ARG_PYTORCH_ROCM_ARCH:-${PYTORCH_ROCM_ARCH}}
|
||||
|
||||
# Install some basic utilities
|
||||
RUN apt-get update -q -y && apt-get install -q -y \
|
||||
sqlite3 libsqlite3-dev libfmt-dev libmsgpack-dev libsuitesparse-dev
|
||||
sqlite3 libsqlite3-dev libfmt-dev libmsgpack-dev libsuitesparse-dev \
|
||||
apt-transport-https ca-certificates wget curl
|
||||
# Remove sccache
|
||||
RUN python3 -m pip install --upgrade pip && pip install setuptools_scm
|
||||
RUN apt-get purge -y sccache; python3 -m pip uninstall -y sccache; rm -f "$(which sccache)"
|
||||
|
340
benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
Normal file
340
benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
Normal file
@ -0,0 +1,340 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import torch
|
||||
import torch.utils.benchmark as benchmark
|
||||
from benchmark_shapes import WEIGHT_SHAPES_MOE
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (cutlass_moe_fp8,
|
||||
fused_experts,
|
||||
fused_topk)
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
DEFAULT_MODELS = [
|
||||
"nm-testing/Mixtral-8x7B-Instruct-v0.1", "nm-testing/deepseekv2-lite",
|
||||
"ibm-granite/granite-3.0-1b-a400m", "ibm-granite/granite-3.0-3b-a800m"
|
||||
]
|
||||
DEFAULT_BATCH_SIZES = [1, 4, 8, 16, 32, 64, 128, 256, 512]
|
||||
DEFAULT_TP_SIZES = [1]
|
||||
|
||||
PER_ACT_TOKEN_OPTS = [False]
|
||||
PER_OUT_CH_OPTS = [False]
|
||||
|
||||
|
||||
def to_fp8(tensor: torch.Tensor):
|
||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||
return torch.round(tensor.clamp(
|
||||
min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)
|
||||
|
||||
|
||||
def bench_run(results: list[benchmark.Measurement], model: str,
|
||||
num_experts: int, topk: int, per_act_token: bool,
|
||||
per_out_ch: bool, mkn: tuple[int, int, int]):
|
||||
label = "Quant Matmul"
|
||||
|
||||
sub_label = (
|
||||
"{}, num_experts={}, topk={}, per_act_token={} per_out_ch={}, "
|
||||
"MKN=({})".format(model, num_experts, topk, per_act_token, per_out_ch,
|
||||
mkn))
|
||||
|
||||
print(f"Testing: {sub_label}")
|
||||
|
||||
(m, k, n) = mkn
|
||||
|
||||
dtype = torch.half
|
||||
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
w1 = torch.randn((num_experts, 2 * n, k), device="cuda", dtype=dtype) / 10
|
||||
w2 = torch.randn((num_experts, k, n), device="cuda", dtype=dtype) / 10
|
||||
|
||||
_, a_scale = ops.scaled_fp8_quant(a)
|
||||
|
||||
w1_q = torch.empty((num_experts, 2 * n, k),
|
||||
device="cuda",
|
||||
dtype=torch.float8_e4m3fn)
|
||||
w2_q = torch.empty((num_experts, k, n),
|
||||
device="cuda",
|
||||
dtype=torch.float8_e4m3fn)
|
||||
w1_scale = torch.empty((num_experts, 1, 1),
|
||||
device="cuda",
|
||||
dtype=torch.float32)
|
||||
w2_scale = torch.empty((num_experts, 1, 1),
|
||||
device="cuda",
|
||||
dtype=torch.float32)
|
||||
|
||||
ab_strides1 = torch.full((num_experts, ),
|
||||
k,
|
||||
device="cuda",
|
||||
dtype=torch.int64)
|
||||
c_strides1 = torch.full((num_experts, ),
|
||||
2 * n,
|
||||
device="cuda",
|
||||
dtype=torch.int64)
|
||||
ab_strides2 = torch.full((num_experts, ),
|
||||
n,
|
||||
device="cuda",
|
||||
dtype=torch.int64)
|
||||
c_strides2 = torch.full((num_experts, ),
|
||||
k,
|
||||
device="cuda",
|
||||
dtype=torch.int64)
|
||||
|
||||
for expert in range(num_experts):
|
||||
w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(w1[expert])
|
||||
w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(w2[expert])
|
||||
w1_q_notransp = w1_q.clone()
|
||||
w2_q_notransp = w2_q.clone()
|
||||
w1_q = w1_q.transpose(1, 2)
|
||||
w2_q = w2_q.transpose(1, 2)
|
||||
|
||||
score = torch.randn((m, num_experts), device="cuda", dtype=dtype)
|
||||
|
||||
topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False)
|
||||
|
||||
def run_triton_moe(a: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||||
w1_scale: torch.Tensor, w2_scale: torch.Tensor,
|
||||
a_scale: torch.Tensor, num_repeats: int):
|
||||
for _ in range(num_repeats):
|
||||
fused_experts(a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a_scale)
|
||||
|
||||
def run_cutlass_moe(a: torch.Tensor, a_scale: torch.Tensor,
|
||||
w1: torch.Tensor, w2: torch.Tensor,
|
||||
w1_scale: torch.Tensor, w2_scale: torch.Tensor,
|
||||
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||||
ab_strides1: torch.Tensor, c_strides1: torch.Tensor,
|
||||
ab_strides2: torch.Tensor, c_strides2: torch.Tensor,
|
||||
num_repeats: int):
|
||||
for _ in range(num_repeats):
|
||||
cutlass_moe_fp8(a,
|
||||
w1,
|
||||
w2,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
ab_strides1,
|
||||
c_strides1,
|
||||
ab_strides2,
|
||||
c_strides2,
|
||||
a1_scale=a_scale)
|
||||
|
||||
def run_cutlass_from_graph(
|
||||
a: torch.Tensor, a_scale: torch.Tensor, w1_q: torch.Tensor,
|
||||
w2_q: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor,
|
||||
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||||
ab_strides1: torch.Tensor, c_strides1: torch.Tensor,
|
||||
ab_strides2: torch.Tensor, c_strides2: torch.Tensor):
|
||||
with set_current_vllm_config(
|
||||
VllmConfig(parallel_config=ParallelConfig(
|
||||
pipeline_parallel_size=1))):
|
||||
return cutlass_moe_fp8(a,
|
||||
w1_q,
|
||||
w2_q,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
ab_strides1,
|
||||
c_strides1,
|
||||
ab_strides2,
|
||||
c_strides2,
|
||||
a1_scale=a_scale)
|
||||
|
||||
def run_triton_from_graph(a: torch.Tensor, w1: torch.Tensor,
|
||||
w2: torch.Tensor, topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor, w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor, a_scale: torch.Tensor):
|
||||
with set_current_vllm_config(
|
||||
VllmConfig(parallel_config=ParallelConfig(
|
||||
pipeline_parallel_size=1))):
|
||||
return fused_experts(a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a_scale)
|
||||
|
||||
def replay_graph(graph, num_repeats):
|
||||
for _ in range(num_repeats):
|
||||
graph.replay()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
cutlass_stream = torch.cuda.Stream()
|
||||
cutlass_graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(cutlass_graph, stream=cutlass_stream):
|
||||
run_cutlass_from_graph(a, a_scale, w1_q, w2_q, w1_scale, w2_scale,
|
||||
topk_weights, topk_ids, ab_strides1, c_strides1,
|
||||
ab_strides2, c_strides2)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
triton_stream = torch.cuda.Stream()
|
||||
triton_graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(triton_graph, stream=triton_stream):
|
||||
run_triton_from_graph(a, w1_q_notransp, w2_q_notransp, topk_weights,
|
||||
topk_ids, w1_scale, w2_scale, a_scale)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
min_run_time = 5
|
||||
num_warmup = 5
|
||||
num_runs = 25
|
||||
|
||||
globals = {
|
||||
# Baseline params
|
||||
"w1": w1,
|
||||
"w2": w2,
|
||||
"score": score,
|
||||
"topk": topk,
|
||||
"w1_q_notransp": w1_q_notransp,
|
||||
"w2_q_notransp": w2_q_notransp,
|
||||
# Cutlass params
|
||||
"a_scale": a_scale,
|
||||
"w1_q": w1_q,
|
||||
"w2_q": w2_q,
|
||||
"w1_scale": w1_scale,
|
||||
"w2_scale": w2_scale,
|
||||
"ab_strides1": ab_strides1,
|
||||
"c_strides1": c_strides1,
|
||||
"ab_strides2": ab_strides2,
|
||||
"c_strides2": c_strides2,
|
||||
# cuda graph params
|
||||
"cutlass_graph": cutlass_graph,
|
||||
"triton_graph": triton_graph,
|
||||
# Gen params
|
||||
"a": a,
|
||||
"topk_weights": topk_weights,
|
||||
"topk_ids": topk_ids,
|
||||
"num_runs": num_runs,
|
||||
# Kernels
|
||||
"run_triton_moe": run_triton_moe,
|
||||
"run_cutlass_moe": run_cutlass_moe,
|
||||
"replay_graph": replay_graph,
|
||||
}
|
||||
|
||||
# Warmup
|
||||
run_triton_moe(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids,
|
||||
w1_scale, w2_scale, a_scale, num_warmup)
|
||||
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt=
|
||||
"run_triton_moe(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids, w1_scale, w2_scale, a_scale, num_runs)", # noqa: E501
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description="triton_moe",
|
||||
).blocked_autorange(min_run_time=min_run_time))
|
||||
|
||||
# Warmup
|
||||
replay_graph(triton_graph, num_warmup)
|
||||
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt="replay_graph(triton_graph, num_runs)",
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description="triton_moe_cuda_graphs",
|
||||
).blocked_autorange(min_run_time=min_run_time))
|
||||
|
||||
# Warmup
|
||||
run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights,
|
||||
topk_ids, ab_strides1, c_strides1, ab_strides2, c_strides2,
|
||||
num_warmup)
|
||||
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt=
|
||||
"run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, ab_strides1, c_strides1, ab_strides2, c_strides2, num_runs)", # noqa: E501
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description="grouped_gemm_moe",
|
||||
).blocked_autorange(min_run_time=min_run_time))
|
||||
|
||||
# Warmup
|
||||
replay_graph(cutlass_graph, num_warmup)
|
||||
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt="replay_graph(cutlass_graph, num_runs)",
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description="grouped_gemm_moe_cuda_graphs",
|
||||
).blocked_autorange(min_run_time=min_run_time))
|
||||
|
||||
|
||||
def main(args):
|
||||
print("Benchmarking models:")
|
||||
for i, model in enumerate(args.models):
|
||||
print(f"[{i}] {model}")
|
||||
|
||||
results: list[benchmark.Measurement] = []
|
||||
|
||||
for model in args.models:
|
||||
for tp in args.tp_sizes:
|
||||
for layer in WEIGHT_SHAPES_MOE[model]:
|
||||
num_experts = layer[0]
|
||||
topk = layer[1]
|
||||
size_k = layer[2]
|
||||
size_n = layer[3] // tp
|
||||
|
||||
if len(args.limit_k) > 0 and size_k not in args.limit_k:
|
||||
continue
|
||||
|
||||
if len(args.limit_n) > 0 and size_n not in args.limit_n:
|
||||
continue
|
||||
|
||||
for per_act_token in PER_ACT_TOKEN_OPTS:
|
||||
for per_out_ch in PER_OUT_CH_OPTS:
|
||||
for size_m in DEFAULT_BATCH_SIZES:
|
||||
mkn = (size_m, size_k, size_n)
|
||||
bench_run(results, model, num_experts, topk,
|
||||
per_act_token, per_out_ch, mkn)
|
||||
|
||||
compare = benchmark.Compare(results)
|
||||
compare.print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser(
|
||||
description="Benchmark Marlin across specified models/shapes/batches")
|
||||
parser.add_argument(
|
||||
"--models",
|
||||
nargs="+",
|
||||
type=str,
|
||||
default=DEFAULT_MODELS,
|
||||
choices=WEIGHT_SHAPES_MOE.keys(),
|
||||
)
|
||||
parser.add_argument("--tp-sizes",
|
||||
nargs="+",
|
||||
type=int,
|
||||
default=DEFAULT_TP_SIZES)
|
||||
parser.add_argument("--batch-sizes",
|
||||
nargs="+",
|
||||
type=int,
|
||||
default=DEFAULT_BATCH_SIZES)
|
||||
parser.add_argument("--limit-k", nargs="+", type=int, default=[])
|
||||
parser.add_argument("--limit-n", nargs="+", type=int, default=[])
|
||||
parser.add_argument("--limit-num-groups", nargs="+", type=int, default=[])
|
||||
parser.add_argument("--limit-per-act-token",
|
||||
nargs="+",
|
||||
type=int,
|
||||
default=[])
|
||||
parser.add_argument("--limit-per-out-ch", nargs="+", type=int, default=[])
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
@ -7,10 +7,13 @@ from typing import Optional
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser,
|
||||
create_kv_caches_with_random)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
NUM_BLOCKS = 128 * 1024
|
||||
PARTITION_SIZE = 512
|
||||
PARTITION_SIZE_ROCM = 256
|
||||
@ -193,6 +196,9 @@ def main(
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
logger.warning("This script benchmarks the paged attention kernel. "
|
||||
"By default this is no longer used in vLLM inference.")
|
||||
|
||||
parser = FlexibleArgumentParser(
|
||||
description="Benchmark the paged attention kernel.")
|
||||
parser.add_argument("--version",
|
||||
|
@ -75,3 +75,19 @@ WEIGHT_SHAPES = {
|
||||
[7168, 8192],
|
||||
],
|
||||
}
|
||||
|
||||
WEIGHT_SHAPES_MOE = {
|
||||
"nm-testing/Mixtral-8x7B-Instruct-v0.1": [
|
||||
[8, 2, 4096, 28672],
|
||||
[8, 2, 14336, 4096],
|
||||
],
|
||||
"nm-testing/deepseekv2-lite": [
|
||||
[64, 6, 2048, 1408],
|
||||
],
|
||||
"ibm-granite/granite-3.0-1b-a400m": [
|
||||
[32, 8, 1024, 1024],
|
||||
],
|
||||
"ibm-granite/granite-3.0-3b-a800m": [
|
||||
[40, 8, 1024, 1536],
|
||||
],
|
||||
}
|
||||
|
@ -190,6 +190,7 @@ set(VLLM_EXT_SRC
|
||||
"csrc/cpu/cache.cpp"
|
||||
"csrc/cpu/utils.cpp"
|
||||
"csrc/cpu/layernorm.cpp"
|
||||
"csrc/cpu/mla_decode.cpp"
|
||||
"csrc/cpu/pos_encoding.cpp"
|
||||
"csrc/cpu/torch_bindings.cpp")
|
||||
|
||||
|
@ -88,6 +88,48 @@ void reshape_and_cache_cpu_impl(
|
||||
}
|
||||
}; // namespace
|
||||
|
||||
template <typename scalar_t>
|
||||
void concat_and_cache_mla_cpu_impl(
|
||||
const scalar_t* __restrict__ kv_c, // [num_tokens, kv_lora_rank]
|
||||
const scalar_t* __restrict__ k_pe, // [num_tokens, pe_dim]
|
||||
scalar_t* __restrict__ kv_cache, // [num_blocks, block_size, (kv_lora_rank
|
||||
// + pe_dim)]
|
||||
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
||||
const int num_tokens, //
|
||||
const int block_stride, //
|
||||
const int entry_stride, //
|
||||
const int kv_c_stride, //
|
||||
const int k_pe_stride, //
|
||||
const int kv_lora_rank, //
|
||||
const int pe_dim, //
|
||||
const int block_size //
|
||||
) {
|
||||
#pragma omp parallel for
|
||||
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||
const int64_t slot_idx = slot_mapping[token_idx];
|
||||
// NOTE: slot_idx can be -1 if the token is padded
|
||||
if (slot_idx < 0) {
|
||||
continue;
|
||||
}
|
||||
const int64_t block_idx = slot_idx / block_size;
|
||||
const int64_t block_offset = slot_idx % block_size;
|
||||
|
||||
auto copy = [&](const scalar_t* __restrict__ src,
|
||||
scalar_t* __restrict__ dst, int src_stride, int dst_stride,
|
||||
int size, int offset) {
|
||||
for (int i = 0; i < size; i++) {
|
||||
const int64_t src_idx = token_idx * src_stride + i;
|
||||
const int64_t dst_idx =
|
||||
block_idx * block_stride + block_offset * entry_stride + i + offset;
|
||||
dst[dst_idx] = src[src_idx];
|
||||
}
|
||||
};
|
||||
|
||||
copy(kv_c, kv_cache, kv_c_stride, block_stride, kv_lora_rank, 0);
|
||||
copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank);
|
||||
}
|
||||
}
|
||||
|
||||
// Note: the key_caches and value_caches vectors are constant but
|
||||
// not the Tensors they contain. The vectors need to be const refs
|
||||
// in order to satisfy pytorch's C++ operator registration code.
|
||||
@ -134,6 +176,38 @@ void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
|
||||
});
|
||||
}
|
||||
|
||||
void concat_and_cache_mla(
|
||||
torch::Tensor& kv_c, // [num_tokens, kv_lora_rank]
|
||||
torch::Tensor& k_pe, // [num_tokens, pe_dim]
|
||||
torch::Tensor& kv_cache, // [num_blocks, block_size, (kv_lora_rank +
|
||||
// pe_dim)]
|
||||
torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens]
|
||||
const std::string& kv_cache_dtype, torch::Tensor& scale) {
|
||||
int num_tokens = slot_mapping.size(0);
|
||||
int kv_lora_rank = kv_c.size(1);
|
||||
int pe_dim = k_pe.size(1);
|
||||
int block_size = kv_cache.size(1);
|
||||
|
||||
TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim);
|
||||
TORCH_CHECK(kv_cache_dtype != "fp8");
|
||||
|
||||
int kv_c_stride = kv_c.stride(0);
|
||||
int k_pe_stride = k_pe.stride(0);
|
||||
int block_stride = kv_cache.stride(0);
|
||||
int entry_stride = kv_cache.stride(1);
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
kv_c.scalar_type(), "concat_and_cache_mla_cpu_impl", [&] {
|
||||
CPU_KERNEL_GUARD_IN(concat_and_cache_mla_cpu_impl)
|
||||
concat_and_cache_mla_cpu_impl<scalar_t>(
|
||||
kv_c.data_ptr<scalar_t>(), k_pe.data_ptr<scalar_t>(),
|
||||
kv_cache.data_ptr<scalar_t>(), slot_mapping.data_ptr<int64_t>(),
|
||||
num_tokens, block_stride, entry_stride, kv_c_stride, k_pe_stride,
|
||||
kv_lora_rank, pe_dim, block_size);
|
||||
CPU_KERNEL_GUARD_OUT(concat_and_cache_mla_cpu_impl)
|
||||
});
|
||||
}
|
||||
|
||||
void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
|
||||
const torch::Tensor& block_mapping) {
|
||||
TORCH_CHECK(false, "swap_blocks is unsupported on CPU.")
|
||||
|
@ -130,6 +130,8 @@ struct BF16Vec32 : public Vec<BF16Vec32> {
|
||||
|
||||
__m512i reg;
|
||||
|
||||
explicit BF16Vec32() : reg(_mm512_setzero_si512()) {}
|
||||
|
||||
explicit BF16Vec32(const void* ptr) : reg((__m512i)_mm512_loadu_si512(ptr)) {}
|
||||
|
||||
explicit BF16Vec32(__m512i data) : reg(data) {}
|
||||
|
393
csrc/cpu/mla_decode.cpp
Normal file
393
csrc/cpu/mla_decode.cpp
Normal file
@ -0,0 +1,393 @@
|
||||
#include "cpu_types.hpp"
|
||||
#include <float.h>
|
||||
|
||||
namespace {
|
||||
template <typename scalar_t>
|
||||
struct KernelVecType {
|
||||
using qk_load_vec_type = void;
|
||||
using qk_vec_type = void;
|
||||
using v_load_vec_type = void;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct KernelVecType<float> {
|
||||
using qk_load_vec_type = vec_op::FP32Vec16;
|
||||
using qk_vec_type = vec_op::FP32Vec16;
|
||||
using v_load_vec_type = vec_op::FP32Vec16;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct KernelVecType<c10::Half> {
|
||||
#if defined(__powerpc64__) || defined(__s390x__)
|
||||
// Power and s390x architecture-specific vector types
|
||||
using qk_load_vec_type = vec_op::FP32Vec16;
|
||||
using qk_vec_type = vec_op::FP32Vec16;
|
||||
using v_load_vec_type = vec_op::FP32Vec16;
|
||||
#else
|
||||
// Fallback for other architectures, including x86
|
||||
using qk_load_vec_type = vec_op::FP16Vec16;
|
||||
using qk_vec_type = vec_op::FP32Vec16;
|
||||
using v_load_vec_type = vec_op::FP16Vec16;
|
||||
#endif
|
||||
};
|
||||
|
||||
#ifdef __AVX512BF16__
|
||||
template <>
|
||||
struct KernelVecType<c10::BFloat16> {
|
||||
using qk_load_vec_type = vec_op::BF16Vec32;
|
||||
using qk_vec_type = vec_op::BF16Vec32;
|
||||
using v_load_vec_type = vec_op::BF16Vec16;
|
||||
};
|
||||
#elif defined(__aarch64__) && !defined(ARM_BF16_SUPPORT)
|
||||
// pass
|
||||
#else
|
||||
template <>
|
||||
struct KernelVecType<c10::BFloat16> {
|
||||
using qk_load_vec_type = vec_op::BF16Vec16;
|
||||
using qk_vec_type = vec_op::FP32Vec16;
|
||||
using v_load_vec_type = vec_op::BF16Vec16;
|
||||
};
|
||||
#endif
|
||||
|
||||
template <int HEAD_DIM, int V_HEAD_DIM, int BLOCK_SIZE, int HEAD_UNROLL,
|
||||
typename qk_vec_type>
|
||||
void mla_decode_block_head(
|
||||
const qk_vec_type* __restrict__ q_vecs, // [HEAD_UNROLL, head_dim]
|
||||
const qk_vec_type* __restrict__ k_vecs, // [block_size, head_dim]
|
||||
const vec_op::FP32Vec16* __restrict v_vecs_f32, // [block_size, v_head_dim]
|
||||
float* __restrict__ acc_out, // [HEAD_UNROLL, v_head_dim]
|
||||
float* __restrict__ acc_lse, // [HEAD_UNROLL]
|
||||
const float scale, const int num_tokens) {
|
||||
using f32_vec_type = vec_op::FP32Vec16;
|
||||
constexpr int QK_NUM_ELEM = qk_vec_type::VEC_ELEM_NUM;
|
||||
constexpr int V_NUM_ELEM = f32_vec_type::VEC_ELEM_NUM;
|
||||
|
||||
float logits[BLOCK_SIZE][HEAD_UNROLL] = {}; // initialize to zeros
|
||||
float max_val[HEAD_UNROLL];
|
||||
std::fill(max_val, max_val + HEAD_UNROLL, -FLT_MAX);
|
||||
|
||||
f32_vec_type acc_vec[BLOCK_SIZE][HEAD_UNROLL];
|
||||
for (int i = 0; i < HEAD_DIM; i += QK_NUM_ELEM) {
|
||||
// load to registers
|
||||
qk_vec_type q_vec[HEAD_UNROLL];
|
||||
|
||||
#pragma unroll
|
||||
for (int unroll = 0; unroll < HEAD_UNROLL; ++unroll)
|
||||
q_vec[unroll] =
|
||||
qk_vec_type{q_vecs[(i + unroll * HEAD_DIM) / QK_NUM_ELEM]};
|
||||
|
||||
for (int block_offset = 0; block_offset < num_tokens; ++block_offset) {
|
||||
qk_vec_type k_vec(k_vecs[(block_offset * HEAD_DIM + i) / QK_NUM_ELEM]);
|
||||
|
||||
#pragma unroll
|
||||
for (int unroll = 0; unroll < HEAD_UNROLL; ++unroll)
|
||||
vec_op::fma(acc_vec[block_offset][unroll], q_vec[unroll], k_vec);
|
||||
}
|
||||
}
|
||||
|
||||
for (int block_offset = 0; block_offset < num_tokens; ++block_offset) {
|
||||
#pragma unroll
|
||||
for (int unroll = 0; unroll < HEAD_UNROLL; ++unroll) {
|
||||
const float acc = acc_vec[block_offset][unroll].reduce_sum() * scale;
|
||||
logits[block_offset][unroll] = acc;
|
||||
max_val[unroll] = std::max(max_val[unroll], acc);
|
||||
}
|
||||
}
|
||||
|
||||
float sum_exp[HEAD_UNROLL] = {};
|
||||
for (int block_offset = 0; block_offset < num_tokens; ++block_offset) {
|
||||
#pragma unroll
|
||||
for (int unroll = 0; unroll < HEAD_UNROLL; ++unroll) {
|
||||
const float val =
|
||||
std::exp(logits[block_offset][unroll] - max_val[unroll]);
|
||||
logits[block_offset][unroll] = val;
|
||||
sum_exp[unroll] += val;
|
||||
}
|
||||
}
|
||||
|
||||
f32_vec_type this_out[V_HEAD_DIM / V_NUM_ELEM][HEAD_UNROLL];
|
||||
|
||||
for (int block_offset = 0; block_offset < num_tokens; ++block_offset) {
|
||||
// load to registers
|
||||
f32_vec_type scale_[HEAD_UNROLL];
|
||||
|
||||
#pragma unroll
|
||||
for (int unroll = 0; unroll < HEAD_UNROLL; ++unroll)
|
||||
scale_[unroll] =
|
||||
f32_vec_type{logits[block_offset][unroll] / sum_exp[unroll]};
|
||||
|
||||
for (int i = 0; i < V_HEAD_DIM; i += V_NUM_ELEM) {
|
||||
f32_vec_type v_vec(
|
||||
v_vecs_f32[(block_offset * HEAD_DIM + i) / V_NUM_ELEM]);
|
||||
|
||||
#pragma unroll
|
||||
for (int unroll = 0; unroll < HEAD_UNROLL; ++unroll)
|
||||
vec_op::fma(this_out[i / V_NUM_ELEM][unroll], v_vec, scale_[unroll]);
|
||||
}
|
||||
}
|
||||
|
||||
// merge attention state
|
||||
// section 2.2 in https://arxiv.org/pdf/2501.01005
|
||||
f32_vec_type prev_scale[HEAD_UNROLL];
|
||||
f32_vec_type curr_scale[HEAD_UNROLL];
|
||||
|
||||
#pragma unroll
|
||||
for (int unroll = 0; unroll < HEAD_UNROLL; ++unroll) {
|
||||
const float prev_lse = acc_lse[unroll];
|
||||
const float curr_lse = std::log(sum_exp[unroll]) +
|
||||
max_val[unroll]; // add back max_val to get true lse
|
||||
// softmax trick
|
||||
const float max_lse = std::max(prev_lse, curr_lse);
|
||||
const float prev_sum_exp = std::exp(prev_lse - max_lse);
|
||||
const float curr_sum_exp = std::exp(curr_lse - max_lse);
|
||||
|
||||
const float new_sum_exp = prev_sum_exp + curr_sum_exp;
|
||||
acc_lse[unroll] = std::log(new_sum_exp) + max_lse;
|
||||
|
||||
prev_scale[unroll] = f32_vec_type{prev_sum_exp / new_sum_exp};
|
||||
curr_scale[unroll] = f32_vec_type{curr_sum_exp / new_sum_exp};
|
||||
}
|
||||
|
||||
for (int i = 0; i < V_HEAD_DIM; i += V_NUM_ELEM) {
|
||||
#pragma unroll
|
||||
for (int unroll = 0; unroll < HEAD_UNROLL; ++unroll) {
|
||||
f32_vec_type o_vec(acc_out + i + V_HEAD_DIM * unroll);
|
||||
o_vec = o_vec * prev_scale[unroll] +
|
||||
this_out[i / V_NUM_ELEM][unroll] * curr_scale[unroll];
|
||||
o_vec.save(acc_out + i + V_HEAD_DIM * unroll);
|
||||
}
|
||||
}
|
||||
|
||||
q_vecs += HEAD_DIM / QK_NUM_ELEM * HEAD_UNROLL;
|
||||
acc_out += V_HEAD_DIM * HEAD_UNROLL;
|
||||
}
|
||||
|
||||
template <typename scalar_t, int HEAD_DIM, int V_HEAD_DIM, int BLOCK_SIZE,
|
||||
typename qk_vec_type>
|
||||
void mla_decode_block(
|
||||
const qk_vec_type* __restrict__ q_vecs, // [num_heads, head_dim]
|
||||
const scalar_t* __restrict__ kv_cache, // [block_size, head_dim]
|
||||
float* __restrict__ acc_out, // [num_heads, v_head_dim]
|
||||
float* __restrict__ acc_lse, // [num_heads]
|
||||
const int num_heads, const float scale, const int num_tokens) {
|
||||
using qk_load_vec_type = typename KernelVecType<scalar_t>::qk_load_vec_type;
|
||||
static_assert(
|
||||
std::is_same<qk_vec_type,
|
||||
typename KernelVecType<scalar_t>::qk_vec_type>::value);
|
||||
using v_load_vec_type = typename KernelVecType<scalar_t>::v_load_vec_type;
|
||||
using f32_vec_type = vec_op::FP32Vec16;
|
||||
static_assert(qk_load_vec_type::VEC_ELEM_NUM == qk_vec_type::VEC_ELEM_NUM);
|
||||
static_assert(v_load_vec_type::VEC_ELEM_NUM == f32_vec_type::VEC_ELEM_NUM);
|
||||
constexpr int QK_NUM_ELEM = qk_vec_type::VEC_ELEM_NUM;
|
||||
constexpr int V_NUM_ELEM = v_load_vec_type::VEC_ELEM_NUM;
|
||||
|
||||
const qk_vec_type* k_vecs;
|
||||
const f32_vec_type* v_vecs_f32;
|
||||
float* kv_cache_f32 = nullptr;
|
||||
|
||||
if constexpr (!std::is_same<scalar_t, float>::value) {
|
||||
// convert KV cache block to FP32 to reuse it across query heads and
|
||||
// attn @ V computation, since FP16/BF16->FP32 is expensive.
|
||||
// TODO: move malloc outside of this fn to reuse across iterations.
|
||||
const int nbytes = BLOCK_SIZE * HEAD_DIM * sizeof(float);
|
||||
kv_cache_f32 = static_cast<float*>(std::aligned_alloc(64, nbytes));
|
||||
|
||||
for (int block_offset = 0; block_offset < num_tokens; ++block_offset)
|
||||
for (int i = 0; i < HEAD_DIM; i += V_NUM_ELEM) {
|
||||
v_load_vec_type kv_load_vec(kv_cache + block_offset * HEAD_DIM + i);
|
||||
f32_vec_type kv_vec_f32(kv_load_vec);
|
||||
kv_vec_f32.save(kv_cache_f32 + block_offset * HEAD_DIM + i);
|
||||
}
|
||||
|
||||
if constexpr (std::is_same<qk_load_vec_type, qk_vec_type>::value) {
|
||||
// for AVX512_BF16, Q @ K.T uses BF16 for K (no conversion)
|
||||
// NOTE: in this case, we only need to convert the V section to FP32.
|
||||
// But for simplicity, we will convert the whole KV block to FP32.
|
||||
k_vecs = reinterpret_cast<const qk_vec_type*>(kv_cache);
|
||||
} else {
|
||||
k_vecs = reinterpret_cast<const qk_vec_type*>(kv_cache_f32);
|
||||
}
|
||||
|
||||
// attn @ V always use FP32 for V, since attn is FP32.
|
||||
v_vecs_f32 = reinterpret_cast<const f32_vec_type*>(kv_cache_f32);
|
||||
|
||||
} else {
|
||||
// KV cache is FP32. don't need to do anything.
|
||||
k_vecs = reinterpret_cast<const qk_vec_type*>(kv_cache);
|
||||
v_vecs_f32 = reinterpret_cast<const f32_vec_type*>(kv_cache);
|
||||
}
|
||||
|
||||
// compute 2 heads at the same time to improve ILP and
|
||||
// take advantage of register cache for K and V.
|
||||
constexpr int HEAD_UNROLL = 2;
|
||||
for (int iter = 0; iter < num_heads / HEAD_UNROLL; ++iter) {
|
||||
mla_decode_block_head<HEAD_DIM, V_HEAD_DIM, BLOCK_SIZE, HEAD_UNROLL>(
|
||||
q_vecs, k_vecs, v_vecs_f32, acc_out, acc_lse, scale, num_tokens);
|
||||
|
||||
q_vecs += HEAD_UNROLL * HEAD_DIM / QK_NUM_ELEM;
|
||||
acc_out += HEAD_UNROLL * V_HEAD_DIM;
|
||||
acc_lse += HEAD_UNROLL;
|
||||
}
|
||||
|
||||
// take care of the remaining heads
|
||||
for (int iter = 0; iter < num_heads % HEAD_UNROLL; ++iter) {
|
||||
mla_decode_block_head<HEAD_DIM, V_HEAD_DIM, BLOCK_SIZE, 1>(
|
||||
q_vecs, k_vecs, v_vecs_f32, acc_out, acc_lse, scale, num_tokens);
|
||||
|
||||
q_vecs += HEAD_DIM / QK_NUM_ELEM;
|
||||
acc_out += V_HEAD_DIM;
|
||||
acc_lse += 1;
|
||||
}
|
||||
|
||||
if (kv_cache_f32 != nullptr) {
|
||||
std::free(kv_cache_f32);
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
template <typename scalar_t, int HEAD_DIM, int V_HEAD_DIM, int BLOCK_SIZE>
|
||||
void mla_decode_kvcache_cpu_impl(
|
||||
scalar_t* __restrict__ out, // [num_seqs, num_heads, v_head_dim]
|
||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_dim]
|
||||
const scalar_t* __restrict__ kv_cache, // [num_blocks, block_size,
|
||||
// head_dim]
|
||||
const int num_heads, const float scale,
|
||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
const int* __restrict__ seq_lens, // [num_seqs]
|
||||
const int max_num_blocks_per_seq, const int o_stride, const int q_stride,
|
||||
const int kv_stride, const int num_seqs) {
|
||||
using qk_load_vec_type = typename KernelVecType<scalar_t>::qk_load_vec_type;
|
||||
using qk_vec_type = typename KernelVecType<scalar_t>::qk_vec_type;
|
||||
constexpr int QK_NUM_ELEM = qk_vec_type::VEC_ELEM_NUM;
|
||||
|
||||
// shared across threads
|
||||
const int max_threads = omp_get_max_threads();
|
||||
const int acc_out_nbytes =
|
||||
max_threads * num_heads * V_HEAD_DIM * sizeof(float);
|
||||
float* acc_out = static_cast<float*>(std::aligned_alloc(64, acc_out_nbytes));
|
||||
std::vector<float> acc_lse(max_threads * num_heads);
|
||||
|
||||
// allocate memory to pre-convert query to FP32 later
|
||||
float* q_f32;
|
||||
constexpr bool PRE_CONVERT_QUERY =
|
||||
!std::is_same<scalar_t, float>::value &&
|
||||
std::is_same<qk_vec_type, vec_op::FP32Vec16>::value;
|
||||
if constexpr (PRE_CONVERT_QUERY) {
|
||||
const int q_f32_nbytes = num_heads * HEAD_DIM * sizeof(float);
|
||||
q_f32 = static_cast<float*>(std::aligned_alloc(64, q_f32_nbytes));
|
||||
}
|
||||
|
||||
#pragma omp parallel
|
||||
{
|
||||
const int num_threads = omp_get_num_threads();
|
||||
const int thread_id = omp_get_thread_num();
|
||||
float* __restrict__ acc_out_thread =
|
||||
acc_out + thread_id * num_heads * V_HEAD_DIM;
|
||||
float* __restrict__ acc_lse_thread = acc_lse.data() + thread_id * num_heads;
|
||||
|
||||
for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
|
||||
// reset accumulator
|
||||
std::fill(acc_out_thread, acc_out_thread + num_heads * V_HEAD_DIM, 0.0f);
|
||||
std::fill(acc_lse_thread, acc_lse_thread + num_heads, -FLT_MAX);
|
||||
|
||||
const int seq_len = seq_lens[seq_idx];
|
||||
const int block_num = (seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
const int last_block_size = seq_len - (block_num - 1) * BLOCK_SIZE;
|
||||
|
||||
const qk_vec_type* q_vecs;
|
||||
if constexpr (PRE_CONVERT_QUERY) {
|
||||
// pre-convert query to FP32 since FP16/BF16->FP32 is slow.
|
||||
#pragma omp for
|
||||
for (int i = 0; i < num_heads * HEAD_DIM; i += QK_NUM_ELEM) {
|
||||
qk_load_vec_type q_load_vec(q + seq_idx * q_stride + i);
|
||||
qk_vec_type q_vec(q_load_vec);
|
||||
q_vec.save(q_f32 + i);
|
||||
}
|
||||
q_vecs = reinterpret_cast<const qk_vec_type*>(q_f32);
|
||||
} else {
|
||||
q_vecs = reinterpret_cast<const qk_vec_type*>(q + seq_idx * q_stride);
|
||||
}
|
||||
|
||||
#pragma omp for
|
||||
for (int block_idx = 0; block_idx < block_num; ++block_idx) {
|
||||
const int physical_block_idx =
|
||||
block_tables[seq_idx * max_num_blocks_per_seq + block_idx];
|
||||
const int num_tokens =
|
||||
block_idx < block_num - 1 ? BLOCK_SIZE : last_block_size;
|
||||
|
||||
mla_decode_block<scalar_t, HEAD_DIM, V_HEAD_DIM, BLOCK_SIZE>(
|
||||
q_vecs, kv_cache + physical_block_idx * kv_stride, acc_out_thread,
|
||||
acc_lse_thread, num_heads, scale, num_tokens);
|
||||
}
|
||||
|
||||
// merge attention states across threads
|
||||
// section 2.2 in https://arxiv.org/pdf/2501.01005
|
||||
// each thread is responsible for 1 head
|
||||
#pragma omp for
|
||||
for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
|
||||
float* acc_lse_head = acc_lse.data() + head_idx;
|
||||
float* acc_out_head = acc_out + head_idx * V_HEAD_DIM;
|
||||
|
||||
float max_val = -FLT_MAX;
|
||||
for (int thread_id_ = 0; thread_id_ < num_threads; ++thread_id_) {
|
||||
max_val = std::max(max_val, acc_lse_head[thread_id_ * num_heads]);
|
||||
}
|
||||
|
||||
float sum_exp = 0.0f;
|
||||
for (int thread_id_ = 0; thread_id_ < num_threads; ++thread_id_) {
|
||||
float val = std::exp(acc_lse_head[thread_id_ * num_heads] - max_val);
|
||||
acc_lse_head[thread_id_ * num_heads] = val;
|
||||
sum_exp += val;
|
||||
}
|
||||
|
||||
float inv_sum = 1.0f / sum_exp;
|
||||
float out_head[V_HEAD_DIM] = {};
|
||||
for (int thread_id_ = 0; thread_id_ < num_threads; ++thread_id_) {
|
||||
float scale_ = acc_lse_head[thread_id_ * num_heads] * inv_sum;
|
||||
for (int i = 0; i < V_HEAD_DIM; ++i) {
|
||||
out_head[i] +=
|
||||
acc_out_head[thread_id_ * num_heads * V_HEAD_DIM + i] * scale_;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < V_HEAD_DIM; ++i) {
|
||||
vec_op::storeFP32(out_head[i], out + seq_idx * o_stride +
|
||||
head_idx * V_HEAD_DIM + i);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (PRE_CONVERT_QUERY) {
|
||||
std::free(q_f32);
|
||||
}
|
||||
std::free(acc_out);
|
||||
}
|
||||
|
||||
void mla_decode_kvcache(torch::Tensor& out, torch::Tensor& query,
|
||||
torch::Tensor& kv_cache, double scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& seq_lens) {
|
||||
const int num_seqs = query.size(0);
|
||||
const int num_heads = query.size(1);
|
||||
const int head_dim = query.size(2);
|
||||
const int block_size = kv_cache.size(1);
|
||||
const int v_head_dim = out.size(2);
|
||||
|
||||
const int max_num_blocks_per_seq = block_tables.size(1);
|
||||
const int o_stride = out.stride(0);
|
||||
const int q_stride = query.stride(0);
|
||||
const int kv_stride = kv_cache.stride(0);
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
query.scalar_type(), "mla_decode_kvcache_cpu_impl", [&] {
|
||||
CPU_KERNEL_GUARD_IN(mla_decode_kvcache_cpu_impl)
|
||||
if (head_dim == 576 && v_head_dim == 512 && block_size == 16)
|
||||
mla_decode_kvcache_cpu_impl<scalar_t, 576, 512, 16>(
|
||||
out.data_ptr<scalar_t>(), query.data_ptr<scalar_t>(),
|
||||
kv_cache.data_ptr<scalar_t>(), num_heads, scale,
|
||||
block_tables.data_ptr<int>(), seq_lens.data_ptr<int>(),
|
||||
max_num_blocks_per_seq, o_stride, q_stride, kv_stride, num_seqs);
|
||||
else
|
||||
TORCH_CHECK(false, "Unsupported block size: ", block_size);
|
||||
CPU_KERNEL_GUARD_OUT(mla_decode_kvcache_cpu_impl)
|
||||
});
|
||||
}
|
@ -18,6 +18,10 @@ void int8_scaled_mm_azp(torch::Tensor& c, const torch::Tensor& a,
|
||||
const std::optional<torch::Tensor>& azp,
|
||||
const std::optional<torch::Tensor>& bias);
|
||||
|
||||
void mla_decode_kvcache(torch::Tensor& out, torch::Tensor& query,
|
||||
torch::Tensor& kv_cache, double scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& seq_lens);
|
||||
|
||||
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
// vLLM custom ops
|
||||
|
||||
@ -150,6 +154,14 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
||||
" str kv_cache_dtype,"
|
||||
" Tensor k_scale, Tensor v_scale) -> ()");
|
||||
cache_ops.impl("reshape_and_cache", torch::kCPU, &reshape_and_cache);
|
||||
|
||||
cache_ops.def(
|
||||
"concat_and_cache_mla(Tensor kv_c, Tensor k_pe,"
|
||||
" Tensor! kv_cache,"
|
||||
" Tensor slot_mapping,"
|
||||
" str kv_cache_dtype,"
|
||||
" Tensor scale) -> ()");
|
||||
cache_ops.impl("concat_and_cache_mla", torch::kCPU, &concat_and_cache_mla);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _utils), utils) {
|
||||
@ -157,4 +169,12 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _utils), utils) {
|
||||
utils.def("init_cpu_threads_env(str cpu_ids) -> str", &init_cpu_threads_env);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cpu), cpu_ops) {
|
||||
cpu_ops.def(
|
||||
"mla_decode_kvcache("
|
||||
" Tensor! out, Tensor query, Tensor kv_cache,"
|
||||
" float scale, Tensor block_tables, Tensor seq_lens) -> ()");
|
||||
cpu_ops.impl("mla_decode_kvcache", torch::kCPU, &mla_decode_kvcache);
|
||||
}
|
||||
|
||||
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
||||
|
@ -48,4 +48,14 @@ struct enable_sm90_or_later : Kernel {
|
||||
Kernel::operator()(std::forward<Args>(args)...);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
template <typename Kernel>
|
||||
struct enable_sm90_only : Kernel {
|
||||
template <typename... Args>
|
||||
CUTLASS_DEVICE void operator()(Args&&... args) {
|
||||
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ == 900
|
||||
Kernel::operator()(std::forward<Args>(args)...);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
@ -0,0 +1,457 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
//
|
||||
// This file is a modified excerpt of
|
||||
// include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp
|
||||
// from https://github.com/NVIDIA/cutlass v3.5.0
|
||||
// It has been modified to support either row/column or scalar broadcasting
|
||||
// where the tensor being loaded from is always passed in via a device pointer.
|
||||
// This lets one compiled kernel handle all cases of per-tensor or
|
||||
// per-channel/per-token quantization.
|
||||
//
|
||||
// This interface also allows the scales to be passed in as tensors that
|
||||
// consistently reside on the device, which avoids an issue with a previous
|
||||
// implementation where scalars needed to be on the CPU since they
|
||||
// were passed in via float values. This created a potential performance hazard
|
||||
// if scales were initially on the device, and caused torch.compile graphs
|
||||
// breaks when moving scales to the CPU.
|
||||
//
|
||||
#pragma once
|
||||
|
||||
// Turn off clang-format for the entire file to keep it close to upstream
|
||||
// clang-format off
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/arch/barrier.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp"
|
||||
|
||||
namespace cutlass::epilogue::fusion {
|
||||
|
||||
using namespace cute;
|
||||
using namespace detail;
|
||||
|
||||
// Row vector broadcast
|
||||
template<
|
||||
int Stages,
|
||||
class CtaTileShapeMNK,
|
||||
class Element,
|
||||
class StrideMNL = Stride<_0,_1,_0>,
|
||||
int Alignment = 128 / sizeof_bits_v<Element>
|
||||
>
|
||||
struct Sm90RowOrScalarBroadcastArray {
|
||||
static_assert(Stages == 0, "Row broadcast doesn't support smem usage");
|
||||
static_assert(is_static_v<decltype(take<0,2>(StrideMNL{}))>); // batch stride can be dynamic or static
|
||||
static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_1>{});
|
||||
|
||||
struct SharedStorage {
|
||||
array_aligned<Element, size<1>(CtaTileShapeMNK{})> smem;
|
||||
};
|
||||
|
||||
// This struct has been modified to have a bool indicating that ptr_row is a
|
||||
// scalar that must be broadcast, instead of containing a scalar that is
|
||||
// valid if ptr_row is null.
|
||||
struct Arguments {
|
||||
const Element* const* ptr_row_array = nullptr;
|
||||
bool row_broadcast = true;
|
||||
StrideMNL dRow = {};
|
||||
};
|
||||
|
||||
using Params = Arguments;
|
||||
|
||||
template <class ProblemShape>
|
||||
static constexpr Params
|
||||
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
|
||||
return args;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static bool
|
||||
can_implement(ProblemShape const& problem_shape, Arguments const& args) {
|
||||
return true;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static size_t
|
||||
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static cutlass::Status
|
||||
initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
|
||||
CudaHostAdapter* cuda_adapter = nullptr) {
|
||||
return cutlass::Status::kSuccess;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Sm90RowOrScalarBroadcastArray() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Sm90RowOrScalarBroadcastArray(Params const& params, SharedStorage const& shared_storage)
|
||||
: params(params)
|
||||
, smem(const_cast<Element*>(shared_storage.smem.data())) { }
|
||||
|
||||
Params params;
|
||||
Element *smem = nullptr;
|
||||
|
||||
CUTLASS_DEVICE bool
|
||||
is_producer_load_needed() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE bool
|
||||
is_C_load_needed() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE bool
|
||||
is_zero() const {
|
||||
return (!params.row_broadcast && *(params.ptr_row_array[group]) == Element(0));
|
||||
}
|
||||
|
||||
template <class... Args>
|
||||
CUTLASS_DEVICE auto
|
||||
get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) {
|
||||
return EmptyProducerLoadCallbacks{};
|
||||
}
|
||||
|
||||
template <class GS_GTensor, class GS_STensor, class GS_CTensor, class Tiled_G2S, class SR_STensor, class SR_RTensor, class CTensor, class ThrResidue, class ThrNum>
|
||||
struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
|
||||
CUTLASS_DEVICE
|
||||
ConsumerStoreCallbacks(
|
||||
GS_GTensor tGS_gRow_, GS_STensor tGS_sRow_,
|
||||
GS_CTensor tGS_cRow_, Tiled_G2S tiled_g2s_,
|
||||
SR_STensor tSR_sRow_, SR_RTensor tSR_rRow_,
|
||||
CTensor tCcRow_, ThrResidue residue_tCcRow_, ThrNum thr_num_,
|
||||
int group, Params const& params_)
|
||||
: tGS_gRow(tGS_gRow_)
|
||||
, tGS_sRow(tGS_sRow_)
|
||||
, tGS_cRow(tGS_cRow_)
|
||||
, tiled_G2S(tiled_g2s_)
|
||||
, tSR_sRow(tSR_sRow_)
|
||||
, tSR_rRow(tSR_rRow_)
|
||||
, tCcRow(tCcRow_)
|
||||
, residue_tCcRow(residue_tCcRow_)
|
||||
, group(group)
|
||||
, params(params_) {}
|
||||
|
||||
GS_GTensor tGS_gRow; // (CPY,CPY_M,CPY_N)
|
||||
GS_STensor tGS_sRow; // (CPY,CPY_M,CPY_N)
|
||||
GS_CTensor tGS_cRow; // (CPY,CPY_M,CPY_N)
|
||||
Tiled_G2S tiled_G2S;
|
||||
|
||||
SR_STensor tSR_sRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
SR_RTensor tSR_rRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
|
||||
CTensor tCcRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
ThrResidue residue_tCcRow; // (m, n)
|
||||
ThrNum thr_num;
|
||||
int group;
|
||||
Params const& params;
|
||||
|
||||
CUTLASS_DEVICE void
|
||||
begin() {
|
||||
if (!params.row_broadcast) {
|
||||
fill(tSR_rRow, *(params.ptr_row_array[group]));
|
||||
return;
|
||||
}
|
||||
|
||||
auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); };
|
||||
Tensor tGS_gRow_flt = filter_zeros(tGS_gRow);
|
||||
Tensor tGS_sRow_flt = filter_zeros(tGS_sRow);
|
||||
Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride()));
|
||||
|
||||
for (int i = 0; i < size(tGS_gRow_flt); ++i) {
|
||||
if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) {
|
||||
continue; // OOB of SMEM,
|
||||
}
|
||||
if (elem_less(tGS_cRow_flt(i), make_coord(get<0>(residue_tCcRow), get<1>(residue_tCcRow)))) {
|
||||
tGS_sRow_flt(i) = tGS_gRow_flt(i);
|
||||
}
|
||||
else {
|
||||
tGS_sRow_flt(i) = Element(0); // Set to Zero when OOB so LDS could be issue without any preds.
|
||||
}
|
||||
}
|
||||
synchronize();
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE void
|
||||
begin_loop(int epi_m, int epi_n) {
|
||||
if (epi_m == 0) { // Assumes M-major subtile loop
|
||||
if (!params.row_broadcast) return; // Do not issue LDS when row is scalar
|
||||
Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n));
|
||||
Tensor tSR_rRow_flt = filter_zeros(tSR_rRow);
|
||||
copy(tSR_sRow_flt, tSR_rRow_flt);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ElementAccumulator, int FragmentSize>
|
||||
CUTLASS_DEVICE Array<Element, FragmentSize>
|
||||
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
|
||||
Array<Element, FragmentSize> frg_row;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < FragmentSize; ++i) {
|
||||
frg_row[i] = tSR_rRow(epi_v * FragmentSize + i);
|
||||
}
|
||||
|
||||
return frg_row;
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy
|
||||
class... Args
|
||||
>
|
||||
CUTLASS_DEVICE auto
|
||||
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
|
||||
auto [M, N, K, L] = args.problem_shape_mnkl;
|
||||
auto [m, n, k, l] = args.tile_coord_mnkl;
|
||||
using ThreadCount = decltype(size(args.tiled_copy));
|
||||
|
||||
Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row_array[l]), make_shape(M,N,1), params.dRow);
|
||||
Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N)
|
||||
Tensor sRow = make_tensor(make_smem_ptr(smem),
|
||||
make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N)
|
||||
//// G2S: Gmem to Smem
|
||||
auto tiled_g2s = make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
|
||||
Layout< Shape<_1, ThreadCount>,
|
||||
Stride<_0, _1>>{},
|
||||
Layout<_1>{});
|
||||
auto thr_g2s = tiled_g2s.get_slice(args.thread_idx);
|
||||
Tensor tGS_gRow = thr_g2s.partition_S(gRow);
|
||||
Tensor tGS_sRow = thr_g2s.partition_D(sRow);
|
||||
|
||||
//// G2S: Coord
|
||||
auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})));
|
||||
Tensor tGS_cRow = thr_g2s.partition_S(cRow);
|
||||
|
||||
//// S2R: Smem to Reg
|
||||
Tensor tSR_sRow = sm90_partition_for_epilogue<ReferenceSrc>(sRow, args.epi_tile, args.tiled_copy, args.thread_idx);
|
||||
Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N)
|
||||
|
||||
return ConsumerStoreCallbacks<decltype(tGS_gRow), decltype(tGS_sRow), decltype(tGS_cRow), decltype(tiled_g2s), decltype(tSR_sRow), decltype(tSR_rRow), decltype(args.tCcD), decltype(args.residue_cD), ThreadCount>(
|
||||
tGS_gRow,
|
||||
tGS_sRow,
|
||||
tGS_cRow, tiled_g2s,
|
||||
tSR_sRow,
|
||||
tSR_rRow,
|
||||
args.tCcD,
|
||||
args.residue_cD,
|
||||
ThreadCount{},
|
||||
l,
|
||||
params);
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Column vector broadcast
|
||||
template<
|
||||
int Stages,
|
||||
class CtaTileShapeMNK,
|
||||
class Element,
|
||||
class StrideMNL = Stride<_1,_0,_0>,
|
||||
int Alignment = 128 / sizeof_bits_v<Element>
|
||||
>
|
||||
struct Sm90ColOrScalarBroadcastArray {
|
||||
static_assert(Stages == 0, "Column broadcast doesn't support smem usage yet");
|
||||
static_assert(Alignment * sizeof_bits_v<Element> % 128 == 0, "sub-16B alignment not supported yet");
|
||||
static_assert(
|
||||
(cute::is_same_v<StrideMNL, Stride<_1,_0, _0>>) || // col vector broadcast, e.g. per-row alpha/bias
|
||||
(cute::is_same_v<StrideMNL, Stride<_1,_0,int>>)); // batched col vector broadcast, e.g. batched per-row bias
|
||||
|
||||
// Accumulator distributes col elements evenly amongst threads so we can just directly load from gmem
|
||||
struct SharedStorage { };
|
||||
|
||||
// This struct has been modified to have a bool indicating that ptr_col is a
|
||||
// scalar that must be broadcast, instead of containing a scalar that is
|
||||
// valid if ptr_col is null.
|
||||
struct Arguments {
|
||||
const Element* const* ptr_col_array = nullptr;
|
||||
bool col_broadcast = true;
|
||||
StrideMNL dCol = {};
|
||||
};
|
||||
|
||||
using Params = Arguments;
|
||||
|
||||
template <class ProblemShape>
|
||||
static constexpr Params
|
||||
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
|
||||
return args;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static bool
|
||||
can_implement(ProblemShape const& problem_shape, Arguments const& args) {
|
||||
return true;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static size_t
|
||||
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static cutlass::Status
|
||||
initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
|
||||
CudaHostAdapter* cuda_adapter = nullptr) {
|
||||
return cutlass::Status::kSuccess;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE bool
|
||||
is_producer_load_needed() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE bool
|
||||
is_C_load_needed() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE bool
|
||||
is_zero() const {
|
||||
return (!params.col_broadcast && *(params.ptr_col_array[group]) == Element(0));
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Sm90ColOrScalarBroadcastArray() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Sm90ColOrScalarBroadcastArray(Params const& params, SharedStorage const& shared_storage)
|
||||
: params(params) { }
|
||||
|
||||
Params params;
|
||||
|
||||
template <class... Args>
|
||||
CUTLASS_DEVICE auto
|
||||
get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) {
|
||||
return EmptyProducerLoadCallbacks{};
|
||||
}
|
||||
|
||||
template<class GTensor, class RTensor, class CTensor, class ProblemShape>
|
||||
struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
|
||||
CUTLASS_DEVICE
|
||||
ConsumerStoreCallbacks(
|
||||
GTensor&& tCgCol,
|
||||
RTensor&& tCrCol,
|
||||
CTensor&& tCcCol,
|
||||
ProblemShape problem_shape,
|
||||
int group,
|
||||
Params const& params
|
||||
):
|
||||
tCgCol(cute::forward<GTensor>(tCgCol)),
|
||||
tCrCol(cute::forward<RTensor>(tCrCol)),
|
||||
tCcCol(cute::forward<CTensor>(tCcCol)),
|
||||
m(get<0>(problem_shape)),
|
||||
group(group),
|
||||
params(params) {}
|
||||
|
||||
GTensor tCgCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
RTensor tCrCol;
|
||||
CTensor tCcCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
Params const& params;
|
||||
int m;
|
||||
int group;
|
||||
|
||||
CUTLASS_DEVICE void
|
||||
begin() {
|
||||
Tensor pred = make_tensor<bool>(shape(tCgCol));
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(pred); ++i) {
|
||||
pred(i) = get<0>(tCcCol(i)) < m;
|
||||
}
|
||||
|
||||
if (!params.col_broadcast) {
|
||||
fill(tCrCol, *(params.ptr_col_array[group]));
|
||||
return;
|
||||
}
|
||||
|
||||
// Filter so we don't issue redundant copies over stride-0 modes
|
||||
// (only works if 0-strides are in same location, which is by construction)
|
||||
copy_if(pred, filter(tCgCol), filter(tCrCol));
|
||||
}
|
||||
|
||||
template <typename ElementAccumulator, int FragmentSize>
|
||||
CUTLASS_DEVICE Array<Element, FragmentSize>
|
||||
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
|
||||
Array<Element, FragmentSize> frg_col;
|
||||
Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < FragmentSize; ++i) {
|
||||
frg_col[i] = tCrCol_mn(epi_v * FragmentSize + i);
|
||||
}
|
||||
|
||||
return frg_col;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
template <
|
||||
bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy
|
||||
class... Args
|
||||
>
|
||||
CUTLASS_DEVICE auto
|
||||
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
|
||||
|
||||
auto [M, N, K, L] = args.problem_shape_mnkl;
|
||||
auto [m, n, k, l] = args.tile_coord_mnkl;
|
||||
|
||||
Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col_array[l]), make_shape(M,N,1), params.dCol);
|
||||
Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
|
||||
Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
|
||||
// Generate an identity tensor matching the shape of the global tensor and
|
||||
// partition the same way, this will be used to generate the predicate
|
||||
// tensor for loading
|
||||
Tensor cCol = make_identity_tensor(mCol.shape());
|
||||
Tensor tCcCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
cCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
|
||||
|
||||
return ConsumerStoreCallbacks(
|
||||
cute::move(tCgCol),
|
||||
cute::move(tCrCol),
|
||||
cute::move(tCcCol),
|
||||
args.problem_shape_mnkl,
|
||||
l,
|
||||
params
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
}
|
@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp"
|
||||
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp"
|
||||
|
||||
/*
|
||||
This file defines custom epilogues for fusing channel scales, token scales,
|
||||
@ -69,6 +70,16 @@ struct ScaledEpilogueBase {
|
||||
0 /*Stages*/, TileShape, T, T, Stride<Int<0>, Int<1>, Int<0>>,
|
||||
128 / sizeof_bits_v<T>, EnableNullPtr>;
|
||||
|
||||
template <typename T>
|
||||
using ColOrScalarLoadArray =
|
||||
cutlass::epilogue::fusion::Sm90ColOrScalarBroadcastArray<
|
||||
0 /*Stages*/, TileShape, T, Stride<Int<1>, Int<0>, Int<0>>>;
|
||||
|
||||
template <typename T>
|
||||
using RowOrScalarLoadArray =
|
||||
cutlass::epilogue::fusion::Sm90RowOrScalarBroadcastArray<
|
||||
0 /*Stages*/, TileShape, T, Stride<Int<0>, Int<1>, Int<0>>>;
|
||||
|
||||
// This utility function constructs the arguments for the load descriptors
|
||||
// from a tensor. It can handle both row and column, as well as row/column or
|
||||
// scalar cases.
|
||||
@ -96,6 +107,14 @@ struct ScaledEpilogueBase {
|
||||
std::is_same_v<Descriptor, RowLoad<T, true>>);
|
||||
return Arguments{data_ptr};
|
||||
}
|
||||
|
||||
template <typename Descriptor, typename T>
|
||||
static auto args_from_tensor(const T* const* data_ptr, bool do_broadcast) {
|
||||
using Arguments = typename Descriptor::Arguments;
|
||||
static_assert(std::is_same_v<Descriptor, ColOrScalarLoadArray<T>> ||
|
||||
std::is_same_v<Descriptor, RowOrScalarLoadArray<T>>);
|
||||
return Arguments{data_ptr, do_broadcast};
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
@ -381,4 +400,51 @@ struct ScaledEpilogueBiasAzpToken
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
This epilogue works like ScaledEpilogue, but ScaleA and ScaleB are pointers
|
||||
to arrays containing different scales used in group gemm. The number of
|
||||
pointers in ScaleA and the number of pointers in ScaleB are equal to the
|
||||
group size.
|
||||
*/
|
||||
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
|
||||
struct ScaledEpilogueArray
|
||||
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
|
||||
private:
|
||||
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::template ColOrScalarLoadArray<float>;
|
||||
using ScaleB = typename SUPER::template RowOrScalarLoadArray<float>;
|
||||
|
||||
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTCompute0 =
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
|
||||
|
||||
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiplies, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
using EVTCompute =
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0>;
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
using ScaleAArray = typename SUPER::template ColOrScalarLoadArray<float>;
|
||||
using ScaleBArray = typename SUPER::template RowOrScalarLoadArray<float>;
|
||||
|
||||
static ArgumentType prepare_args(float const* const* a_scales_ptr,
|
||||
float const* const* b_scales_ptr,
|
||||
bool a_col_broadcast, bool b_row_broadcast) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleAArray, float>(
|
||||
a_scales_ptr, a_col_broadcast);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleBArray, float>(
|
||||
b_scales_ptr, b_row_broadcast);
|
||||
|
||||
typename EVTCompute0::Arguments evt0_args{b_args, {}, {}};
|
||||
return ArgumentType{a_args, evt0_args, {}};
|
||||
}
|
||||
};
|
||||
|
||||
}; // namespace vllm::c3x
|
||||
|
14
csrc/ops.h
14
csrc/ops.h
@ -164,6 +164,7 @@ int64_t ggml_moe_get_block_size(int64_t type);
|
||||
bool cutlass_scaled_mm_supports_fp4(int64_t cuda_device_capability);
|
||||
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
|
||||
bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability);
|
||||
bool cutlass_group_gemm_supported(int64_t cuda_device_capability);
|
||||
|
||||
void cutlass_scaled_fp4_mm(torch::Tensor& D, torch::Tensor const& A,
|
||||
torch::Tensor const& B, torch::Tensor const& A_sf,
|
||||
@ -175,6 +176,19 @@ void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
void cutlass_moe_mm(
|
||||
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides);
|
||||
|
||||
void get_cutlass_moe_mm_data(
|
||||
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
||||
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
|
||||
const int64_t num_experts, const int64_t n, const int64_t k);
|
||||
|
||||
void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
|
80
csrc/quantization/cutlass_w8a8/moe/get_group_starts.cuh
Normal file
80
csrc/quantization/cutlass_w8a8/moe/get_group_starts.cuh
Normal file
@ -0,0 +1,80 @@
|
||||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
#include <torch/all.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
|
||||
#include "core/scalar_type.hpp"
|
||||
#include "cutlass/bfloat16.h"
|
||||
#include "cutlass/float8.h"
|
||||
|
||||
template <typename ElementAB, typename ElementC, typename ElementAccumulator>
|
||||
__global__ void get_group_gemm_starts(
|
||||
int32_t* expert_offsets, ElementAB** a_offsets, ElementAB** b_offsets,
|
||||
ElementC** out_offsets, ElementAccumulator** a_scales_offsets,
|
||||
ElementAccumulator** b_scales_offsets, ElementAB* a_base_as_int,
|
||||
ElementAB* b_base_as_int, ElementC* out_base_as_int,
|
||||
ElementAccumulator* a_scales_base_as_int,
|
||||
ElementAccumulator* b_scales_base_as_int, int64_t n, int64_t k,
|
||||
bool per_act_token, bool per_out_ch) {
|
||||
int expert_id = threadIdx.x;
|
||||
|
||||
int64_t expert_offset = expert_offsets[expert_id];
|
||||
|
||||
a_offsets[expert_id] = a_base_as_int + expert_offset * k;
|
||||
b_offsets[expert_id] = b_base_as_int + expert_id * k * n;
|
||||
out_offsets[expert_id] = out_base_as_int + expert_offset * n;
|
||||
a_scales_offsets[expert_id] =
|
||||
a_scales_base_as_int + (per_act_token ? expert_offset : 0);
|
||||
b_scales_offsets[expert_id] =
|
||||
b_scales_base_as_int + (per_out_ch ? n * expert_id : expert_id);
|
||||
}
|
||||
|
||||
#define __CALL_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE) \
|
||||
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
|
||||
get_group_gemm_starts<cutlass::float_e4m3_t, C_TYPE, float> \
|
||||
<<<1, num_experts, 0, stream>>>( \
|
||||
static_cast<int32_t*>(expert_offsets.data_ptr()), \
|
||||
static_cast<cutlass::float_e4m3_t**>(a_ptrs.data_ptr()), \
|
||||
static_cast<cutlass::float_e4m3_t**>(b_ptrs.data_ptr()), \
|
||||
static_cast<C_TYPE**>(out_ptrs.data_ptr()), \
|
||||
static_cast<float**>(a_scales_ptrs.data_ptr()), \
|
||||
static_cast<float**>(b_scales_ptrs.data_ptr()), \
|
||||
static_cast<cutlass::float_e4m3_t*>(a_tensors.data_ptr()), \
|
||||
static_cast<cutlass::float_e4m3_t*>(b_tensors.data_ptr()), \
|
||||
static_cast<C_TYPE*>(out_tensors.data_ptr()), \
|
||||
static_cast<float*>(a_scales.data_ptr()), \
|
||||
static_cast<float*>(b_scales.data_ptr()), out_tensors.size(1), \
|
||||
a_tensors.size(1), per_act_token, per_out_ch); \
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
void run_get_group_gemm_starts(
|
||||
torch::Tensor const& expert_offsets, torch::Tensor& a_ptrs,
|
||||
torch::Tensor& b_ptrs, torch::Tensor& out_ptrs,
|
||||
torch::Tensor& a_scales_ptrs, torch::Tensor& b_scales_ptrs,
|
||||
torch::Tensor const& a_tensors, torch::Tensor const& b_tensors,
|
||||
torch::Tensor& out_tensors, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
|
||||
int num_experts = static_cast<int>(expert_offsets.size(0));
|
||||
bool per_act_token = a_scales.numel() != 1;
|
||||
bool per_out_ch = b_scales.numel() != num_experts;
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
|
||||
|
||||
if (false) {
|
||||
}
|
||||
__CALL_GET_STARTS_KERNEL(torch::kBFloat16, cutlass::bfloat16_t)
|
||||
__CALL_GET_STARTS_KERNEL(torch::kFloat16, half)
|
||||
else {
|
||||
TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
160
csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu
Normal file
160
csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu
Normal file
@ -0,0 +1,160 @@
|
||||
#include <cudaTypedefs.h>
|
||||
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "grouped_mm_c3x.cuh"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_fp8_config_default {
|
||||
// M in (16, inf)
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum;
|
||||
using EpilogueSchedule =
|
||||
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
|
||||
using TileShape = cute::Shape<cute::_64, cute::_256, cute::_128>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_2, cute::_1>;
|
||||
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_group_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_fp8_config_M16 {
|
||||
// M in [1, 16]
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum;
|
||||
using EpilogueSchedule =
|
||||
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
|
||||
using TileShape = cute::Shape<cute::_64, cute::_64, cute::_128>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_4, cute::_1>;
|
||||
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_group_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_fp8_config_K8192 {
|
||||
// K in [8192, inf)
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum;
|
||||
using EpilogueSchedule =
|
||||
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
|
||||
using TileShape = cute::Shape<cute::_128, cute::_128, cute::_128>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_8, cute::_1>;
|
||||
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_group_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_fp8_config_N8192 {
|
||||
// N in [8192, inf)
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum;
|
||||
using EpilogueSchedule =
|
||||
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
|
||||
using TileShape = cute::Shape<cute::_64, cute::_128, cute::_256>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_8, cute::_1>;
|
||||
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_group_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType>
|
||||
void run_cutlass_moe_mm_sm90(
|
||||
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides) {
|
||||
TORCH_CHECK(a_tensors.size(0) > 0, "No input A tensors provided.");
|
||||
TORCH_CHECK(b_tensors.size(0) > 0, "No input B tensors provided.");
|
||||
TORCH_CHECK(out_tensors.size(0) > 0, "No output tensors provided.");
|
||||
|
||||
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn,
|
||||
"A tensors must be of type float8_e4m3fn.");
|
||||
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn,
|
||||
"B tensors must be of type float8_e4m3fn.");
|
||||
|
||||
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn);
|
||||
|
||||
using Cutlass3xGemmN8192 = typename sm90_fp8_config_N8192<
|
||||
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmK8192 = typename sm90_fp8_config_K8192<
|
||||
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmM16 = typename sm90_fp8_config_M16<
|
||||
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmDefault = typename sm90_fp8_config_default<
|
||||
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
|
||||
|
||||
uint32_t const m = a_tensors.size(0);
|
||||
uint32_t const n = out_tensors.size(1);
|
||||
uint32_t const k = a_tensors.size(1);
|
||||
|
||||
if (n >= 8192) {
|
||||
cutlass_group_gemm_caller<Cutlass3xGemmN8192>(
|
||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
||||
problem_sizes, a_strides, b_strides, c_strides);
|
||||
} else if (k >= 8192) {
|
||||
cutlass_group_gemm_caller<Cutlass3xGemmK8192>(
|
||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
||||
problem_sizes, a_strides, b_strides, c_strides);
|
||||
} else if (m <= 16) {
|
||||
cutlass_group_gemm_caller<Cutlass3xGemmM16>(
|
||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
||||
problem_sizes, a_strides, b_strides, c_strides);
|
||||
} else {
|
||||
cutlass_group_gemm_caller<Cutlass3xGemmDefault>(
|
||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
||||
problem_sizes, a_strides, b_strides, c_strides);
|
||||
}
|
||||
}
|
||||
|
||||
void dispatch_moe_mm_sm90(
|
||||
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides) {
|
||||
if (out_tensors.dtype() == torch::kBFloat16) {
|
||||
run_cutlass_moe_mm_sm90<cutlass::float_e4m3_t, cutlass::bfloat16_t>(
|
||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
||||
problem_sizes, a_strides, b_strides, c_strides);
|
||||
} else {
|
||||
run_cutlass_moe_mm_sm90<cutlass::float_e4m3_t, cutlass::half_t>(
|
||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
||||
problem_sizes, a_strides, b_strides, c_strides);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void cutlass_moe_mm_sm90(
|
||||
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides) {
|
||||
dispatch_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
|
||||
expert_offsets, problem_sizes, a_strides, b_strides,
|
||||
c_strides);
|
||||
}
|
149
csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh
Normal file
149
csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh
Normal file
@ -0,0 +1,149 @@
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
#include "get_group_starts.cuh"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
namespace {
|
||||
|
||||
using ProblemShape =
|
||||
cutlass::gemm::GroupProblemShape<cute::Shape<int, int, int>>;
|
||||
|
||||
using ElementAccumulator = float;
|
||||
using ArchTag = cutlass::arch::Sm90;
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
|
||||
template <typename ElementAB_, typename ElementC_,
|
||||
template <typename, typename, typename> typename Epilogue_,
|
||||
typename TileShape, typename ClusterShape, typename KernelSchedule,
|
||||
typename EpilogueSchedule>
|
||||
struct cutlass_3x_group_gemm {
|
||||
using ElementAB = ElementAB_;
|
||||
using ElementC = void;
|
||||
using ElementD = ElementC_;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Epilogue = Epilogue_<ElementAccumulator, ElementD, TileShape>;
|
||||
|
||||
using StrideC =
|
||||
cute::remove_pointer_t<cute::Stride<int64_t, cute::Int<1>, cute::Int<0>>>;
|
||||
|
||||
static constexpr int AlignmentAB =
|
||||
128 / cutlass::sizeof_bits<ElementAB>::value;
|
||||
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
|
||||
using EVTCompute = typename Epilogue::EVTCompute;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass, TileShape, ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator,
|
||||
ElementAccumulator, ElementC, LayoutC*, AlignmentC, ElementD,
|
||||
LayoutC*, AlignmentC, EpilogueSchedule, EVTCompute>::CollectiveOp;
|
||||
|
||||
static constexpr size_t CEStorageSize =
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage);
|
||||
using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(CEStorageSize)>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass, ElementAB, LayoutA*, AlignmentAB, ElementAB,
|
||||
LayoutB*, AlignmentAB, ElementAccumulator, TileShape, ClusterShape,
|
||||
Stages, KernelSchedule>::CollectiveOp;
|
||||
|
||||
using KernelType = enable_sm90_only<cutlass::gemm::kernel::GemmUniversal<
|
||||
ProblemShape, CollectiveMainloop, CollectiveEpilogue>>;
|
||||
|
||||
struct GemmKernel : public KernelType {};
|
||||
};
|
||||
|
||||
template <typename Gemm>
|
||||
void cutlass_group_gemm_caller(
|
||||
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides) {
|
||||
using ElementAB = typename Gemm::ElementAB;
|
||||
using ElementD = typename Gemm::ElementD;
|
||||
|
||||
int num_experts = static_cast<int>(expert_offsets.size(0));
|
||||
int k_size = a_tensors.size(1);
|
||||
int n_size = out_tensors.size(1);
|
||||
|
||||
bool per_act_token = a_scales.numel() != 1;
|
||||
bool per_out_ch = b_scales.numel() != num_experts;
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
|
||||
|
||||
auto options_int =
|
||||
torch::TensorOptions().dtype(torch::kInt64).device(a_tensors.device());
|
||||
|
||||
torch::Tensor a_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor b_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor out_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int);
|
||||
|
||||
run_get_group_gemm_starts(expert_offsets, a_ptrs, b_ptrs, out_ptrs,
|
||||
a_scales_ptrs, b_scales_ptrs, a_tensors, b_tensors,
|
||||
out_tensors, a_scales, b_scales);
|
||||
|
||||
using GemmKernel = typename Gemm::GemmKernel;
|
||||
using StrideA = Stride<int64_t, Int<1>, Int<0>>;
|
||||
using StrideB = Stride<int64_t, Int<1>, Int<0>>;
|
||||
using StrideC = typename GemmKernel::InternalStrideC;
|
||||
|
||||
ProblemShape::UnderlyingProblemShape* problem_sizes_as_shapes =
|
||||
static_cast<ProblemShape::UnderlyingProblemShape*>(
|
||||
problem_sizes.data_ptr());
|
||||
ProblemShape prob_shape{num_experts, problem_sizes_as_shapes, nullptr};
|
||||
|
||||
typename GemmKernel::MainloopArguments mainloop_args{
|
||||
static_cast<const ElementAB**>(a_ptrs.data_ptr()),
|
||||
static_cast<StrideA*>(a_strides.data_ptr()),
|
||||
static_cast<const ElementAB**>(b_ptrs.data_ptr()),
|
||||
static_cast<StrideB*>(b_strides.data_ptr())};
|
||||
|
||||
// Currently, we are only able to do broadcast on either all or none a_scales
|
||||
// and on either all or none b_scales
|
||||
typename GemmKernel::EpilogueArguments epilogue_args{
|
||||
Gemm::Epilogue::prepare_args(
|
||||
static_cast<const ElementAccumulator**>(a_scales_ptrs.data_ptr()),
|
||||
static_cast<const ElementAccumulator**>(b_scales_ptrs.data_ptr()),
|
||||
per_act_token, per_out_ch),
|
||||
nullptr, static_cast<StrideC*>(c_strides.data_ptr()),
|
||||
static_cast<ElementD**>(out_ptrs.data_ptr()),
|
||||
static_cast<StrideC*>(c_strides.data_ptr())};
|
||||
|
||||
typename GemmKernel::Arguments args{
|
||||
cutlass::gemm::GemmUniversalMode::kGrouped, prob_shape, mainloop_args,
|
||||
epilogue_args};
|
||||
|
||||
using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
GemmOp gemm_op;
|
||||
CUTLASS_CHECK(gemm_op.can_implement(args));
|
||||
|
||||
size_t workspace_size = gemm_op.get_workspace_size(args);
|
||||
auto const workspace_options =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(a_tensors.device());
|
||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||
|
||||
cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream);
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
|
||||
} // namespace
|
90
csrc/quantization/cutlass_w8a8/moe/moe_data.cu
Normal file
90
csrc/quantization/cutlass_w8a8/moe/moe_data.cu
Normal file
@ -0,0 +1,90 @@
|
||||
#include <cudaTypedefs.h>
|
||||
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <iostream>
|
||||
|
||||
constexpr uint64_t THREADS_PER_EXPERT = 512;
|
||||
|
||||
__global__ void compute_problem_sizes(const int* __restrict__ topk_ids,
|
||||
int32_t* problem_sizes1,
|
||||
int32_t* problem_sizes2,
|
||||
int32_t* atomic_buffer,
|
||||
const int topk_length, const int n,
|
||||
const int k) {
|
||||
int expert_id = blockIdx.x;
|
||||
|
||||
int occurrences = 0;
|
||||
for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) {
|
||||
occurrences += (topk_ids[i] == expert_id);
|
||||
}
|
||||
atomicAdd(&atomic_buffer[expert_id], occurrences);
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
int final_occurrences = atomic_buffer[expert_id];
|
||||
problem_sizes1[expert_id * 3] = final_occurrences;
|
||||
problem_sizes1[expert_id * 3 + 1] = 2 * n;
|
||||
problem_sizes1[expert_id * 3 + 2] = k;
|
||||
problem_sizes2[expert_id * 3] = final_occurrences;
|
||||
problem_sizes2[expert_id * 3 + 1] = k;
|
||||
problem_sizes2[expert_id * 3 + 2] = n;
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void compute_expert_offsets(
|
||||
const int32_t* __restrict__ problem_sizes1, int32_t* expert_offsets,
|
||||
int32_t* atomic_buffer, const int num_experts) {
|
||||
int32_t tot_offset = 0;
|
||||
expert_offsets[0] = 0;
|
||||
for (int i = 0; i < num_experts; ++i) {
|
||||
atomic_buffer[i] = tot_offset;
|
||||
tot_offset += problem_sizes1[i * 3];
|
||||
expert_offsets[i + 1] = tot_offset;
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void compute_arg_sorts(const int* __restrict__ topk_ids,
|
||||
int32_t* input_permutation,
|
||||
int32_t* output_permutation,
|
||||
int32_t* atomic_buffer, const int topk_length,
|
||||
const int topk) {
|
||||
int expert_id = blockIdx.x;
|
||||
|
||||
for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) {
|
||||
if (topk_ids[i] == expert_id) {
|
||||
int start = atomicAdd(&atomic_buffer[expert_id], 1);
|
||||
input_permutation[start] = i / topk;
|
||||
output_permutation[i] = start;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void get_cutlass_moe_mm_data_caller(
|
||||
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
||||
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
|
||||
const int64_t num_experts, const int64_t n, const int64_t k) {
|
||||
auto stream = at::cuda::getCurrentCUDAStream(topk_ids.device().index());
|
||||
auto options_int32 =
|
||||
torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device());
|
||||
torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32);
|
||||
|
||||
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
|
||||
compute_problem_sizes<<<num_experts, num_threads, 0, stream>>>(
|
||||
static_cast<const int32_t*>(topk_ids.data_ptr()),
|
||||
static_cast<int32_t*>(problem_sizes1.data_ptr()),
|
||||
static_cast<int32_t*>(problem_sizes2.data_ptr()),
|
||||
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), n, k);
|
||||
compute_expert_offsets<<<1, 1, 0, stream>>>(
|
||||
static_cast<const int32_t*>(problem_sizes1.data_ptr()),
|
||||
static_cast<int32_t*>(expert_offsets.data_ptr()),
|
||||
static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts);
|
||||
compute_arg_sorts<<<num_experts, num_threads, 0, stream>>>(
|
||||
static_cast<const int32_t*>(topk_ids.data_ptr()),
|
||||
static_cast<int32_t*>(input_permutation.data_ptr()),
|
||||
static_cast<int32_t*>(output_permutation.data_ptr()),
|
||||
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(),
|
||||
topk_ids.size(1));
|
||||
}
|
@ -29,6 +29,20 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
void cutlass_moe_mm_sm90(
|
||||
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides);
|
||||
|
||||
void get_cutlass_moe_mm_data_caller(
|
||||
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
||||
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
|
||||
const int64_t num_experts, const int64_t n, const int64_t k);
|
||||
|
||||
#endif
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
|
||||
@ -102,6 +116,19 @@ bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool cutlass_group_gemm_supported(int64_t cuda_device_capability) {
|
||||
// CUTLASS groped FP8 kernels need at least CUDA 12.3
|
||||
// and SM90 (Hopper)
|
||||
|
||||
#if defined CUDA_VERSION
|
||||
if (cuda_device_capability == 90) {
|
||||
return CUDA_VERSION >= 12030;
|
||||
}
|
||||
#endif
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
@ -168,6 +195,46 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
||||
version_num);
|
||||
}
|
||||
|
||||
void cutlass_moe_mm(
|
||||
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides) {
|
||||
int32_t version_num = get_sm_version_num();
|
||||
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
|
||||
cutlass_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
|
||||
expert_offsets, problem_sizes, a_strides, b_strides,
|
||||
c_strides);
|
||||
return;
|
||||
#endif
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"No compiled cutlass_scaled_mm for CUDA device capability: ", version_num,
|
||||
". Required capability: 90");
|
||||
}
|
||||
|
||||
void get_cutlass_moe_mm_data(
|
||||
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
||||
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
|
||||
const int64_t num_experts, const int64_t n, const int64_t k) {
|
||||
// This function currently gets compiled only if we have a valid cutlass moe
|
||||
// mm to run it for.
|
||||
int32_t version_num = get_sm_version_num();
|
||||
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
|
||||
get_cutlass_moe_mm_data_caller(topk_ids, expert_offsets, problem_sizes1,
|
||||
problem_sizes2, input_permutation,
|
||||
output_permutation, num_experts, n, k);
|
||||
return;
|
||||
#endif
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"No compiled get_cutlass_moe_mm_data: no cutlass_scaled_mm kernel for "
|
||||
"CUDA device capability: ",
|
||||
version_num, ". Required capability: 90");
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
|
@ -375,25 +375,25 @@ torch::Tensor ggml_moe_a8(torch::Tensor X, // input
|
||||
int64_t ggml_moe_get_block_size(int64_t type) {
|
||||
switch (type) {
|
||||
case 2:
|
||||
return MMQ_X_Q4_0;
|
||||
return MOE_X_Q4_0;
|
||||
case 3:
|
||||
return MMQ_X_Q4_1;
|
||||
return MOE_X_Q4_1;
|
||||
case 6:
|
||||
return MMQ_X_Q5_0;
|
||||
return MOE_X_Q5_0;
|
||||
case 7:
|
||||
return MMQ_X_Q5_1;
|
||||
return MOE_X_Q5_1;
|
||||
case 8:
|
||||
return MMQ_X_Q8_0;
|
||||
return MOE_X_Q8_0;
|
||||
case 10:
|
||||
return MMQ_X_Q2_K;
|
||||
return MOE_X_Q2_K;
|
||||
case 11:
|
||||
return MMQ_X_Q3_K;
|
||||
return MOE_X_Q3_K;
|
||||
case 12:
|
||||
return MMQ_X_Q4_K;
|
||||
return MOE_X_Q4_K;
|
||||
case 13:
|
||||
return MMQ_X_Q5_K;
|
||||
return MOE_X_Q5_K;
|
||||
case 14:
|
||||
return MMQ_X_Q6_K;
|
||||
return MOE_X_Q6_K;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
@ -129,12 +129,12 @@ static __device__ __forceinline__ void moe_q(
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#define MMQ_X_Q4_0 64
|
||||
#define MMQ_Y_Q4_0 128
|
||||
#define MOE_X_Q4_0 64
|
||||
#define MOE_Y_Q4_0 128
|
||||
#define NWARPS_Q4_0 8
|
||||
#else
|
||||
#define MMQ_X_Q4_0 4
|
||||
#define MMQ_Y_Q4_0 32
|
||||
#define MOE_X_Q4_0 4
|
||||
#define MOE_Y_Q4_0 32
|
||||
#define NWARPS_Q4_0 4
|
||||
#endif
|
||||
|
||||
@ -149,8 +149,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q4_0, 2)
|
||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst,
|
||||
const int top_k) {
|
||||
const int mmq_x = MMQ_X_Q4_0;
|
||||
const int mmq_y = MMQ_Y_Q4_0;
|
||||
const int mmq_x = MOE_X_Q4_0;
|
||||
const int mmq_y = MOE_Y_Q4_0;
|
||||
const int nwarps = NWARPS_Q4_0;
|
||||
|
||||
moe_q<scalar_t, QK4_0, QR4_0, QI4_0, true, block_q4_0, mmq_x, mmq_y, nwarps,
|
||||
@ -167,8 +167,8 @@ static void ggml_moe_q4_0_q8_1_cuda(
|
||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
|
||||
const int tokens_post_padded, cudaStream_t stream) {
|
||||
int mmq_x = MMQ_X_Q4_0;
|
||||
int mmq_y = MMQ_Y_Q4_0;
|
||||
int mmq_x = MOE_X_Q4_0;
|
||||
int mmq_y = MOE_Y_Q4_0;
|
||||
int nwarps = NWARPS_Q4_0;
|
||||
|
||||
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
|
||||
@ -190,12 +190,12 @@ static void ggml_moe_q4_0_q8_1_cuda(
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#define MMQ_X_Q4_1 64
|
||||
#define MMQ_Y_Q4_1 128
|
||||
#define MOE_X_Q4_1 64
|
||||
#define MOE_Y_Q4_1 128
|
||||
#define NWARPS_Q4_1 8
|
||||
#else
|
||||
#define MMQ_X_Q4_1 4
|
||||
#define MMQ_Y_Q4_1 32
|
||||
#define MOE_X_Q4_1 4
|
||||
#define MOE_Y_Q4_1 32
|
||||
#define NWARPS_Q4_1 4
|
||||
#endif
|
||||
|
||||
@ -210,8 +210,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q4_1, 2)
|
||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst,
|
||||
const int top_k) {
|
||||
const int mmq_x = MMQ_X_Q4_1;
|
||||
const int mmq_y = MMQ_Y_Q4_1;
|
||||
const int mmq_x = MOE_X_Q4_1;
|
||||
const int mmq_y = MOE_Y_Q4_1;
|
||||
const int nwarps = NWARPS_Q4_1;
|
||||
|
||||
moe_q<scalar_t, QK4_1, QR4_1, QI4_1, true, block_q4_1, mmq_x, mmq_y, nwarps,
|
||||
@ -228,8 +228,8 @@ static void ggml_moe_q4_1_q8_1_cuda(
|
||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
|
||||
const int tokens_post_padded, cudaStream_t stream) {
|
||||
int mmq_x = MMQ_X_Q4_1;
|
||||
int mmq_y = MMQ_Y_Q4_1;
|
||||
int mmq_x = MOE_X_Q4_1;
|
||||
int mmq_y = MOE_Y_Q4_1;
|
||||
int nwarps = NWARPS_Q4_1;
|
||||
|
||||
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
|
||||
@ -251,12 +251,12 @@ static void ggml_moe_q4_1_q8_1_cuda(
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#define MMQ_X_Q5_0 64
|
||||
#define MMQ_Y_Q5_0 128
|
||||
#define MOE_X_Q5_0 64
|
||||
#define MOE_Y_Q5_0 128
|
||||
#define NWARPS_Q5_0 8
|
||||
#else
|
||||
#define MMQ_X_Q5_0 4
|
||||
#define MMQ_Y_Q5_0 32
|
||||
#define MOE_X_Q5_0 4
|
||||
#define MOE_Y_Q5_0 32
|
||||
#define NWARPS_Q5_0 4
|
||||
#endif
|
||||
|
||||
@ -271,8 +271,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q5_0, 2)
|
||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst,
|
||||
const int top_k) {
|
||||
const int mmq_x = MMQ_X_Q5_0;
|
||||
const int mmq_y = MMQ_Y_Q5_0;
|
||||
const int mmq_x = MOE_X_Q5_0;
|
||||
const int mmq_y = MOE_Y_Q5_0;
|
||||
const int nwarps = NWARPS_Q5_0;
|
||||
|
||||
moe_q<scalar_t, QK5_0, QR5_0, QI5_0, false, block_q5_0, mmq_x, mmq_y, nwarps,
|
||||
@ -289,8 +289,8 @@ static void ggml_moe_q5_0_q8_1_cuda(
|
||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
|
||||
const int tokens_post_padded, cudaStream_t stream) {
|
||||
const int mmq_x = MMQ_X_Q5_0;
|
||||
const int mmq_y = MMQ_Y_Q5_0;
|
||||
const int mmq_x = MOE_X_Q5_0;
|
||||
const int mmq_y = MOE_Y_Q5_0;
|
||||
const int nwarps = NWARPS_Q5_0;
|
||||
|
||||
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
|
||||
@ -312,12 +312,12 @@ static void ggml_moe_q5_0_q8_1_cuda(
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#define MMQ_X_Q5_1 64
|
||||
#define MMQ_Y_Q5_1 128
|
||||
#define MOE_X_Q5_1 64
|
||||
#define MOE_Y_Q5_1 128
|
||||
#define NWARPS_Q5_1 8
|
||||
#else
|
||||
#define MMQ_X_Q5_1 4
|
||||
#define MMQ_Y_Q5_1 32
|
||||
#define MOE_X_Q5_1 4
|
||||
#define MOE_Y_Q5_1 32
|
||||
#define NWARPS_Q5_1 4
|
||||
#endif
|
||||
|
||||
@ -332,8 +332,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q5_1, 2)
|
||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst,
|
||||
const int top_k) {
|
||||
const int mmq_x = MMQ_X_Q5_1;
|
||||
const int mmq_y = MMQ_Y_Q5_1;
|
||||
const int mmq_x = MOE_X_Q5_1;
|
||||
const int mmq_y = MOE_Y_Q5_1;
|
||||
const int nwarps = NWARPS_Q5_1;
|
||||
|
||||
moe_q<scalar_t, QK5_1, QR5_1, QI5_1, true, block_q5_1, mmq_x, mmq_y, nwarps,
|
||||
@ -350,8 +350,8 @@ static void ggml_moe_q5_1_q8_1_cuda(
|
||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
|
||||
const int tokens_post_padded, cudaStream_t stream) {
|
||||
const int mmq_x = MMQ_X_Q5_1;
|
||||
const int mmq_y = MMQ_Y_Q5_1;
|
||||
const int mmq_x = MOE_X_Q5_1;
|
||||
const int mmq_y = MOE_Y_Q5_1;
|
||||
const int nwarps = NWARPS_Q5_1;
|
||||
|
||||
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
|
||||
@ -373,12 +373,12 @@ static void ggml_moe_q5_1_q8_1_cuda(
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#define MMQ_X_Q8_0 64
|
||||
#define MMQ_Y_Q8_0 128
|
||||
#define MOE_X_Q8_0 64
|
||||
#define MOE_Y_Q8_0 128
|
||||
#define NWARPS_Q8_0 8
|
||||
#else
|
||||
#define MMQ_X_Q8_0 4
|
||||
#define MMQ_Y_Q8_0 32
|
||||
#define MOE_X_Q8_0 4
|
||||
#define MOE_Y_Q8_0 32
|
||||
#define NWARPS_Q8_0 4
|
||||
#endif
|
||||
|
||||
@ -393,8 +393,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q8_0, 2)
|
||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst,
|
||||
const int top_k) {
|
||||
const int mmq_x = MMQ_X_Q8_0;
|
||||
const int mmq_y = MMQ_Y_Q8_0;
|
||||
const int mmq_x = MOE_X_Q8_0;
|
||||
const int mmq_y = MOE_Y_Q8_0;
|
||||
const int nwarps = NWARPS_Q8_0;
|
||||
|
||||
moe_q<scalar_t, QK8_0, QR8_0, QI8_0, false, block_q8_0, mmq_x, mmq_y, nwarps,
|
||||
@ -411,8 +411,8 @@ static void ggml_moe_q8_0_q8_1_cuda(
|
||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
|
||||
const int tokens_post_padded, cudaStream_t stream) {
|
||||
const int mmq_x = MMQ_X_Q8_0;
|
||||
const int mmq_y = MMQ_Y_Q8_0;
|
||||
const int mmq_x = MOE_X_Q8_0;
|
||||
const int mmq_y = MOE_Y_Q8_0;
|
||||
const int nwarps = NWARPS_Q8_0;
|
||||
|
||||
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
|
||||
@ -434,12 +434,12 @@ static void ggml_moe_q8_0_q8_1_cuda(
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#define MMQ_X_Q2_K 64
|
||||
#define MMQ_Y_Q2_K 128
|
||||
#define MOE_X_Q2_K 64
|
||||
#define MOE_Y_Q2_K 128
|
||||
#define NWARPS_Q2_K 8
|
||||
#else
|
||||
#define MMQ_X_Q2_K 4
|
||||
#define MMQ_Y_Q2_K 32
|
||||
#define MOE_X_Q2_K 4
|
||||
#define MOE_Y_Q2_K 32
|
||||
#define NWARPS_Q2_K 4
|
||||
#endif
|
||||
|
||||
@ -454,8 +454,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q2_K, 2)
|
||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst,
|
||||
const int top_k) {
|
||||
const int mmq_x = MMQ_X_Q2_K;
|
||||
const int mmq_y = MMQ_Y_Q2_K;
|
||||
const int mmq_x = MOE_X_Q2_K;
|
||||
const int mmq_y = MOE_Y_Q2_K;
|
||||
const int nwarps = NWARPS_Q2_K;
|
||||
|
||||
moe_q<scalar_t, QK_K, QR2_K, QI2_K, false, block_q2_K, mmq_x, mmq_y, nwarps,
|
||||
@ -472,8 +472,8 @@ static void ggml_moe_q2_K_q8_1_cuda(
|
||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
|
||||
const int tokens_post_padded, cudaStream_t stream) {
|
||||
const int mmq_x = MMQ_X_Q2_K;
|
||||
const int mmq_y = MMQ_Y_Q2_K;
|
||||
const int mmq_x = MOE_X_Q2_K;
|
||||
const int mmq_y = MOE_Y_Q2_K;
|
||||
const int nwarps = NWARPS_Q2_K;
|
||||
|
||||
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
|
||||
@ -495,12 +495,12 @@ static void ggml_moe_q2_K_q8_1_cuda(
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#define MMQ_X_Q3_K 64
|
||||
#define MMQ_Y_Q3_K 128
|
||||
#define MOE_X_Q3_K 64
|
||||
#define MOE_Y_Q3_K 128
|
||||
#define NWARPS_Q3_K 8
|
||||
#else
|
||||
#define MMQ_X_Q3_K 4
|
||||
#define MMQ_Y_Q3_K 32
|
||||
#define MOE_X_Q3_K 4
|
||||
#define MOE_Y_Q3_K 32
|
||||
#define NWARPS_Q3_K 4
|
||||
#endif
|
||||
|
||||
@ -516,8 +516,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q3_K, 2)
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst,
|
||||
const int top_k) {
|
||||
|
||||
const int mmq_x = MMQ_X_Q3_K;
|
||||
const int mmq_y = MMQ_Y_Q3_K;
|
||||
const int mmq_x = MOE_X_Q3_K;
|
||||
const int mmq_y = MOE_Y_Q3_K;
|
||||
const int nwarps = NWARPS_Q3_K;
|
||||
|
||||
moe_q<scalar_t, QK_K, QR3_K, QI3_K, false, block_q3_K, mmq_x, mmq_y, nwarps,
|
||||
@ -533,8 +533,8 @@ static void ggml_moe_q3_K_q8_1_cuda(
|
||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
|
||||
const int tokens_post_padded, cudaStream_t stream) {
|
||||
const int mmq_x = MMQ_X_Q3_K;
|
||||
const int mmq_y = MMQ_Y_Q3_K;
|
||||
const int mmq_x = MOE_X_Q3_K;
|
||||
const int mmq_y = MOE_Y_Q3_K;
|
||||
const int nwarps = NWARPS_Q3_K;
|
||||
|
||||
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
|
||||
@ -556,12 +556,12 @@ static void ggml_moe_q3_K_q8_1_cuda(
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#define MMQ_X_Q4_K 64
|
||||
#define MMQ_Y_Q4_K 128
|
||||
#define MOE_X_Q4_K 64
|
||||
#define MOE_Y_Q4_K 128
|
||||
#define NWARPS_Q4_K 8
|
||||
#else
|
||||
#define MMQ_X_Q4_K 4
|
||||
#define MMQ_Y_Q4_K 32
|
||||
#define MOE_X_Q4_K 4
|
||||
#define MOE_Y_Q4_K 32
|
||||
#define NWARPS_Q4_K 4
|
||||
#endif
|
||||
|
||||
@ -576,8 +576,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q4_K, 2)
|
||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst,
|
||||
const int top_k) {
|
||||
const int mmq_x = MMQ_X_Q4_K;
|
||||
const int mmq_y = MMQ_Y_Q4_K;
|
||||
const int mmq_x = MOE_X_Q4_K;
|
||||
const int mmq_y = MOE_Y_Q4_K;
|
||||
const int nwarps = NWARPS_Q4_K;
|
||||
|
||||
moe_q<scalar_t, QK_K, QR4_K, QI4_K, true, block_q4_K, mmq_x, mmq_y, nwarps,
|
||||
@ -594,8 +594,8 @@ static void ggml_moe_q4_K_q8_1_cuda(
|
||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
|
||||
const int tokens_post_padded, cudaStream_t stream) {
|
||||
const int mmq_x = MMQ_X_Q4_K;
|
||||
const int mmq_y = MMQ_Y_Q4_K;
|
||||
const int mmq_x = MOE_X_Q4_K;
|
||||
const int mmq_y = MOE_Y_Q4_K;
|
||||
const int nwarps = NWARPS_Q4_K;
|
||||
|
||||
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
|
||||
@ -617,12 +617,12 @@ static void ggml_moe_q4_K_q8_1_cuda(
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#define MMQ_X_Q5_K 64
|
||||
#define MMQ_Y_Q5_K 128
|
||||
#define MOE_X_Q5_K 64
|
||||
#define MOE_Y_Q5_K 128
|
||||
#define NWARPS_Q5_K 8
|
||||
#else
|
||||
#define MMQ_X_Q5_K 4
|
||||
#define MMQ_Y_Q5_K 32
|
||||
#define MOE_X_Q5_K 4
|
||||
#define MOE_Y_Q5_K 32
|
||||
#define NWARPS_Q5_K 4
|
||||
#endif
|
||||
|
||||
@ -637,8 +637,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q5_K, 2)
|
||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst,
|
||||
const int top_k) {
|
||||
const int mmq_x = MMQ_X_Q5_K;
|
||||
const int mmq_y = MMQ_Y_Q5_K;
|
||||
const int mmq_x = MOE_X_Q5_K;
|
||||
const int mmq_y = MOE_Y_Q5_K;
|
||||
const int nwarps = NWARPS_Q5_K;
|
||||
|
||||
moe_q<scalar_t, QK_K, QR5_K, QI5_K, true, block_q5_K, mmq_x, mmq_y, nwarps,
|
||||
@ -655,8 +655,8 @@ static void ggml_moe_q5_K_q8_1_cuda(
|
||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
|
||||
const int tokens_post_padded, cudaStream_t stream) {
|
||||
const int mmq_x = MMQ_X_Q5_K;
|
||||
const int mmq_y = MMQ_Y_Q5_K;
|
||||
const int mmq_x = MOE_X_Q5_K;
|
||||
const int mmq_y = MOE_Y_Q5_K;
|
||||
const int nwarps = NWARPS_Q5_K;
|
||||
|
||||
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
|
||||
@ -678,12 +678,12 @@ static void ggml_moe_q5_K_q8_1_cuda(
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#define MMQ_X_Q6_K 64
|
||||
#define MMQ_Y_Q6_K 128
|
||||
#define MOE_X_Q6_K 64
|
||||
#define MOE_Y_Q6_K 128
|
||||
#define NWARPS_Q6_K 8
|
||||
#else
|
||||
#define MMQ_X_Q6_K 4
|
||||
#define MMQ_Y_Q6_K 32
|
||||
#define MOE_X_Q6_K 4
|
||||
#define MOE_Y_Q6_K 32
|
||||
#define NWARPS_Q6_K 4
|
||||
#endif
|
||||
|
||||
@ -698,8 +698,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q6_K, 2)
|
||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst,
|
||||
const int top_k) {
|
||||
const int mmq_x = MMQ_X_Q6_K;
|
||||
const int mmq_y = MMQ_Y_Q6_K;
|
||||
const int mmq_x = MOE_X_Q6_K;
|
||||
const int mmq_y = MOE_Y_Q6_K;
|
||||
const int nwarps = NWARPS_Q6_K;
|
||||
|
||||
moe_q<scalar_t, QK_K, QR6_K, QI6_K, false, block_q6_K, mmq_x, mmq_y, nwarps,
|
||||
@ -716,8 +716,8 @@ static void ggml_moe_q6_K_q8_1_cuda(
|
||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
|
||||
const int tokens_post_padded, cudaStream_t stream) {
|
||||
const int mmq_x = MMQ_X_Q6_K;
|
||||
const int mmq_y = MMQ_Y_Q6_K;
|
||||
const int mmq_x = MOE_X_Q6_K;
|
||||
const int mmq_y = MOE_Y_Q6_K;
|
||||
const int nwarps = NWARPS_Q6_K;
|
||||
|
||||
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
|
||||
|
@ -14,7 +14,7 @@ __global__ void awq_marlin_repack_kernel(
|
||||
int n_tiles = size_n / tile_n_size;
|
||||
int block_k_tiles = div_ceil(k_tiles, gridDim.x);
|
||||
|
||||
int start_k_tile = blockIdx.x * block_k_tiles;
|
||||
auto start_k_tile = blockIdx.x * block_k_tiles;
|
||||
if (start_k_tile >= k_tiles) {
|
||||
return;
|
||||
}
|
||||
@ -51,8 +51,8 @@ __global__ void awq_marlin_repack_kernel(
|
||||
int4* sh_ptr = sh + stage_size * pipe;
|
||||
|
||||
if (threadIdx.x < stage_size) {
|
||||
int k_id = threadIdx.x / stage_n_threads;
|
||||
int n_id = threadIdx.x % stage_n_threads;
|
||||
auto k_id = threadIdx.x / stage_n_threads;
|
||||
auto n_id = threadIdx.x % stage_n_threads;
|
||||
|
||||
int first_k = k_tile_id * tile_k_size;
|
||||
|
||||
@ -70,8 +70,8 @@ __global__ void awq_marlin_repack_kernel(
|
||||
return;
|
||||
}
|
||||
|
||||
int warp_id = threadIdx.x / 32;
|
||||
int th_id = threadIdx.x % 32;
|
||||
auto warp_id = threadIdx.x / 32;
|
||||
auto th_id = threadIdx.x % 32;
|
||||
|
||||
if (warp_id >= 4) {
|
||||
return;
|
||||
@ -265,4 +265,4 @@ TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, Meta, m) {
|
||||
m.impl("awq_marlin_repack", &awq_marlin_repack_meta);
|
||||
}
|
||||
}
|
||||
|
@ -460,7 +460,7 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
|
||||
int const* __restrict__ perm_int_ptr,
|
||||
int4* __restrict__ out_int4_ptr, int size_m,
|
||||
int size_k, int lda, int block_rows) {
|
||||
int start_row = block_rows * blockIdx.x;
|
||||
auto start_row = block_rows * blockIdx.x;
|
||||
int finish_row = start_row + block_rows;
|
||||
if (finish_row > size_m) {
|
||||
finish_row = size_m;
|
||||
@ -484,7 +484,7 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
|
||||
int base_k = 0;
|
||||
|
||||
for (int i = 0; i < iters; i++) {
|
||||
int cur_k = base_k + threadIdx.x;
|
||||
auto cur_k = base_k + threadIdx.x;
|
||||
int src_pos = perm_int_ptr[cur_k];
|
||||
|
||||
out_half[cur_k] = a_row_half[src_pos];
|
||||
@ -494,7 +494,7 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
|
||||
|
||||
if (rest) {
|
||||
if (threadIdx.x < rest) {
|
||||
int cur_k = base_k + threadIdx.x;
|
||||
auto cur_k = base_k + threadIdx.x;
|
||||
int src_pos = perm_int_ptr[cur_k];
|
||||
|
||||
out_half[cur_k] = a_row_half[src_pos];
|
||||
@ -723,8 +723,8 @@ __global__ void Marlin(
|
||||
(threadIdx.x % b_sh_stride_threads) * b_thread_vecs;
|
||||
b_gl_rd += b_sh_stride * slice_col;
|
||||
b_gl_rd += b_gl_rd_delta_o * slice_row;
|
||||
int b_sh_wr = threadIdx.x * b_thread_vecs;
|
||||
int b_sh_rd = threadIdx.x * b_thread_vecs;
|
||||
auto b_sh_wr = threadIdx.x * b_thread_vecs;
|
||||
auto b_sh_rd = threadIdx.x * b_thread_vecs;
|
||||
|
||||
// For act_order
|
||||
constexpr int k_iter_size = tb_k / b_sh_wr_iters;
|
||||
@ -743,7 +743,7 @@ __global__ void Marlin(
|
||||
s_sh_stride * slice_col + threadIdx.x;
|
||||
}
|
||||
}
|
||||
int s_sh_wr = threadIdx.x;
|
||||
auto s_sh_wr = threadIdx.x;
|
||||
bool s_sh_wr_pred = threadIdx.x < s_sh_stride;
|
||||
|
||||
// Zero-points
|
||||
@ -756,7 +756,7 @@ __global__ void Marlin(
|
||||
zp_sh_stride * slice_col + threadIdx.x;
|
||||
}
|
||||
}
|
||||
int zp_sh_wr = threadIdx.x;
|
||||
auto zp_sh_wr = threadIdx.x;
|
||||
bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride;
|
||||
|
||||
// We use a different scale layout for grouped and column-wise quantization as
|
||||
@ -1047,7 +1047,7 @@ __global__ void Marlin(
|
||||
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
|
||||
reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
|
||||
} else {
|
||||
int warp_id = threadIdx.x / 32;
|
||||
auto warp_id = threadIdx.x / 32;
|
||||
int n_warps = thread_n_blocks / 4;
|
||||
|
||||
int warp_row = warp_id / n_warps;
|
||||
@ -1085,7 +1085,7 @@ __global__ void Marlin(
|
||||
|
||||
// Determine "position" inside the thread-block (based on warp and
|
||||
// thread-id)
|
||||
int warp_id = threadIdx.x / 32;
|
||||
auto warp_id = threadIdx.x / 32;
|
||||
int n_warps =
|
||||
thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N
|
||||
|
||||
@ -1094,7 +1094,7 @@ __global__ void Marlin(
|
||||
|
||||
cur_k += warp_row * 16;
|
||||
|
||||
int th_id = threadIdx.x % 32;
|
||||
auto th_id = threadIdx.x % 32;
|
||||
cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix
|
||||
|
||||
int s_col_shift =
|
||||
@ -1159,7 +1159,7 @@ __global__ void Marlin(
|
||||
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
|
||||
}
|
||||
} else {
|
||||
int warp_id = threadIdx.x / 32;
|
||||
auto warp_id = threadIdx.x / 32;
|
||||
int n_warps = thread_n_blocks / 4;
|
||||
|
||||
int warp_row = warp_id / n_warps;
|
||||
@ -1197,7 +1197,7 @@ __global__ void Marlin(
|
||||
(pipe / (group_blocks / thread_k_blocks)));
|
||||
reinterpret_cast<int4*>(&frag_zpf[k % 2])[0] = sh_zp_stage[zp_sh_rd];
|
||||
} else {
|
||||
int warp_id = threadIdx.x / 32;
|
||||
auto warp_id = threadIdx.x / 32;
|
||||
int n_warps = thread_n_blocks / 4;
|
||||
|
||||
int warp_row = warp_id / n_warps;
|
||||
@ -1323,7 +1323,7 @@ __global__ void Marlin(
|
||||
auto thread_block_reduce = [&]() {
|
||||
constexpr int red_off = threads / b_sh_stride_threads / 2;
|
||||
if (red_off >= 1) {
|
||||
int red_idx = threadIdx.x / b_sh_stride_threads;
|
||||
auto red_idx = threadIdx.x / b_sh_stride_threads;
|
||||
constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2;
|
||||
constexpr int red_sh_delta = b_sh_stride_threads;
|
||||
int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) +
|
||||
@ -1390,7 +1390,7 @@ __global__ void Marlin(
|
||||
4 * (threadIdx.x / 32) + threadIdx.x % 4;
|
||||
c_gl_wr += (2 * thread_n_blocks) * slice_col;
|
||||
constexpr int c_sh_wr_delta = active_threads;
|
||||
int c_sh_wr = threadIdx.x;
|
||||
auto c_sh_wr = threadIdx.x;
|
||||
|
||||
int row = (threadIdx.x % 32) / 4;
|
||||
|
||||
|
@ -15,7 +15,7 @@ __global__ void gptq_marlin_repack_kernel(
|
||||
int n_tiles = size_n / tile_n_size;
|
||||
int block_k_tiles = div_ceil(k_tiles, gridDim.x);
|
||||
|
||||
int start_k_tile = blockIdx.x * block_k_tiles;
|
||||
auto start_k_tile = blockIdx.x * block_k_tiles;
|
||||
if (start_k_tile >= k_tiles) {
|
||||
return;
|
||||
}
|
||||
@ -71,8 +71,8 @@ __global__ void gptq_marlin_repack_kernel(
|
||||
|
||||
if constexpr (has_perm) {
|
||||
if (threadIdx.x < stage_size) {
|
||||
int k_id = threadIdx.x / stage_n_threads;
|
||||
int n_id = threadIdx.x % stage_n_threads;
|
||||
auto k_id = threadIdx.x / stage_n_threads;
|
||||
auto n_id = threadIdx.x % stage_n_threads;
|
||||
|
||||
uint32_t const* sh_perm_int_ptr =
|
||||
reinterpret_cast<uint32_t const*>(sh_perm_ptr);
|
||||
@ -88,8 +88,8 @@ __global__ void gptq_marlin_repack_kernel(
|
||||
|
||||
} else {
|
||||
if (threadIdx.x < stage_size) {
|
||||
int k_id = threadIdx.x / stage_n_threads;
|
||||
int n_id = threadIdx.x % stage_n_threads;
|
||||
auto k_id = threadIdx.x / stage_n_threads;
|
||||
auto n_id = threadIdx.x % stage_n_threads;
|
||||
|
||||
int first_k = k_tile_id * tile_k_size;
|
||||
int first_k_packed = first_k / pack_factor;
|
||||
@ -109,8 +109,8 @@ __global__ void gptq_marlin_repack_kernel(
|
||||
return;
|
||||
}
|
||||
|
||||
int warp_id = threadIdx.x / 32;
|
||||
int th_id = threadIdx.x % 32;
|
||||
auto warp_id = threadIdx.x / 32;
|
||||
auto th_id = threadIdx.x % 32;
|
||||
|
||||
if (warp_id >= 4) {
|
||||
return;
|
||||
@ -339,4 +339,4 @@ TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, Meta, m) {
|
||||
m.impl("gptq_marlin_repack", &gptq_marlin_repack_meta);
|
||||
}
|
||||
}
|
||||
|
@ -277,12 +277,12 @@ __global__ void Marlin(
|
||||
b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride);
|
||||
b_gl_rd += b_sh_stride * slice_col;
|
||||
b_gl_rd += b_gl_rd_delta_o * slice_row;
|
||||
int b_sh_wr = threadIdx.x;
|
||||
int b_sh_rd = threadIdx.x;
|
||||
auto b_sh_wr = threadIdx.x;
|
||||
auto b_sh_rd = threadIdx.x;
|
||||
|
||||
int s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
|
||||
s_sh_stride * slice_col + threadIdx.x;
|
||||
int s_sh_wr = threadIdx.x;
|
||||
auto s_sh_wr = threadIdx.x;
|
||||
int s_sh_rd;
|
||||
// We use a different scale layout for grouped and column-wise quantization as
|
||||
// we scale a `half2` tile in column-major layout in the former and in
|
||||
@ -455,7 +455,7 @@ __global__ void Marlin(
|
||||
auto thread_block_reduce = [&]() {
|
||||
constexpr int red_off = threads / b_sh_stride / 2;
|
||||
if (red_off >= 1) {
|
||||
int red_idx = threadIdx.x / b_sh_stride;
|
||||
auto red_idx = threadIdx.x / b_sh_stride;
|
||||
constexpr int red_sh_stride = b_sh_stride * 4 * 2;
|
||||
constexpr int red_sh_delta = b_sh_stride;
|
||||
int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) +
|
||||
@ -522,7 +522,7 @@ __global__ void Marlin(
|
||||
4 * (threadIdx.x / 32) + threadIdx.x % 4;
|
||||
c_gl_wr += (2 * thread_n_blocks) * slice_col;
|
||||
constexpr int c_sh_wr_delta = active_threads;
|
||||
int c_sh_wr = threadIdx.x;
|
||||
auto c_sh_wr = threadIdx.x;
|
||||
|
||||
int row = (threadIdx.x % 32) / 4;
|
||||
|
||||
|
@ -353,10 +353,10 @@ __global__ void Marlin(
|
||||
b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride);
|
||||
b_gl_rd += b_sh_stride * slice_col;
|
||||
b_gl_rd += b_gl_rd_delta_o * slice_row;
|
||||
int b_sh_wr = threadIdx.x;
|
||||
int b_sh_rd = threadIdx.x;
|
||||
auto b_sh_wr = threadIdx.x;
|
||||
auto b_sh_rd = threadIdx.x;
|
||||
|
||||
int s_tok_gl_rd = threadIdx.x;
|
||||
auto s_tok_gl_rd = threadIdx.x;
|
||||
// NOTE(HandH1998): activation scale s_tok need shuffle to [0, 8, 1, 9, 2, 10,
|
||||
// 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] for example, 0, 8 row scales serve for
|
||||
// thread 0, 1, 2, 3. For more details, refer to mma operand A layout as
|
||||
@ -368,8 +368,8 @@ __global__ void Marlin(
|
||||
int s_tok_sh_rd = (threadIdx.x % 32) / 4;
|
||||
bool s_tok_sh_wr_pred = threadIdx.x < prob_m;
|
||||
|
||||
int s_ch_gl_rd = s_ch_sh_stride * slice_col + threadIdx.x;
|
||||
int s_ch_sh_wr = threadIdx.x;
|
||||
auto s_ch_gl_rd = s_ch_sh_stride * slice_col + threadIdx.x;
|
||||
auto s_ch_sh_wr = threadIdx.x;
|
||||
int s_ch_sh_rd = 16 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
||||
2 * ((threadIdx.x % 32) % 4);
|
||||
bool s_ch_sh_wr_pred = threadIdx.x < s_ch_sh_stride;
|
||||
@ -558,7 +558,7 @@ __global__ void Marlin(
|
||||
auto thread_block_reduce = [&]() {
|
||||
constexpr int red_off = threads / b_sh_stride / 2;
|
||||
if (red_off >= 1) {
|
||||
int red_idx = threadIdx.x / b_sh_stride;
|
||||
auto red_idx = threadIdx.x / b_sh_stride;
|
||||
constexpr int red_sh_stride = b_sh_stride * 4 * 2;
|
||||
constexpr int red_sh_delta = b_sh_stride;
|
||||
int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) +
|
||||
@ -628,7 +628,7 @@ __global__ void Marlin(
|
||||
8 * (threadIdx.x / 32) + (threadIdx.x % 4) * 2;
|
||||
c_gl_wr += (4 * thread_n_blocks) * slice_col;
|
||||
constexpr int c_sh_wr_delta = active_threads * 2;
|
||||
int c_sh_wr = 2 * threadIdx.x;
|
||||
auto c_sh_wr = 2 * threadIdx.x;
|
||||
|
||||
int row = (threadIdx.x % 32) / 4;
|
||||
|
||||
|
@ -273,15 +273,15 @@ __global__ void Marlin_24(
|
||||
(threadIdx.x % b_sh_stride_threads) * b_thread_vecs;
|
||||
b_gl_rd += b_sh_stride * slice_col;
|
||||
b_gl_rd += b_gl_rd_delta_o * slice_row;
|
||||
int b_sh_wr = threadIdx.x * b_thread_vecs;
|
||||
int b_sh_rd = threadIdx.x * b_thread_vecs;
|
||||
auto b_sh_wr = threadIdx.x * b_thread_vecs;
|
||||
auto b_sh_rd = threadIdx.x * b_thread_vecs;
|
||||
|
||||
int m_gl_rd = m_gl_stride * (threadIdx.x / (m_sh_stride)) +
|
||||
(threadIdx.x % (m_sh_stride));
|
||||
m_gl_rd += (m_sh_stride)*slice_col;
|
||||
m_gl_rd += m_gl_rd_delta_o * slice_row;
|
||||
int m_sh_wr = threadIdx.x;
|
||||
int m_sh_rd = threadIdx.x % 16 + (threadIdx.x / 32) * 16;
|
||||
auto m_sh_wr = threadIdx.x;
|
||||
auto m_sh_rd = threadIdx.x % 16 + (threadIdx.x / 32) * 16;
|
||||
|
||||
int s_gl_rd;
|
||||
if constexpr (group_blocks == -1) {
|
||||
@ -291,7 +291,7 @@ __global__ void Marlin_24(
|
||||
s_sh_stride * slice_col + threadIdx.x;
|
||||
}
|
||||
|
||||
int s_sh_wr = threadIdx.x;
|
||||
auto s_sh_wr = threadIdx.x;
|
||||
int s_sh_rd;
|
||||
// We use a different scale layout for grouped and column-wise quantization as
|
||||
// we scale a `half2` tile in column-major layout in the former and in
|
||||
@ -516,7 +516,7 @@ __global__ void Marlin_24(
|
||||
auto thread_block_reduce = [&]() {
|
||||
constexpr int red_off = threads / b_sh_stride_threads / 2;
|
||||
if (red_off >= 1) {
|
||||
int red_idx = threadIdx.x / b_sh_stride_threads;
|
||||
auto red_idx = threadIdx.x / b_sh_stride_threads;
|
||||
constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2;
|
||||
constexpr int red_sh_delta = b_sh_stride_threads;
|
||||
int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) +
|
||||
@ -583,7 +583,7 @@ __global__ void Marlin_24(
|
||||
8 * (threadIdx.x / 32) + (threadIdx.x % 32) / 4;
|
||||
c_gl_wr += (2 * thread_n_blocks) * slice_col;
|
||||
constexpr int c_sh_wr_delta = active_threads;
|
||||
int c_sh_wr = threadIdx.x;
|
||||
auto c_sh_wr = threadIdx.x;
|
||||
|
||||
int col = 2 * ((threadIdx.x % 32) % 4);
|
||||
|
||||
|
@ -284,18 +284,18 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
int max_ctx_blocks, const float* k_scale, const float* v_scale) {
|
||||
// clang-format on
|
||||
constexpr int NWARPS = NUM_THREADS / WARP_SIZE;
|
||||
const int warpid = threadIdx.x / WARP_SIZE;
|
||||
const int laneid = threadIdx.x % WARP_SIZE;
|
||||
const auto warpid = threadIdx.x / WARP_SIZE;
|
||||
const auto laneid = threadIdx.x % WARP_SIZE;
|
||||
const int lane4id = laneid % 4;
|
||||
const int lane16id = laneid % 16;
|
||||
const int rowid = laneid / 16;
|
||||
|
||||
const int seq_idx = blockIdx.x;
|
||||
const int partition_idx = blockIdx.y;
|
||||
const auto seq_idx = blockIdx.x;
|
||||
const auto partition_idx = blockIdx.y;
|
||||
|
||||
constexpr int T_PAR_SIZE = 256; // token partition size set to 256
|
||||
|
||||
const int max_num_partitions = gridDim.y;
|
||||
const auto max_num_partitions = gridDim.y;
|
||||
|
||||
const int context_len = context_lens[seq_idx];
|
||||
|
||||
@ -346,9 +346,9 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
// can be interpreted as B8x16 for 8 bit types
|
||||
_B16x8 Klocal[TLOOP][QKHELOOP];
|
||||
|
||||
const int wg_start_head_idx = blockIdx.z * GQA_RATIO;
|
||||
const int wg_start_kv_head_idx = blockIdx.z;
|
||||
const int total_num_heads = gridDim.z * GQA_RATIO;
|
||||
const auto wg_start_head_idx = blockIdx.z * GQA_RATIO;
|
||||
const auto wg_start_kv_head_idx = blockIdx.z;
|
||||
const auto total_num_heads = gridDim.z * GQA_RATIO;
|
||||
|
||||
// for QK mfma, tokens in multiples of TOKENS_PER_WARP are spread across warps
|
||||
// each mfma takes QH16xT16x16HE across warp
|
||||
@ -789,14 +789,14 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
|
||||
int max_ctx_blocks, const float* k_scale, const float* v_scale) {
|
||||
// clang-format on
|
||||
constexpr int NWARPS = NUM_THREADS / WARP_SIZE;
|
||||
const int warpid = threadIdx.x / WARP_SIZE;
|
||||
const int laneid = threadIdx.x % WARP_SIZE;
|
||||
const auto warpid = threadIdx.x / WARP_SIZE;
|
||||
const auto laneid = threadIdx.x % WARP_SIZE;
|
||||
const int lane4id = laneid % 4;
|
||||
|
||||
const int seq_idx = blockIdx.x;
|
||||
const int partition_idx = blockIdx.y;
|
||||
const int partition_size = blockDim.x;
|
||||
const int max_num_partitions = gridDim.y;
|
||||
const auto seq_idx = blockIdx.x;
|
||||
const auto partition_idx = blockIdx.y;
|
||||
const auto partition_size = blockDim.x;
|
||||
const auto max_num_partitions = gridDim.y;
|
||||
|
||||
const int context_len = context_lens[seq_idx];
|
||||
const int partition_start_token_idx = partition_idx * partition_size;
|
||||
@ -838,8 +838,8 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
|
||||
qk_max[h] = -FLT_MAX;
|
||||
}
|
||||
|
||||
const int wg_start_head_idx = blockIdx.z * GQA_RATIO;
|
||||
const int wg_start_kv_head_idx = blockIdx.z;
|
||||
const auto wg_start_head_idx = blockIdx.z * GQA_RATIO;
|
||||
const auto wg_start_kv_head_idx = blockIdx.z;
|
||||
|
||||
const int warp_start_token_idx =
|
||||
partition_start_token_idx + warpid * WARP_SIZE;
|
||||
@ -857,7 +857,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
|
||||
|
||||
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
|
||||
// token id within partition
|
||||
const int local_token_idx = threadIdx.x;
|
||||
const auto local_token_idx = threadIdx.x;
|
||||
// token id within sequence
|
||||
const int global_token_idx = partition_start_token_idx + local_token_idx;
|
||||
|
||||
@ -1126,7 +1126,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
|
||||
|
||||
__syncthreads();
|
||||
|
||||
const int num_heads = gridDim.z * GQA_RATIO;
|
||||
const auto num_heads = gridDim.z * GQA_RATIO;
|
||||
float* max_logits_ptr =
|
||||
max_logits + seq_idx * num_heads * max_num_partitions + partition_idx;
|
||||
float* exp_sums_ptr =
|
||||
@ -1268,14 +1268,14 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
|
||||
// max_num_partitions, head_size]
|
||||
const int* __restrict__ context_lens, // [num_seqs]
|
||||
const int max_num_partitions) {
|
||||
const int num_heads = gridDim.x;
|
||||
const int head_idx = blockIdx.x;
|
||||
const int seq_idx = blockIdx.y;
|
||||
const auto num_heads = gridDim.x;
|
||||
const auto head_idx = blockIdx.x;
|
||||
const auto seq_idx = blockIdx.y;
|
||||
const int context_len = context_lens[seq_idx];
|
||||
const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE);
|
||||
[[maybe_unused]] constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
||||
const int warpid = threadIdx.x / WARP_SIZE;
|
||||
[[maybe_unused]] const int laneid = threadIdx.x % WARP_SIZE;
|
||||
const auto warpid = threadIdx.x / WARP_SIZE;
|
||||
[[maybe_unused]] const auto laneid = threadIdx.x % WARP_SIZE;
|
||||
|
||||
__shared__ float shared_global_exp_sum;
|
||||
// max num partitions supported is warp_size * NPAR_LOOPS
|
||||
@ -1294,7 +1294,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NPAR_LOOPS; i++) {
|
||||
const int partition_no = i * WARP_SIZE + threadIdx.x;
|
||||
const auto partition_no = i * WARP_SIZE + threadIdx.x;
|
||||
valid_partition[i] =
|
||||
(partition_no < num_partitions) ? partition_no : last_valid_partition;
|
||||
}
|
||||
@ -1324,7 +1324,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NPAR_LOOPS; i++) {
|
||||
const int partition_no = i * WARP_SIZE + threadIdx.x;
|
||||
const auto partition_no = i * WARP_SIZE + threadIdx.x;
|
||||
rescaled_exp_sum[i] *= (partition_no < num_partitions)
|
||||
? expf(reg_max_logit[i] - max_logit)
|
||||
: 0.0f;
|
||||
@ -1336,7 +1336,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NPAR_LOOPS; i++) {
|
||||
const int partition_no = i * WARP_SIZE + threadIdx.x;
|
||||
const auto partition_no = i * WARP_SIZE + threadIdx.x;
|
||||
shared_exp_sums[partition_no] = rescaled_exp_sum[i];
|
||||
}
|
||||
|
||||
|
@ -365,6 +365,35 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.def("cutlass_scaled_mm_supports_fp8(int cuda_device_capability) -> bool");
|
||||
ops.impl("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8);
|
||||
|
||||
// Check if cutlass grouped gemm is supported for CUDA devices of the given
|
||||
// capability
|
||||
ops.def("cutlass_group_gemm_supported(int cuda_device_capability) -> bool");
|
||||
ops.impl("cutlass_group_gemm_supported", &cutlass_group_gemm_supported);
|
||||
|
||||
// CUTLASS w8a8 grouped GEMM
|
||||
ops.def(
|
||||
"cutlass_moe_mm(Tensor! out_tensors, Tensor a_tensors, Tensor b_tensors, "
|
||||
" Tensor a_scales, Tensor b_scales, Tensor expert_offsets, "
|
||||
" Tensor problem_sizes, Tensor a_strides, "
|
||||
" Tensor b_strides, Tensor c_strides) -> ()",
|
||||
{stride_tag});
|
||||
ops.impl("cutlass_moe_mm", torch::kCUDA, &cutlass_moe_mm);
|
||||
|
||||
// A function that computes data required to run fused MoE with w8a8 grouped
|
||||
// GEMM. It takes topk_ids as an input, and computes expert_offsets
|
||||
// (token start indices of each expert). In addition to this, it computes
|
||||
// problem sizes for each expert's multiplication used by the two mms called
|
||||
// from fused MoE operation, and arrays with permutations required to shuffle
|
||||
// and de-shuffle the input/output of the fused operation.
|
||||
ops.def(
|
||||
"get_cutlass_moe_mm_data(Tensor topk_ids, Tensor! expert_offsets, "
|
||||
" Tensor! problem_sizes1, Tensor! problem_sizes2, "
|
||||
" Tensor! input_permutation, "
|
||||
" Tensor! output_permutation, int num_experts, "
|
||||
" int n, int k) -> ()",
|
||||
{stride_tag});
|
||||
ops.impl("get_cutlass_moe_mm_data", torch::kCUDA, &get_cutlass_moe_mm_data);
|
||||
|
||||
// Check if cutlass scaled_mm supports block quantization (used by DeepSeekV3)
|
||||
ops.def(
|
||||
"cutlass_scaled_mm_supports_block_fp8(int cuda_device_capability) -> "
|
||||
|
@ -10,8 +10,8 @@ document.addEventListener("DOMContentLoaded", function () {
|
||||
script.setAttribute("runllm-keyboard-shortcut", "Mod+j"); // cmd-j or ctrl-j to open the widget.
|
||||
script.setAttribute("runllm-name", "vLLM");
|
||||
script.setAttribute("runllm-position", "BOTTOM_RIGHT");
|
||||
script.setAttribute("runllm-position-y", "20%");
|
||||
script.setAttribute("runllm-position-x", "3%");
|
||||
script.setAttribute("runllm-position-y", "120px");
|
||||
script.setAttribute("runllm-position-x", "20px");
|
||||
script.setAttribute("runllm-assistant-id", "207");
|
||||
|
||||
script.async = true;
|
||||
|
@ -103,6 +103,11 @@ myst_url_schemes = {
|
||||
"title": "Pull Request #{{path}}",
|
||||
"classes": ["github"],
|
||||
},
|
||||
"gh-project": {
|
||||
"url": "https://github.com/vllm-project/projects/{{path}}",
|
||||
"title": "Project #{{path}}",
|
||||
"classes": ["github"],
|
||||
},
|
||||
"gh-dir": {
|
||||
"url": "https://github.com/vllm-project/vllm/tree/main/{{path}}",
|
||||
"title": "{{path}}",
|
||||
|
@ -11,6 +11,15 @@ We also believe in the power of community support; thus, answering queries, offe
|
||||
|
||||
Finally, one of the most impactful ways to support us is by raising awareness about vLLM. Talk about it in your blog posts and highlight how it's driving your incredible projects. Express your support on social media if you're using vLLM, or simply offer your appreciation by starring our repository!
|
||||
|
||||
## Job Board
|
||||
|
||||
Unsure on where to start? Check out the following links for tasks to work on:
|
||||
|
||||
- [Good first issues](https://github.com/vllm-project/vllm/issues?q=is%3Aissue%20state%3Aopen%20label%3A%22good%20first%20issue%22)
|
||||
- [Selected onboarding tasks](gh-project:6)
|
||||
- [New model requests](https://github.com/vllm-project/vllm/issues?q=is%3Aissue%20state%3Aopen%20label%3A%22new%20model%22)
|
||||
- [Models with multi-modal capabilities](gh-project:10)
|
||||
|
||||
## License
|
||||
|
||||
See <gh-file:LICENSE>.
|
||||
|
@ -34,11 +34,11 @@ If you need to use those dependencies (having accepted the license terms),
|
||||
create a custom Dockerfile on top of the base image with an extra layer that installs them:
|
||||
|
||||
```Dockerfile
|
||||
FROM vllm/vllm-openai:v0.8.0
|
||||
FROM vllm/vllm-openai:v0.8.2
|
||||
|
||||
# e.g. install the `audio` and `video` optional dependencies
|
||||
# NOTE: Make sure the version of vLLM matches the base image!
|
||||
RUN uv pip install vllm[audio,video]==0.8.0
|
||||
RUN uv pip install --system vllm[audio,video]==0.8.2
|
||||
```
|
||||
|
||||
:::
|
||||
@ -52,7 +52,7 @@ with an extra layer that installs their code from source:
|
||||
```Dockerfile
|
||||
FROM vllm/vllm-openai:latest
|
||||
|
||||
RUN uv pip install git+https://github.com/huggingface/transformers.git
|
||||
RUN uv pip install --system git+https://github.com/huggingface/transformers.git
|
||||
```
|
||||
|
||||
:::
|
||||
|
@ -15,12 +15,13 @@ Block 3: |<------------------ prefix -------------------->| |<--- block tokens -
|
||||
In the example above, the KV cache in the first block can be uniquely identified with the token “A gentle breeze stirred”. The third block can be uniquely identified with the tokens in the block “laughed in the distance”, along with the prefix tokens “A gentle breeze stirred the leaves as children”. Therefore, we can build the block hash of `hash(tuple[components])`, where components are:
|
||||
|
||||
* Parent hash value: The hash value of the parent hash block.
|
||||
* Block tokens: A tuple of tokens in this block. The reason to include the exact tokens is to reduce potential hash value collision.
|
||||
* Block tokens: A tuple of tokens in this block. The reason to include the exact tokens is to reduce potential hash value collision.
|
||||
* Extra hashes: Other values required to make this block unique, such as LoRA IDs and multi-modality input hashes (see the example below).
|
||||
|
||||
Note 1: We only cache full blocks.
|
||||
> **Note 1:** We only cache full blocks.
|
||||
|
||||
Note 2: The above hash key structure is not 100% collision free. Theoretically it’s still possible for the different prefix tokens to have the same hash value, but this should be nearly impossible to happen. Of course, contributions are welcome if you have an awesome idea to eliminate collusion entirely.
|
||||
> **Note 2:** The above hash key structure is not 100% collision free. Theoretically it’s still possible for the different prefix tokens to have the same hash value. To avoid any hash collisions **in a multi-tenant setup, we advise to use SHA256** as hash function instead of the default builtin hash.
|
||||
SHA256 is supported since vLLM v0.8.3 and must be enabled with a command line argument. It comes with a performance impact of about 100-200ns per token (~6ms for 50k tokens of context).
|
||||
|
||||
**A hashing example with multi-modality inputs**
|
||||
In this example, we illustrate how prefix caching works with multi-modality inputs (e.g., images). Assuming we have a request with the following messages:
|
||||
|
@ -4,7 +4,7 @@
|
||||
|
||||
vLLM offers support for reasoning models like [DeepSeek R1](https://huggingface.co/deepseek-ai/DeepSeek-R1), which are designed to generate outputs containing both reasoning steps and final conclusions.
|
||||
|
||||
Reasoning models return a additional `reasoning_content` field in their outputs, which contains the reasoning steps that led to the final conclusion. This field is not present in the outputs of other models.
|
||||
Reasoning models return an additional `reasoning_content` field in their outputs, which contains the reasoning steps that led to the final conclusion. This field is not present in the outputs of other models.
|
||||
|
||||
## Supported Models
|
||||
|
||||
@ -14,6 +14,9 @@ vLLM currently supports the following reasoning models:
|
||||
|--------------|-------------|------------------|-------------|
|
||||
| [DeepSeek R1 series](https://huggingface.co/collections/deepseek-ai/deepseek-r1-678e1e131c0169c0bc89728d) | `deepseek_r1` | `guided_json`, `guided_regex` | ❌ |
|
||||
| [QwQ-32B](https://huggingface.co/Qwen/QwQ-32B) | `deepseek_r1` | `guided_json`, `guided_regex` | ✅ |
|
||||
| [IBM Granite 3.2 language models](https://huggingface.co/collections/ibm-granite/granite-32-language-models-67b3bc8c13508f6d064cff9a) | `granite` | ❌ | ❌ |
|
||||
|
||||
- IBM Granite 3.2 reasoning is disabled by default; to enable it, you must also pass `thinking=True` in your `chat_template_kwargs`.
|
||||
|
||||
## Quickstart
|
||||
|
||||
@ -43,6 +46,7 @@ model = models.data[0].id
|
||||
|
||||
# Round 1
|
||||
messages = [{"role": "user", "content": "9.11 and 9.8, which is greater?"}]
|
||||
# For granite, add: `extra_body={"chat_template_kwargs": {"thinking": True}}`
|
||||
response = client.chat.completions.create(model=model, messages=messages)
|
||||
|
||||
reasoning_content = response.choices[0].message.reasoning_content
|
||||
@ -97,6 +101,7 @@ models = client.models.list()
|
||||
model = models.data[0].id
|
||||
|
||||
messages = [{"role": "user", "content": "9.11 and 9.8, which is greater?"}]
|
||||
# For granite, add: `extra_body={"chat_template_kwargs": {"thinking": True}}`
|
||||
stream = client.chat.completions.create(model=model,
|
||||
messages=messages,
|
||||
stream=True)
|
||||
|
@ -47,9 +47,9 @@ This living user guide outlines a few known **important changes and limitations*
|
||||
| **Logprobs Calculation** | <nobr>🟢 Functional</nobr> |
|
||||
| **LoRA** | <nobr>🟢 Functional ([PR #13096](https://github.com/vllm-project/vllm/pull/13096))</nobr>|
|
||||
| **Multimodal Models** | <nobr>🟢 Functional</nobr> |
|
||||
| **FP8 KV Cache** | <nobr>🟢 Functional on Hopper devices ([PR #15191](https://github.com/vllm-project/vllm/pull/15191))</nobr>|
|
||||
| **Spec Decode** | <nobr>🚧 WIP ([PR #13933](https://github.com/vllm-project/vllm/pull/13933))</nobr>|
|
||||
| **Prompt Logprobs with Prefix Caching** | <nobr>🟡 Planned ([RFC #13414](https://github.com/vllm-project/vllm/issues/13414))</nobr>|
|
||||
| **FP8 KV Cache** | <nobr>🟡 Planned</nobr> |
|
||||
| **Structured Output Alternative Backends** | <nobr>🟡 Planned</nobr> |
|
||||
| **Embedding Models** | <nobr>🟡 Planned ([RFC #12249](https://github.com/vllm-project/vllm/issues/12249))</nobr> |
|
||||
| **Mamba Models** | <nobr>🟡 Planned</nobr> |
|
||||
@ -129,9 +129,10 @@ in progress.
|
||||
- **Spec Decode**: Currently, only ngram-based spec decode is supported in V1. There
|
||||
will be follow-up work to support other types of spec decode (e.g., see [PR #13933](https://github.com/vllm-project/vllm/pull/13933)). We will prioritize the support for Eagle, MTP compared to draft model based spec decode.
|
||||
|
||||
#### Features to Be Supported
|
||||
- **Multimodal Models**: V1 is almost fully compatible with V0 except that interleaved modality input is not supported yet.
|
||||
See [here](https://github.com/orgs/vllm-project/projects/8) for the status of upcoming features and optimizations.
|
||||
|
||||
- **FP8 KV Cache**: While vLLM V1 introduces new FP8 kernels for model weight quantization, support for an FP8 key–value cache is not yet available. Users must continue using FP16 (or other supported precisions) for the KV cache.
|
||||
#### Features to Be Supported
|
||||
|
||||
- **Structured Output Alternative Backends**: Structured output alternative backends (outlines, guidance) support is planned. V1 currently
|
||||
supports only the `xgrammar:no_fallback` mode, meaning that it will error out if the output schema is unsupported by xgrammar.
|
||||
|
@ -43,7 +43,7 @@ vLLM is flexible and easy to use with:
|
||||
- Tensor parallelism and pipeline parallelism support for distributed inference
|
||||
- Streaming outputs
|
||||
- OpenAI-compatible API server
|
||||
- Support NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs, Gaudi® accelerators and GPUs, PowerPC CPUs, TPU, and AWS Trainium and Inferentia Accelerators.
|
||||
- Support NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs, Gaudi® accelerators and GPUs, IBM Power CPUs, TPU, and AWS Trainium and Inferentia Accelerators.
|
||||
- Prefix caching support
|
||||
- Multi-lora support
|
||||
|
||||
|
@ -57,10 +57,10 @@ llm = LLM(model=..., task="generate") # Name or path of your model
|
||||
llm.apply_model(lambda model: print(type(model)))
|
||||
```
|
||||
|
||||
If it is `TransformersModel` then it means it's based on Transformers!
|
||||
If it is `TransformersForCausalLM` then it means it's based on Transformers!
|
||||
|
||||
:::{tip}
|
||||
You can force the use of `TransformersModel` by setting `model_impl="transformers"` for <project:#offline-inference> or `--model-impl transformers` for the <project:#openai-compatible-server>.
|
||||
You can force the use of `TransformersForCausalLM` by setting `model_impl="transformers"` for <project:#offline-inference> or `--model-impl transformers` for the <project:#openai-compatible-server>.
|
||||
:::
|
||||
|
||||
:::{note}
|
||||
@ -119,7 +119,7 @@ Here is what happens in the background:
|
||||
|
||||
1. The config is loaded
|
||||
2. `MyModel` Python class is loaded from the `auto_map`, and we check that the model `_supports_attention_backend`.
|
||||
3. The `TransformersModel` backend is used. See <gh-file:vllm/model_executor/models/transformers.py>, which leverage `self.config._attn_implementation = "vllm"`, thus the need to use `ALL_ATTENTION_FUNCTION`.
|
||||
3. The `TransformersForCausalLM` backend is used. See <gh-file:vllm/model_executor/models/transformers.py>, which leverage `self.config._attn_implementation = "vllm"`, thus the need to use `ALL_ATTENTION_FUNCTION`.
|
||||
|
||||
To make your model compatible with tensor parallel, it needs:
|
||||
|
||||
@ -836,14 +836,14 @@ See [this page](#generative-models) for more information on how to use generativ
|
||||
* `openbmb/MiniCPM-o-2_6`, etc.
|
||||
* ✅︎
|
||||
* ✅︎
|
||||
*
|
||||
* ✅︎
|
||||
- * `MiniCPMV`
|
||||
* MiniCPM-V
|
||||
* T + I<sup>E+</sup> + V<sup>E+</sup>
|
||||
* `openbmb/MiniCPM-V-2` (see note), `openbmb/MiniCPM-Llama3-V-2_5`, `openbmb/MiniCPM-V-2_6`, etc.
|
||||
* ✅︎
|
||||
* ✅︎
|
||||
*
|
||||
* ✅︎
|
||||
- * `MllamaForConditionalGeneration`
|
||||
* Llama 3.2
|
||||
* T + I<sup>+</sup>
|
||||
@ -853,7 +853,7 @@ See [this page](#generative-models) for more information on how to use generativ
|
||||
*
|
||||
- * `MolmoForCausalLM`
|
||||
* Molmo
|
||||
* T + I
|
||||
* T + I<sup>+</sup>
|
||||
* `allenai/Molmo-7B-D-0924`, `allenai/Molmo-7B-O-0924`, etc.
|
||||
* ✅︎
|
||||
* ✅︎
|
||||
|
@ -18,7 +18,10 @@ llm = LLM(model="facebook/opt-125m")
|
||||
# that contain the prompt, generated text, and other information.
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
# Print the outputs.
|
||||
print("\nGenerated Outputs:\n" + "-" * 60)
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
print(f"Prompt: {prompt!r}")
|
||||
print(f"Output: {generated_text!r}")
|
||||
print("-" * 60)
|
@ -27,12 +27,13 @@ def main(args: dict):
|
||||
sampling_params.top_k = top_k
|
||||
|
||||
def print_outputs(outputs):
|
||||
print("\nGenerated Outputs:\n" + "-" * 80)
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}")
|
||||
print(f"Prompt: {prompt!r}\n")
|
||||
print(f"Generated text: {generated_text!r}")
|
||||
print("-" * 80)
|
||||
print("-" * 80)
|
||||
|
||||
print("=" * 80)
|
||||
|
||||
|
@ -23,12 +23,14 @@ def main(args: Namespace):
|
||||
outputs = model.classify(prompts)
|
||||
|
||||
# Print the outputs.
|
||||
print("\nGenerated Outputs:\n" + "-" * 60)
|
||||
for prompt, output in zip(prompts, outputs):
|
||||
probs = output.outputs.probs
|
||||
probs_trimmed = ((str(probs[:16])[:-1] +
|
||||
", ...]") if len(probs) > 16 else probs)
|
||||
print(f"Prompt: {prompt!r} | "
|
||||
print(f"Prompt: {prompt!r} \n"
|
||||
f"Class Probabilities: {probs_trimmed} (size={len(probs)})")
|
||||
print("-" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -23,12 +23,14 @@ def main(args: Namespace):
|
||||
outputs = model.embed(prompts)
|
||||
|
||||
# Print the outputs.
|
||||
print("\nGenerated Outputs:\n" + "-" * 60)
|
||||
for prompt, output in zip(prompts, outputs):
|
||||
embeds = output.outputs.embedding
|
||||
embeds_trimmed = ((str(embeds[:16])[:-1] +
|
||||
", ...]") if len(embeds) > 16 else embeds)
|
||||
print(f"Prompt: {prompt!r} | "
|
||||
print(f"Prompt: {prompt!r} \n"
|
||||
f"Embeddings: {embeds_trimmed} (size={len(embeds)})")
|
||||
print("-" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -22,9 +22,11 @@ def main(args: Namespace):
|
||||
outputs = model.score(text_1, texts_2)
|
||||
|
||||
# Print the outputs.
|
||||
print("\nGenerated Outputs:\n" + "-" * 60)
|
||||
for text_2, output in zip(texts_2, outputs):
|
||||
score = output.outputs.score
|
||||
print(f"Pair: {[text_1, text_2]!r} | Score: {score}")
|
||||
print(f"Pair: {[text_1, text_2]!r} \nScore: {score}")
|
||||
print("-" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -1,26 +1,49 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# usage:
|
||||
# VLLM_USE_V1=1 python examples/offline_inference/data_parallel.py
|
||||
# we need to have a launcher to create multiple data parallel
|
||||
# ranks. And each rank will create a vLLM instance to process its own prompts.
|
||||
"""
|
||||
Usage:
|
||||
Single node:
|
||||
python examples/offline_inference/data_parallel.py \
|
||||
--model="ibm-research/PowerMoE-3b" \
|
||||
--dp-size=2 \
|
||||
--tp-size=2
|
||||
|
||||
Multi-node:
|
||||
Node 0 (assume the node has ip of 10.99.48.128):
|
||||
python examples/offline_inference/data_parallel.py \
|
||||
--model="ibm-research/PowerMoE-3b" \
|
||||
--dp-size=2 \
|
||||
--tp-size=2 \
|
||||
--node-size=2 \
|
||||
--node-rank=0 \
|
||||
--master-addr=10.99.48.128 \
|
||||
--master-port=13345
|
||||
Node 1:
|
||||
python examples/offline_inference/data_parallel.py \
|
||||
--model="ibm-research/PowerMoE-3b" \
|
||||
--dp-size=2 \
|
||||
--tp-size=2 \
|
||||
--node-size=2 \
|
||||
--node-rank=1 \
|
||||
--master-addr=10.99.48.128 \
|
||||
--master-port=13345
|
||||
"""
|
||||
import os
|
||||
from time import sleep
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.utils import get_open_port
|
||||
|
||||
GPUs_per_dp_rank = 2
|
||||
DP_size = 2
|
||||
|
||||
|
||||
def main(dp_size, dp_rank, dp_master_ip, dp_master_port, GPUs_per_dp_rank):
|
||||
os.environ["VLLM_DP_RANK"] = str(dp_rank)
|
||||
def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
|
||||
dp_master_port, GPUs_per_dp_rank):
|
||||
os.environ["VLLM_DP_RANK"] = str(global_dp_rank)
|
||||
os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank)
|
||||
os.environ["VLLM_DP_SIZE"] = str(dp_size)
|
||||
os.environ["VLLM_DP_MASTER_IP"] = dp_master_ip
|
||||
os.environ["VLLM_DP_MASTER_PORT"] = str(dp_master_port)
|
||||
# set devices for each dp_rank
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
|
||||
str(i) for i in range(dp_rank * GPUs_per_dp_rank, (dp_rank + 1) *
|
||||
GPUs_per_dp_rank))
|
||||
|
||||
# CUDA_VISIBLE_DEVICES for each DP rank is set automatically inside the
|
||||
# engine processes.
|
||||
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
@ -28,20 +51,20 @@ def main(dp_size, dp_rank, dp_master_ip, dp_master_port, GPUs_per_dp_rank):
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
] * 100
|
||||
|
||||
# with DP, each rank should process different prompts.
|
||||
# usually all the DP ranks process a full dataset,
|
||||
# and each rank processes a different part of the dataset.
|
||||
promts_per_rank = len(prompts) // dp_size
|
||||
start = dp_rank * promts_per_rank
|
||||
start = global_dp_rank * promts_per_rank
|
||||
end = start + promts_per_rank
|
||||
prompts = prompts[start:end]
|
||||
if len(prompts) == 0:
|
||||
# if any rank has no prompts to process,
|
||||
# we need to set a placeholder prompt
|
||||
prompts = ["Placeholder"]
|
||||
print(f"DP rank {dp_rank} needs to process {len(prompts)} prompts")
|
||||
print(f"DP rank {global_dp_rank} needs to process {len(prompts)} prompts")
|
||||
|
||||
# Create a sampling params object.
|
||||
# since we are doing data parallel, every rank can have different
|
||||
@ -49,37 +72,96 @@ def main(dp_size, dp_rank, dp_master_ip, dp_master_port, GPUs_per_dp_rank):
|
||||
# ranks for demonstration.
|
||||
sampling_params = SamplingParams(temperature=0.8,
|
||||
top_p=0.95,
|
||||
max_tokens=16 * (dp_rank + 1))
|
||||
max_tokens=[16, 20][global_dp_rank % 2])
|
||||
|
||||
# Create an LLM.
|
||||
llm = LLM(model="ibm-research/PowerMoE-3b",
|
||||
llm = LLM(model=model,
|
||||
tensor_parallel_size=GPUs_per_dp_rank,
|
||||
enforce_eager=True,
|
||||
enable_expert_parallel=True)
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
# Print the outputs.
|
||||
for output in outputs:
|
||||
for i, output in enumerate(outputs):
|
||||
if i >= 5:
|
||||
# print only 5 outputs
|
||||
break
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"DP rank {dp_rank}, Prompt: {prompt!r}, "
|
||||
print(f"DP rank {global_dp_rank}, Prompt: {prompt!r}, "
|
||||
f"Generated text: {generated_text!r}")
|
||||
|
||||
# Give engines time to pause their processing loops before exiting.
|
||||
sleep(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description="Data Parallel Inference")
|
||||
parser.add_argument("--model",
|
||||
type=str,
|
||||
default="ibm-research/PowerMoE-3b",
|
||||
help="Model name or path")
|
||||
parser.add_argument("--dp-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Data parallel size")
|
||||
parser.add_argument("--tp-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Tensor parallel size")
|
||||
parser.add_argument("--node-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Total number of nodes")
|
||||
parser.add_argument("--node-rank",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Rank of the current node")
|
||||
parser.add_argument("--master-addr",
|
||||
type=str,
|
||||
default="",
|
||||
help="Master node IP address")
|
||||
parser.add_argument("--master-port",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Master node port")
|
||||
args = parser.parse_args()
|
||||
|
||||
dp_size = args.dp_size
|
||||
tp_size = args.tp_size
|
||||
node_size = args.node_size
|
||||
node_rank = args.node_rank
|
||||
|
||||
if node_size == 1:
|
||||
dp_master_ip = "127.0.0.1"
|
||||
dp_master_port = get_open_port()
|
||||
else:
|
||||
dp_master_ip = args.master_addr
|
||||
dp_master_port = args.master_port
|
||||
|
||||
assert dp_size % node_size == 0, "dp_size should be divisible by node_size"
|
||||
dp_per_node = dp_size // node_size
|
||||
|
||||
from multiprocessing import Process
|
||||
dp_master_ip = "127.0.0.1"
|
||||
dp_master_port = get_open_port()
|
||||
|
||||
procs = []
|
||||
for i in range(DP_size):
|
||||
for local_dp_rank, global_dp_rank in enumerate(
|
||||
range(node_rank * dp_per_node, (node_rank + 1) * dp_per_node)):
|
||||
proc = Process(target=main,
|
||||
args=(DP_size, i, dp_master_ip, dp_master_port,
|
||||
GPUs_per_dp_rank))
|
||||
args=(args.model, dp_size, local_dp_rank,
|
||||
global_dp_rank, dp_master_ip, dp_master_port,
|
||||
tp_size))
|
||||
proc.start()
|
||||
procs.append(proc)
|
||||
exit_code = 0
|
||||
for proc in procs:
|
||||
proc.join()
|
||||
if proc.exitcode:
|
||||
proc.join(timeout=300)
|
||||
if proc.exitcode is None:
|
||||
print(f"Killing process {proc.pid} that "
|
||||
f"didn't stop within 5 minutes.")
|
||||
proc.kill()
|
||||
exit_code = 1
|
||||
elif proc.exitcode:
|
||||
exit_code = proc.exitcode
|
||||
|
||||
exit(exit_code)
|
||||
|
@ -14,10 +14,7 @@ answers = [
|
||||
]
|
||||
N = 1
|
||||
# Currently, top-p sampling is disabled. `top_p` should be 1.0.
|
||||
sampling_params = SamplingParams(temperature=0.7,
|
||||
top_p=1.0,
|
||||
n=N,
|
||||
max_tokens=16)
|
||||
sampling_params = SamplingParams(temperature=0, top_p=1.0, n=N, max_tokens=16)
|
||||
|
||||
# Set `enforce_eager=True` to avoid ahead-of-time compilation.
|
||||
# In real workloads, `enforace_eager` should be `False`.
|
||||
|
@ -361,6 +361,7 @@ def run_llava_next_video(questions: list[str],
|
||||
engine_args = EngineArgs(
|
||||
model="llava-hf/LLaVA-NeXT-Video-7B-hf",
|
||||
max_model_len=8192,
|
||||
max_num_seqs=2,
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
)
|
||||
|
||||
|
@ -31,6 +31,7 @@ model = models.data[0].id
|
||||
|
||||
# Round 1
|
||||
messages = [{"role": "user", "content": "9.11 and 9.8, which is greater?"}]
|
||||
# For granite, add: `extra_body={"chat_template_kwargs": {"thinking": True}}`
|
||||
response = client.chat.completions.create(model=model, messages=messages)
|
||||
|
||||
reasoning_content = response.choices[0].message.reasoning_content
|
||||
|
@ -38,6 +38,7 @@ models = client.models.list()
|
||||
model = models.data[0].id
|
||||
|
||||
messages = [{"role": "user", "content": "9.11 and 9.8, which is greater?"}]
|
||||
# For granite, add: `extra_body={"chat_template_kwargs": {"thinking": True}}`
|
||||
stream = client.chat.completions.create(model=model,
|
||||
messages=messages,
|
||||
stream=True)
|
||||
|
@ -4,14 +4,14 @@
|
||||
# Dependencies for CPUs
|
||||
torch==2.6.0+cpu; platform_machine == "x86_64"
|
||||
torch==2.6.0; platform_system == "Darwin"
|
||||
torch==2.5.1; platform_machine == "ppc64le" or platform_machine == "aarch64"
|
||||
torch==2.6.0; platform_machine == "ppc64le" or platform_machine == "aarch64"
|
||||
torch==2.7.0.dev20250304; platform_machine == "s390x"
|
||||
|
||||
# required for the image processor of minicpm-o-2_6, this must be updated alongside torch
|
||||
torchaudio; platform_machine != "ppc64le" and platform_machine != "s390x"
|
||||
torchaudio==2.5.1; platform_machine == "ppc64le"
|
||||
torchaudio==2.6.0; platform_machine == "ppc64le"
|
||||
|
||||
# required for the image processor of phi3v, this must be updated alongside torch
|
||||
torchvision; platform_machine != "ppc64le" and platform_machine != "s390x"
|
||||
torchvision==0.20.1; platform_machine == "ppc64le"
|
||||
torchvision==0.21.0; platform_machine == "ppc64le"
|
||||
datasets # for benchmark scripts
|
||||
|
@ -4,7 +4,7 @@
|
||||
numba == 0.60.0 # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding
|
||||
|
||||
# Dependencies for NVIDIA GPUs
|
||||
ray[cgraph]>=2.43.0 # Ray Compiled Graph, required for pipeline parallelism in V1.
|
||||
ray[cgraph]>=2.43.0, !=2.44.* # Ray Compiled Graph, required for pipeline parallelism in V1.
|
||||
torch==2.6.0
|
||||
torchaudio==2.6.0
|
||||
# These must be updated alongside torch
|
||||
|
@ -17,7 +17,7 @@ vector_quantize_pytorch # required for minicpmo_26 test
|
||||
vocos # required for minicpmo_26 test
|
||||
peft
|
||||
pqdm
|
||||
ray[cgraph]>=2.43.0 # Ray Compiled Graph, required by pipeline parallelism tests
|
||||
ray[cgraph]>=2.43.0, !=2.44.* # Ray Compiled Graph, required by pipeline parallelism tests
|
||||
sentence-transformers # required for embedding tests
|
||||
soundfile # required for audio tests
|
||||
jiwer # required for audio tests
|
||||
|
@ -63,7 +63,8 @@ class LlamaConfig:
|
||||
factors.append((k, v))
|
||||
factors.sort()
|
||||
import hashlib
|
||||
return hashlib.md5(str(factors).encode()).hexdigest()
|
||||
return hashlib.md5(str(factors).encode(),
|
||||
usedforsecurity=False).hexdigest()
|
||||
|
||||
def __post_init__(self):
|
||||
assert self.mlp_size >= self.hidden_size
|
||||
|
@ -175,7 +175,7 @@ TEXT_GENERATION_MODELS = {
|
||||
"inceptionai/jais-13b-chat": PPTestSettings.fast(),
|
||||
"ai21labs/Jamba-tiny-dev": PPTestSettings.fast(),
|
||||
"meta-llama/Llama-3.2-1B-Instruct": PPTestSettings.detailed(),
|
||||
# Tests TransformersModel
|
||||
# Tests TransformersForCausalLM
|
||||
"ArthurZ/Ilama-3.2-1B": PPTestSettings.fast(),
|
||||
"openbmb/MiniCPM-2B-sft-bf16": PPTestSettings.fast(),
|
||||
"openbmb/MiniCPM3-4B": PPTestSettings.fast(),
|
||||
|
@ -0,0 +1,349 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from tests.entrypoints.openai.reasoning_parsers.utils import (
|
||||
DeltaMessage, run_reasoning_extraction)
|
||||
from vllm.entrypoints.openai.reasoning_parsers import (ReasoningParser,
|
||||
ReasoningParserManager)
|
||||
|
||||
parser_name = "granite"
|
||||
START_REASONING = "Here is my thought process:"
|
||||
START_RESPONSE = "Here is my response:"
|
||||
|
||||
SIMPLE_REASONING = {
|
||||
"output":
|
||||
f"{START_REASONING}This is a reasoning section{START_RESPONSE}This is the rest", #noqa: E501
|
||||
"reasoning_content": "This is a reasoning section",
|
||||
"content": "This is the rest",
|
||||
}
|
||||
COMPLETE_REASONING = {
|
||||
"output": f"{START_REASONING}This is a reasoning section{START_RESPONSE}",
|
||||
"reasoning_content": "This is a reasoning section",
|
||||
"content": None,
|
||||
}
|
||||
NO_REASONING = {
|
||||
"output": "This is content",
|
||||
"reasoning_content": None,
|
||||
"content": "This is content",
|
||||
}
|
||||
MULTIPLE_LINES = {
|
||||
"output":
|
||||
f"{START_REASONING}This\nThat{START_RESPONSE}This is the rest\nThat",
|
||||
"reasoning_content": "This\nThat",
|
||||
"content": "This is the rest\nThat",
|
||||
}
|
||||
REASONING_WITH_THINK = {
|
||||
"output":
|
||||
f"{START_REASONING}This is a reasoning section{START_RESPONSE}This is the rest", #noqa: E501
|
||||
"reasoning_content": "This is a reasoning section",
|
||||
"content": "This is the rest",
|
||||
}
|
||||
COMPLETE_REASONING_WITH_THINK = {
|
||||
"output": f"{START_REASONING}This is a reasoning section{START_RESPONSE}",
|
||||
"reasoning_content": "This is a reasoning section",
|
||||
"content": None,
|
||||
}
|
||||
MULTIPLE_LINES_WITH_THINK = {
|
||||
"output":
|
||||
f"{START_REASONING}This\nThat{START_RESPONSE}This is the rest\nThat",
|
||||
"reasoning_content": "This\nThat",
|
||||
"content": "This is the rest\nThat",
|
||||
}
|
||||
|
||||
TEST_CASES = [
|
||||
pytest.param(
|
||||
False,
|
||||
SIMPLE_REASONING,
|
||||
id="simple_reasoning",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
COMPLETE_REASONING,
|
||||
id="complete_reasoning",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
NO_REASONING,
|
||||
id="no_reasoning",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
MULTIPLE_LINES,
|
||||
id="multiple_lines",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
REASONING_WITH_THINK,
|
||||
id="reasoning_with_think",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
COMPLETE_REASONING_WITH_THINK,
|
||||
id="complete_reasoning_with_think",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
MULTIPLE_LINES_WITH_THINK,
|
||||
id="multiple_lines_with_think",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
SIMPLE_REASONING,
|
||||
id="simple_reasoning_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
COMPLETE_REASONING,
|
||||
id="complete_reasoning_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
NO_REASONING,
|
||||
id="no_reasoning_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
MULTIPLE_LINES,
|
||||
id="multiple_lines_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
REASONING_WITH_THINK,
|
||||
id="reasoning_with_think_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
COMPLETE_REASONING_WITH_THINK,
|
||||
id="complete_reasoning_with_think_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
MULTIPLE_LINES_WITH_THINK,
|
||||
id="multiple_lines_with_think_streaming",
|
||||
),
|
||||
]
|
||||
|
||||
# Global tokenizer initialization to avoid repeated loading
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming, param_dict", TEST_CASES)
|
||||
def test_reasoning(
|
||||
streaming: bool,
|
||||
param_dict: dict,
|
||||
):
|
||||
output = tokenizer.tokenize(param_dict["output"])
|
||||
# decode everything to tokens
|
||||
output_tokens: list[str] = [
|
||||
tokenizer.convert_tokens_to_string([token]) for token in output
|
||||
]
|
||||
parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(
|
||||
parser_name)(tokenizer)
|
||||
|
||||
reasoning, content = run_reasoning_extraction(parser,
|
||||
output_tokens,
|
||||
streaming=streaming)
|
||||
|
||||
assert reasoning == param_dict["reasoning_content"]
|
||||
assert content == param_dict["content"]
|
||||
|
||||
|
||||
# Additional tests for verifying the correctness of granite streaming; this
|
||||
# is complicated because granite uses multiple tokens to indicate when thinking
|
||||
# is starting / when it's starting its response, so skipping special tokens
|
||||
# is awkward.
|
||||
|
||||
### Handling the start of reasoning
|
||||
STREAMING_1 = {
|
||||
"previous_text": None,
|
||||
"current_text": "Here",
|
||||
"delta_text": "Here",
|
||||
"reasoning_content": None,
|
||||
"content": None,
|
||||
}
|
||||
# When we fail, we should give what was previously being silenced first
|
||||
STREAMING_2 = {
|
||||
"previous_text": "Here is my thought",
|
||||
"current_text": "Here is my thought failure",
|
||||
"delta_text": " failure",
|
||||
"reasoning_content": None,
|
||||
"content": "Here is my thought failure",
|
||||
}
|
||||
# But then after the first one, we should only add the delta text to content
|
||||
STREAMING_3 = {
|
||||
"previous_text": "Here wrong",
|
||||
"current_text": " words",
|
||||
"delta_text": " Here wrong words",
|
||||
"reasoning_content": None,
|
||||
"content": " words",
|
||||
}
|
||||
# But then after the first one, we should only add the delta text to content
|
||||
STREAMING_4 = {
|
||||
"previous_text": "Here is my thought",
|
||||
"current_text": "Here is my thought process:",
|
||||
"delta_text": " process:",
|
||||
"reasoning_content": None,
|
||||
"content": None,
|
||||
}
|
||||
# Reasoning started successfully; parse reasoning content
|
||||
STREAMING_5 = {
|
||||
"previous_text": "Here is my thought process:",
|
||||
"current_text": "Here is my thought process: foo",
|
||||
"delta_text": " foo",
|
||||
"reasoning_content": " foo",
|
||||
"content": None,
|
||||
}
|
||||
# Response special sequence has started, but not finished.
|
||||
STREAMING_6 = {
|
||||
"previous_text": "Here is my thought process: foo",
|
||||
"current_text": "Here is my thought process: foo Here is",
|
||||
"delta_text": " Here is",
|
||||
"reasoning_content": " ",
|
||||
"content": None,
|
||||
}
|
||||
# Response special sequence started, but was broken; the reasoning
|
||||
# content should be the content that was previously unused.
|
||||
STREAMING_7 = {
|
||||
"previous_text": "Here is my thought process: foo Here is",
|
||||
"current_text": "Here is my thought process: foo Here is Here",
|
||||
"delta_text": " Here",
|
||||
"reasoning_content": "Here is ",
|
||||
"content": None,
|
||||
}
|
||||
# Response special sequence is ongoing
|
||||
STREAMING_8 = {
|
||||
"previous_text": "Here is my thought process: foo Here is my response:",
|
||||
"current_text": "Here is my thought process: foo Here is my response: bar",
|
||||
"delta_text": " bar",
|
||||
"reasoning_content": None,
|
||||
"content": " bar",
|
||||
}
|
||||
# The delta text has everything; we should be able to correctly parse both
|
||||
STREAMING_9 = {
|
||||
"previous_text": None,
|
||||
"current_text": "Here is my thought process: foo Here is my response: bar",
|
||||
"delta_text": "Here is my thought process: foo Here is my response: bar",
|
||||
"reasoning_content": " foo ",
|
||||
"content": " bar",
|
||||
}
|
||||
## The Response is ongoing, and the delta mixes reasoning content / content
|
||||
STREAMING_10 = {
|
||||
"previous_text": "Here is my thought process: foo",
|
||||
"current_text":
|
||||
"Here is my thought process: foo bar Here is my response: baz",
|
||||
"delta_text": " bar Here is my response: baz",
|
||||
"reasoning_content": " bar ",
|
||||
"content": " baz",
|
||||
}
|
||||
# The delta text starts a new substring that might be a response special seq
|
||||
STREAMING_11 = {
|
||||
"previous_text":
|
||||
"Here is my thought process: This is a reasoning section ",
|
||||
"current_text":
|
||||
"Here is my thought process: This is a reasoning section Here",
|
||||
"delta_text": "Here",
|
||||
"reasoning_content": None,
|
||||
"content": None,
|
||||
}
|
||||
# The delta text is finishing the response special seq
|
||||
STREAMING_12 = {
|
||||
"previous_text": "Here is my thought process: foo Here is my response",
|
||||
"current_text": "Here is my thought process: foo Here is my response:",
|
||||
"delta_text": ":",
|
||||
"reasoning_content": None,
|
||||
"content": None,
|
||||
}
|
||||
STREAMING_13 = {
|
||||
"previous_text": "Here is my thought process: foo Here",
|
||||
"current_text": "Here is my thought process: foo Here was",
|
||||
"delta_text": " was",
|
||||
"reasoning_content": "Here was",
|
||||
"content": None,
|
||||
}
|
||||
|
||||
STREAMING_SUBCASES = [
|
||||
pytest.param(
|
||||
STREAMING_1,
|
||||
id="Starting reasoning special sequence",
|
||||
),
|
||||
pytest.param(
|
||||
STREAMING_2,
|
||||
id="Unexpected start reasoning sequence",
|
||||
),
|
||||
pytest.param(
|
||||
STREAMING_3,
|
||||
id="Continuing unexpected start reasoning sequence",
|
||||
),
|
||||
pytest.param(
|
||||
STREAMING_4,
|
||||
id="Only start reasoning sequence and nothing else",
|
||||
),
|
||||
pytest.param(
|
||||
STREAMING_5,
|
||||
id="Reasoning content has started",
|
||||
),
|
||||
pytest.param(
|
||||
STREAMING_6,
|
||||
id="Response special sequence has started",
|
||||
),
|
||||
pytest.param(
|
||||
STREAMING_7,
|
||||
id="Response special sequence reset",
|
||||
),
|
||||
pytest.param(
|
||||
STREAMING_8,
|
||||
id="Response text has started",
|
||||
),
|
||||
pytest.param(
|
||||
STREAMING_9,
|
||||
id="Delta contains everything",
|
||||
),
|
||||
pytest.param(
|
||||
STREAMING_10,
|
||||
id="Delta contains some reasoning and response",
|
||||
),
|
||||
pytest.param(
|
||||
STREAMING_11,
|
||||
id="Delta starts response sequence",
|
||||
),
|
||||
pytest.param(
|
||||
STREAMING_12,
|
||||
id="Delta finishes response sequence",
|
||||
),
|
||||
pytest.param(
|
||||
STREAMING_13,
|
||||
id="Delta breaks potential responise sequence",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("param_dict", STREAMING_SUBCASES)
|
||||
def test_streaming_subcases(param_dict):
|
||||
# Get all of the token IDs
|
||||
previous_token_ids = tokenizer.encode(
|
||||
param_dict["previous_text"]
|
||||
) if param_dict["previous_text"] is not None else []
|
||||
current_token_ids = tokenizer.encode(param_dict["current_text"])
|
||||
delta_token_ids = tokenizer.encode(param_dict["delta_text"])
|
||||
|
||||
parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(
|
||||
parser_name)(tokenizer)
|
||||
|
||||
response = parser.extract_reasoning_content_streaming(
|
||||
previous_text=param_dict["previous_text"],
|
||||
current_text=param_dict["current_text"],
|
||||
delta_text=param_dict["delta_text"],
|
||||
previous_token_ids=previous_token_ids,
|
||||
current_token_ids=current_token_ids,
|
||||
delta_token_ids=delta_token_ids,
|
||||
)
|
||||
# Streaming currently expects at least one of reasoning content / content,
|
||||
# so the response should return None in that case.
|
||||
if param_dict["reasoning_content"] is None and param_dict[
|
||||
"content"] is None:
|
||||
assert response is None
|
||||
else:
|
||||
assert isinstance(response, DeltaMessage)
|
||||
assert param_dict["reasoning_content"] == response.reasoning_content
|
||||
assert param_dict["content"] == response.content
|
@ -9,11 +9,11 @@ from transformers import __version__ as TRANSFORMERS_VERSION
|
||||
|
||||
from vllm.assets.image import ImageAsset
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.entrypoints.chat_utils import (_resolve_hf_chat_template,
|
||||
_try_extract_ast, load_chat_template,
|
||||
from vllm.entrypoints.chat_utils import (_try_extract_ast, load_chat_template,
|
||||
parse_chat_messages,
|
||||
parse_chat_messages_futures,
|
||||
resolve_chat_template_content_format)
|
||||
resolve_chat_template_content_format,
|
||||
resolve_hf_chat_template)
|
||||
from vllm.entrypoints.llm import apply_hf_chat_template
|
||||
from vllm.multimodal import MultiModalDataDict
|
||||
from vllm.multimodal.utils import encode_image_base64
|
||||
@ -747,7 +747,7 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools):
|
||||
}] if use_tools else None
|
||||
|
||||
# Test detecting the tokenizer's chat_template
|
||||
chat_template = _resolve_hf_chat_template(
|
||||
chat_template = resolve_hf_chat_template(
|
||||
tokenizer,
|
||||
chat_template=None,
|
||||
tools=tools,
|
||||
@ -781,7 +781,7 @@ def test_resolve_content_format_hf_defined(model, expected_format):
|
||||
tokenizer = tokenizer_group.tokenizer
|
||||
|
||||
# Test detecting the tokenizer's chat_template
|
||||
chat_template = _resolve_hf_chat_template(
|
||||
chat_template = resolve_hf_chat_template(
|
||||
tokenizer,
|
||||
chat_template=None,
|
||||
tools=None,
|
||||
|
@ -749,3 +749,72 @@ def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size,
|
||||
|
||||
ops.gather_cache(src_cache, dst, block_table, cu_seq_lens, batch_size)
|
||||
torch.testing.assert_close(dst, expected)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
|
||||
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS_MLA)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
|
||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.cpu_model
|
||||
@pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only")
|
||||
@torch.inference_mode()
|
||||
def test_concat_and_cache_mla_cpu(
|
||||
kv_lora_rank: int,
|
||||
qk_rope_head_dim: int,
|
||||
num_tokens: int,
|
||||
block_size: int,
|
||||
num_blocks: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
) -> None:
|
||||
device = "cpu"
|
||||
kv_cache_dtype = "auto"
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
|
||||
total_slots = num_blocks * block_size
|
||||
slot_mapping_lst = random.sample(range(total_slots), num_tokens)
|
||||
slot_mapping = torch.tensor(slot_mapping_lst,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
|
||||
kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device)
|
||||
k_pe = torch.randn(num_tokens,
|
||||
qk_rope_head_dim,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
entry_size = kv_lora_rank + qk_rope_head_dim
|
||||
|
||||
scale = torch.tensor(0.1, dtype=torch.float32, device=device)
|
||||
kv_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
|
||||
kv_cache_dtype, device)
|
||||
ref_temp = torch.zeros(*kv_cache.shape, dtype=dtype, device=device)
|
||||
|
||||
for i in range(num_tokens):
|
||||
slot = slot_mapping[i].item()
|
||||
block_idx = slot // block_size
|
||||
block_offset = slot % block_size
|
||||
ref_temp[block_idx, block_offset, :kv_lora_rank] = kv_c[i]
|
||||
ref_temp[block_idx, block_offset, kv_lora_rank:] = k_pe[i]
|
||||
|
||||
if kv_cache_dtype == "fp8":
|
||||
ref_kv_cache = torch.empty_like(ref_temp, dtype=kv_cache.dtype)
|
||||
ops.convert_fp8(ref_kv_cache,
|
||||
ref_temp,
|
||||
scale.item(),
|
||||
kv_dtype=kv_cache_dtype)
|
||||
else:
|
||||
ref_kv_cache = ref_temp
|
||||
|
||||
opcheck(
|
||||
torch.ops._C_cache_ops.concat_and_cache_mla,
|
||||
(kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale),
|
||||
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
|
||||
)
|
||||
|
||||
ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping,
|
||||
kv_cache_dtype, scale)
|
||||
torch.testing.assert_close(kv_cache, ref_kv_cache)
|
||||
|
@ -3,6 +3,7 @@
|
||||
|
||||
Run `pytest tests/kernels/test_cutlass.py`.
|
||||
"""
|
||||
import random
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -507,3 +508,136 @@ def test_cutlass_cuda_graph(per_act_token: bool, per_out_ch: bool):
|
||||
|
||||
def test_cutlass_support_opcheck():
|
||||
opcheck(torch.ops._C.cutlass_scaled_mm_supports_fp8, (capability, ))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_experts", [8, 64])
|
||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
||||
@pytest.mark.parametrize("use_bias", [False])
|
||||
@pytest.mark.skipif(
|
||||
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
|
||||
current_platform.get_device_capability()),
|
||||
reason="Grouped gemm is not supported on this GPU type.")
|
||||
def test_cutlass_fp8_group_gemm(num_experts: int, per_act_token: bool,
|
||||
per_out_ch: bool, use_bias: bool):
|
||||
|
||||
# Device and dtype setup
|
||||
device = "cuda"
|
||||
out_dtype = torch.half
|
||||
|
||||
# Create separate A, B, C tensors for each group
|
||||
a_tensors = []
|
||||
b_tensors = []
|
||||
a_scales_tensors = []
|
||||
b_scales_tensors = []
|
||||
baseline_tensors = []
|
||||
|
||||
expert_offsets = torch.zeros((num_experts + 1),
|
||||
device=device,
|
||||
dtype=torch.int32)
|
||||
|
||||
problem_sizes = torch.zeros((num_experts, 3),
|
||||
device=device,
|
||||
dtype=torch.int32)
|
||||
|
||||
if not per_act_token:
|
||||
one_scale_a = torch.randn((1, 1), device=device, dtype=torch.float32)
|
||||
|
||||
alignment = 16 # 128 // 8
|
||||
# For variation, each group has dimensions
|
||||
n_g = alignment * random.randint(1, 64)
|
||||
k_g = alignment * random.randint(1, 64)
|
||||
for g in range(num_experts):
|
||||
m_g = alignment * random.randint(1, 64)
|
||||
|
||||
expert_offsets[g + 1] = expert_offsets[g] + m_g
|
||||
problem_sizes[g][0] = m_g
|
||||
problem_sizes[g][1] = n_g
|
||||
problem_sizes[g][2] = k_g
|
||||
|
||||
m_a_scales = m_g if per_act_token else 1
|
||||
n_b_scales = n_g if per_out_ch else 1
|
||||
|
||||
print("shape:", m_g, n_g, k_g)
|
||||
|
||||
# Create group-specific A and B (FP8) and output (FP16/FP32)
|
||||
a_g = to_fp8(torch.randn((m_g, k_g), device=device))
|
||||
b_g = to_fp8(torch.randn((n_g, k_g), device=device).t())
|
||||
a_tensors.append(a_g)
|
||||
b_tensors.append(b_g)
|
||||
|
||||
# Set up A/B scales
|
||||
scale_b = torch.randn((1, n_b_scales),
|
||||
device=device,
|
||||
dtype=torch.float32)
|
||||
b_scales_tensors.append(scale_b)
|
||||
|
||||
if per_act_token:
|
||||
scale_a = torch.randn((m_a_scales, 1),
|
||||
device=device,
|
||||
dtype=torch.float32)
|
||||
a_scales_tensors.append(scale_a)
|
||||
else:
|
||||
scale_a = one_scale_a
|
||||
|
||||
# Compute baseline result for this group
|
||||
baseline_g = baseline_scaled_mm(a_g, b_g, scale_a, scale_b, out_dtype,
|
||||
None)
|
||||
baseline_tensors.append(baseline_g)
|
||||
|
||||
a_tensors_stacked = torch.empty((expert_offsets[num_experts], k_g),
|
||||
device=device,
|
||||
dtype=torch.float8_e4m3fn)
|
||||
b_tensors_stacked = torch.empty((num_experts, n_g, k_g),
|
||||
device=device,
|
||||
dtype=torch.float8_e4m3fn)
|
||||
|
||||
for g in range(num_experts):
|
||||
a_tensors_stacked[expert_offsets[g]:expert_offsets[g +
|
||||
1]] = a_tensors[g]
|
||||
b_tensors_stacked[g] = b_tensors[g].t()
|
||||
b_tensors_stacked = b_tensors_stacked.transpose(1, 2)
|
||||
|
||||
if per_act_token:
|
||||
a_scales_tensors_stacked = torch.empty(
|
||||
(expert_offsets[num_experts], 1),
|
||||
device=device,
|
||||
dtype=torch.float32)
|
||||
for g in range(num_experts):
|
||||
a_scales_tensors_stacked[
|
||||
expert_offsets[g]:expert_offsets[g + 1]] = a_scales_tensors[g]
|
||||
else:
|
||||
a_scales_tensors_stacked = one_scale_a
|
||||
|
||||
b_scales_tensors_stacked = torch.empty((num_experts, n_b_scales),
|
||||
device=device,
|
||||
dtype=torch.float32)
|
||||
for g in range(num_experts):
|
||||
b_scales_tensors_stacked[g] = b_scales_tensors[g]
|
||||
|
||||
out_tensors_stacked = torch.zeros((expert_offsets[num_experts], n_g),
|
||||
device=device,
|
||||
dtype=out_dtype)
|
||||
|
||||
ab_strides = torch.full((num_experts, ),
|
||||
a_tensors_stacked.stride(0),
|
||||
device="cuda",
|
||||
dtype=torch.int64)
|
||||
c_strides = torch.full((num_experts, ),
|
||||
out_tensors_stacked.stride(0),
|
||||
device="cuda",
|
||||
dtype=torch.int64)
|
||||
|
||||
ops.cutlass_moe_mm(out_tensors_stacked, a_tensors_stacked,
|
||||
b_tensors_stacked, a_scales_tensors_stacked,
|
||||
b_scales_tensors_stacked, expert_offsets[:-1],
|
||||
problem_sizes, ab_strides, ab_strides, c_strides)
|
||||
|
||||
# Validate each group's result against the baseline
|
||||
for g in range(num_experts):
|
||||
baseline = baseline_tensors[g]
|
||||
c = out_tensors_stacked[expert_offsets[g]:expert_offsets[g + 1]]
|
||||
print(baseline)
|
||||
print(c)
|
||||
print("*")
|
||||
torch.testing.assert_close(c, baseline, rtol=1e-2, atol=5e-4)
|
||||
|
244
tests/kernels/test_cutlass_moe.py
Normal file
244
tests/kernels/test_cutlass_moe.py
Normal file
@ -0,0 +1,244 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (cutlass_moe_fp8,
|
||||
fused_experts,
|
||||
fused_topk)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
NUM_EXPERTS = [40, 64]
|
||||
TOP_KS = [6, 8]
|
||||
|
||||
|
||||
def run(a: torch.Tensor, a_scale: torch.Tensor, w1_q: torch.Tensor,
|
||||
w2_q: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor,
|
||||
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||||
ab_strides1: torch.Tensor, c_strides1: torch.Tensor,
|
||||
ab_strides2: torch.Tensor, c_strides2: torch.Tensor):
|
||||
with set_current_vllm_config(
|
||||
VllmConfig(parallel_config=ParallelConfig(
|
||||
pipeline_parallel_size=1))):
|
||||
return cutlass_moe_fp8(a,
|
||||
w1_q,
|
||||
w2_q,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
ab_strides1,
|
||||
c_strides1,
|
||||
ab_strides2,
|
||||
c_strides2,
|
||||
a1_scale=a_scale)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [2, 64, 224])
|
||||
@pytest.mark.parametrize("n", [1024, 3072])
|
||||
@pytest.mark.parametrize("k", [1024, 1536])
|
||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
||||
@pytest.mark.skipif(
|
||||
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
|
||||
current_platform.get_device_capability()),
|
||||
reason="Grouped gemm is not supported on this GPU type.")
|
||||
def test_cutlass_moe_no_graph(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
per_act_token: bool,
|
||||
per_out_ch: bool,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
with set_current_vllm_config(
|
||||
VllmConfig(parallel_config=ParallelConfig(
|
||||
pipeline_parallel_size=1))):
|
||||
|
||||
dtype = torch.half
|
||||
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
||||
|
||||
# Get the right scale for tests.
|
||||
_, a_scale1 = ops.scaled_fp8_quant(
|
||||
a, use_per_token_if_dynamic=per_act_token)
|
||||
a_q, _ = ops.scaled_fp8_quant(a,
|
||||
a_scale1,
|
||||
use_per_token_if_dynamic=per_act_token)
|
||||
|
||||
a_d = a_q.float().mul(a_scale1).to(dtype)
|
||||
|
||||
n_b_scales = 2 * n if per_out_ch else 1
|
||||
k_b_scales = k if per_out_ch else 1
|
||||
|
||||
w1_q = torch.empty((e, 2 * n, k),
|
||||
device="cuda",
|
||||
dtype=torch.float8_e4m3fn)
|
||||
w2_q = torch.empty((e, k, n), device="cuda", dtype=torch.float8_e4m3fn)
|
||||
w1_scale = torch.empty((e, n_b_scales, 1),
|
||||
device="cuda",
|
||||
dtype=torch.float32)
|
||||
w2_scale = torch.empty((e, k_b_scales, 1),
|
||||
device="cuda",
|
||||
dtype=torch.float32)
|
||||
|
||||
ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
|
||||
c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64)
|
||||
ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64)
|
||||
c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
|
||||
|
||||
for expert in range(e):
|
||||
w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(
|
||||
w1[expert], use_per_token_if_dynamic=per_out_ch)
|
||||
w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(
|
||||
w2[expert], use_per_token_if_dynamic=per_out_ch)
|
||||
w1_q = w1_q.transpose(1, 2)
|
||||
w2_q = w2_q.transpose(1, 2)
|
||||
|
||||
ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
|
||||
c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64)
|
||||
ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64)
|
||||
c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
|
||||
|
||||
w1_d = torch.empty_like(w1)
|
||||
w2_d = torch.empty_like(w2)
|
||||
for expert in range(e):
|
||||
w1_d[expert] = (w1_q[expert].t().float() * w1_scale[expert]).half()
|
||||
w2_d[expert] = (w2_q[expert].t().float() * w2_scale[expert]).half()
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False)
|
||||
|
||||
triton_output = fused_experts(a_d, w1_d, w2_d, topk_weights, topk_ids)
|
||||
|
||||
cutlass_output = cutlass_moe_fp8(a,
|
||||
w1_q,
|
||||
w2_q,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
ab_strides1,
|
||||
c_strides1,
|
||||
ab_strides2,
|
||||
c_strides2,
|
||||
a1_scale=a_scale1)
|
||||
|
||||
print(triton_output)
|
||||
print(cutlass_output)
|
||||
print("*")
|
||||
|
||||
torch.testing.assert_close(triton_output,
|
||||
cutlass_output,
|
||||
atol=5e-2,
|
||||
rtol=1e-2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [2, 64, 224])
|
||||
@pytest.mark.parametrize("n", [1024, 3072])
|
||||
@pytest.mark.parametrize("k", [1024, 1536])
|
||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
||||
@pytest.mark.skipif(
|
||||
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
|
||||
current_platform.get_device_capability()),
|
||||
reason="Grouped gemm is not supported on this GPU type.")
|
||||
def test_cutlass_moe_cuda_graph(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
per_act_token: bool,
|
||||
per_out_ch: bool,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
with set_current_vllm_config(
|
||||
VllmConfig(parallel_config=ParallelConfig(
|
||||
pipeline_parallel_size=1))):
|
||||
|
||||
dtype = torch.half
|
||||
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
||||
|
||||
# Get the right scale for tests.
|
||||
_, a_scale1 = ops.scaled_fp8_quant(
|
||||
a, use_per_token_if_dynamic=per_act_token)
|
||||
a_q, _ = ops.scaled_fp8_quant(a,
|
||||
a_scale1,
|
||||
use_per_token_if_dynamic=per_act_token)
|
||||
|
||||
a_d = a_q.float().mul(a_scale1).to(dtype)
|
||||
|
||||
n_b_scales = 2 * n if per_out_ch else 1
|
||||
k_b_scales = k if per_out_ch else 1
|
||||
|
||||
w1_q = torch.empty((e, 2 * n, k),
|
||||
device="cuda",
|
||||
dtype=torch.float8_e4m3fn)
|
||||
w2_q = torch.empty((e, k, n), device="cuda", dtype=torch.float8_e4m3fn)
|
||||
w1_scale = torch.empty((e, n_b_scales, 1),
|
||||
device="cuda",
|
||||
dtype=torch.float32)
|
||||
w2_scale = torch.empty((e, k_b_scales, 1),
|
||||
device="cuda",
|
||||
dtype=torch.float32)
|
||||
|
||||
ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
|
||||
c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64)
|
||||
ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64)
|
||||
c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
|
||||
|
||||
for expert in range(e):
|
||||
w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(
|
||||
w1[expert], use_per_token_if_dynamic=per_out_ch)
|
||||
w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(
|
||||
w2[expert], use_per_token_if_dynamic=per_out_ch)
|
||||
w1_q = w1_q.transpose(1, 2)
|
||||
w2_q = w2_q.transpose(1, 2)
|
||||
|
||||
ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
|
||||
c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64)
|
||||
ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64)
|
||||
c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
|
||||
|
||||
w1_d = torch.empty_like(w1)
|
||||
w2_d = torch.empty_like(w2)
|
||||
for expert in range(e):
|
||||
w1_d[expert] = (w1_q[expert].t().float() * w1_scale[expert]).half()
|
||||
w2_d[expert] = (w2_q[expert].t().float() * w2_scale[expert]).half()
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False)
|
||||
|
||||
triton_output = fused_experts(a_d, w1_d, w2_d, topk_weights, topk_ids)
|
||||
|
||||
stream = torch.cuda.Stream()
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph, stream=stream):
|
||||
cutlass_output = run(a, a_scale1, w1_q, w2_q, w1_scale, w2_scale,
|
||||
topk_weights, topk_ids, ab_strides1,
|
||||
c_strides1, ab_strides2, c_strides2)
|
||||
torch.cuda.synchronize()
|
||||
graph.replay()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
print(triton_output)
|
||||
print(cutlass_output)
|
||||
print("*")
|
||||
|
||||
torch.testing.assert_close(triton_output,
|
||||
cutlass_output,
|
||||
atol=9e-2,
|
||||
rtol=1e-2)
|
94
tests/kernels/test_mla_decode_cpu.py
Normal file
94
tests/kernels/test_mla_decode_cpu.py
Normal file
@ -0,0 +1,94 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
def cdiv(a, b):
|
||||
return (a + b - 1) // b
|
||||
|
||||
|
||||
def ref_mla(
|
||||
out: Tensor, # (bs, num_heads, v_head_dim)
|
||||
query: Tensor, # (bs, num_heads, head_dim)
|
||||
kv_cache: Tensor, # (num_blocks, block_size, head_dim)
|
||||
scale: float,
|
||||
block_tables: Tensor, # (bs, max_num_blocks)
|
||||
seq_lens: Tensor, # (bs,)
|
||||
):
|
||||
bs, num_heads, v_head_dim = out.shape
|
||||
head_dim = query.shape[2]
|
||||
|
||||
for i in range(bs):
|
||||
# gather and flatten KV-cache
|
||||
kv = kv_cache[
|
||||
block_tables[i]] # (max_num_blocks, block_size, head_dim)
|
||||
kv = kv.view(1, -1,
|
||||
head_dim)[:, :seq_lens[i]] # (1, seq_len, head_dim)
|
||||
v = kv[:, :, :v_head_dim]
|
||||
|
||||
q = query[i].view(num_heads, 1, head_dim)
|
||||
o = F.scaled_dot_product_attention(q,
|
||||
kv,
|
||||
v,
|
||||
scale=scale,
|
||||
enable_gqa=True)
|
||||
out[i] = o.view(num_heads, v_head_dim)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@pytest.mark.parametrize("bs", [4])
|
||||
@pytest.mark.parametrize("mean_seq_len", [256])
|
||||
@pytest.mark.parametrize("h_q", [16])
|
||||
@pytest.mark.parametrize("d", [576])
|
||||
@pytest.mark.parametrize("dv", [512])
|
||||
@pytest.mark.parametrize("block_size", [16])
|
||||
@pytest.mark.parametrize("dtype", [torch.float, torch.half, torch.bfloat16])
|
||||
@pytest.mark.parametrize("varlen", [False, True])
|
||||
@pytest.mark.cpu_model
|
||||
@pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only")
|
||||
def test_mla_decode_cpu(
|
||||
bs: int,
|
||||
mean_seq_len: int,
|
||||
h_q: int,
|
||||
d: int,
|
||||
dv: int,
|
||||
block_size: int,
|
||||
dtype: torch.dtype,
|
||||
varlen: bool,
|
||||
):
|
||||
torch.set_default_dtype(dtype)
|
||||
torch.manual_seed(0)
|
||||
|
||||
scale = d**(-0.5)
|
||||
if varlen:
|
||||
seq_lens = torch.empty(bs).normal_(mean_seq_len, mean_seq_len / 2)
|
||||
seq_lens = seq_lens.clip(2).to(torch.int32)
|
||||
else:
|
||||
seq_lens = torch.full((bs, ), mean_seq_len, dtype=torch.int32)
|
||||
max_seq_len = seq_lens.max().item()
|
||||
seqlen_pad = cdiv(max_seq_len, 256) * 256 # is this necessary?
|
||||
|
||||
q = torch.randn(bs, h_q, d)
|
||||
block_table = torch.arange(bs * seqlen_pad // block_size,
|
||||
dtype=torch.int32)
|
||||
block_table = block_table.view(bs, seqlen_pad // block_size)
|
||||
|
||||
kv_cache = torch.randn(block_table.numel(), block_size, d)
|
||||
for i, seq_len in enumerate(seq_lens.tolist()):
|
||||
kv_cache.view(bs, seqlen_pad, d)[i, seq_len:] = float("nan")
|
||||
|
||||
out_mla = q.new_zeros(bs, h_q, dv)
|
||||
ops.mla_decode_kvcache_cpu(out_mla, q, kv_cache, scale, block_table,
|
||||
seq_lens)
|
||||
|
||||
out_ref = q.new_zeros(bs, h_q, dv)
|
||||
ref_mla(out_ref, q, kv_cache, scale, block_table, seq_lens)
|
||||
|
||||
assert not out_mla.isnan().any(), "Likely read out of bounds"
|
||||
torch.testing.assert_close(out_mla, out_ref)
|
@ -3,7 +3,6 @@
|
||||
|
||||
Run `pytest tests/kernels/test_moe.py`.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
@ -216,11 +215,17 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
|
||||
@pytest.mark.parametrize("dtype",
|
||||
[torch.float32, torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("padding", [True, False])
|
||||
@pytest.mark.parametrize(
|
||||
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
|
||||
@torch.inference_mode()
|
||||
def test_mixtral_moe(dtype: torch.dtype, padding: bool):
|
||||
def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool,
|
||||
monkeypatch):
|
||||
"""Make sure our Mixtral MoE implementation agrees with the one from
|
||||
huggingface."""
|
||||
|
||||
if use_rocm_aiter:
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||
|
||||
# Instantiate our and huggingface's MoE blocks
|
||||
config = MixtralConfig()
|
||||
hf_moe = MixtralSparseMoeBlock(config).to(dtype).to("cuda")
|
||||
@ -268,10 +273,18 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool):
|
||||
torch.bfloat16: 1e-2,
|
||||
}
|
||||
|
||||
torch.testing.assert_close(hf_states.flatten(0, 1),
|
||||
vllm_states,
|
||||
rtol=mixtral_moe_tol[dtype],
|
||||
atol=mixtral_moe_tol[dtype])
|
||||
if use_rocm_aiter:
|
||||
# The values of rtol and atol are set based on the tests in ROCM AITER package. # noqa: E501
|
||||
# https://github.com/ROCm/aiter/blob/dfed377f4be7da96ca2d75ac0761f569676f7240/op_tests/test_moe.py#L174 # noqa: E501
|
||||
torch.testing.assert_close(hf_states.flatten(0, 1),
|
||||
vllm_states,
|
||||
rtol=0.01,
|
||||
atol=100)
|
||||
else:
|
||||
torch.testing.assert_close(hf_states.flatten(0, 1),
|
||||
vllm_states,
|
||||
rtol=mixtral_moe_tol[dtype],
|
||||
atol=mixtral_moe_tol[dtype])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [1, 33, 64, 222])
|
||||
|
@ -241,39 +241,6 @@ def long_context_lora_files_16k_1():
|
||||
return snapshot_download(repo_id="SangBinCho/long_context_16k_testing_1")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def long_context_lora_files_16k_2():
|
||||
return snapshot_download(repo_id="SangBinCho/long_context_16k_testing_2")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def long_context_lora_files_32k():
|
||||
return snapshot_download(repo_id="SangBinCho/long_context_32k_testing")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def long_context_infos(long_context_lora_files_16k_1,
|
||||
long_context_lora_files_16k_2,
|
||||
long_context_lora_files_32k):
|
||||
cleanup_dist_env_and_memory(shutdown_ray=True)
|
||||
infos: dict[int, ContextInfo] = {}
|
||||
for lora_checkpoint_info in LONG_LORA_INFOS:
|
||||
lora_id = lora_checkpoint_info["lora_id"]
|
||||
if lora_id == 1:
|
||||
lora = long_context_lora_files_16k_1
|
||||
elif lora_id == 2:
|
||||
lora = long_context_lora_files_16k_2
|
||||
elif lora_id == 3:
|
||||
lora = long_context_lora_files_32k
|
||||
else:
|
||||
raise AssertionError("Unknown lora id")
|
||||
infos[lora_id] = {
|
||||
"context_length": lora_checkpoint_info["context_length"],
|
||||
"lora": lora,
|
||||
}
|
||||
return infos
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llama_2_7b_engine_extra_embeddings():
|
||||
cleanup_dist_env_and_memory(shutdown_ray=True)
|
||||
|
File diff suppressed because one or more lines are too long
@ -1,301 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import ast
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import vllm
|
||||
from vllm import SamplingParams
|
||||
from vllm.lora.layers import LinearScalingRotaryEmbeddingWithLoRA
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.layers.rotary_embedding import (
|
||||
LinearScalingRotaryEmbedding)
|
||||
|
||||
from .data.long_context_test_data import prompts_and_responses
|
||||
|
||||
context_len_to_scaling_factor = {
|
||||
"16k": 4,
|
||||
"32k": 8,
|
||||
}
|
||||
|
||||
# We use the same sampling params for all requests
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0,
|
||||
max_tokens=100,
|
||||
)
|
||||
|
||||
|
||||
def _create_lora_request(lora_id, long_context_infos):
|
||||
context_len = long_context_infos[lora_id]["context_length"]
|
||||
scaling_factor = context_len_to_scaling_factor[context_len]
|
||||
return LoRARequest(
|
||||
# There are 2 LoRAs for 16K, we need to add lora_id to indicate
|
||||
# they are different LoRAs.
|
||||
context_len + str(lora_id),
|
||||
lora_id,
|
||||
long_context_infos[lora_id]["lora"],
|
||||
None,
|
||||
4096 * scaling_factor,
|
||||
)
|
||||
|
||||
|
||||
def evaluate_json_response(model_response, golden_response):
|
||||
"""Evaluates the model response against the golden response.
|
||||
|
||||
Returns a score between 0 and 1, where 1 is a perfect match and 0 is no
|
||||
match. The score quantifies how well the model is able to extract the
|
||||
golden JSON from the long context.
|
||||
"""
|
||||
try:
|
||||
model_response = ast.literal_eval(model_response)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Model response is not a valid JSON. Expected {golden_response}, "
|
||||
f"got {model_response}") from e
|
||||
|
||||
# Normally, we would flatten the dictionary and compare the values, but in
|
||||
# this case, we know that the dictionary is only 2 levels deep
|
||||
positive_values = 0
|
||||
total_values = 0
|
||||
# We look at all the attributes of the person that we are extracting a
|
||||
# biography of and copmare them to the golden response
|
||||
for person_attribute, person_attribute_value in golden_response.items():
|
||||
if person_attribute in model_response:
|
||||
if isinstance(person_attribute_value, dict):
|
||||
for (sub_attribute,
|
||||
sub_attribute_value) in person_attribute_value.items():
|
||||
total_values += 1
|
||||
if sub_attribute in model_response[
|
||||
person_attribute] and model_response[
|
||||
person_attribute][
|
||||
sub_attribute] == sub_attribute_value:
|
||||
positive_values += 1
|
||||
else:
|
||||
total_values += 1
|
||||
if model_response[person_attribute] == person_attribute_value:
|
||||
positive_values += 1
|
||||
else:
|
||||
# We count a missing sub-dict as a single missed value.
|
||||
total_values += 1
|
||||
|
||||
# Return a score between 0 and 1
|
||||
return positive_values / total_values
|
||||
|
||||
|
||||
def generate(
|
||||
llm: vllm.LLM,
|
||||
inputs: tuple[str, SamplingParams, Optional[LoRARequest]],
|
||||
):
|
||||
prompts, sampling_param, lora_request = inputs
|
||||
outputs = llm.generate(prompts, sampling_param, lora_request=lora_request)
|
||||
return outputs[0].outputs[0].text.strip()
|
||||
|
||||
|
||||
def batched_generate(
|
||||
llm: vllm.LLM,
|
||||
inputs: list[tuple[str, SamplingParams, Optional[LoRARequest]]],
|
||||
):
|
||||
for input in inputs:
|
||||
prompt, sampling_param, lora_req = input
|
||||
# Add requests to the engine and run the engine
|
||||
llm._validate_and_add_requests(prompt,
|
||||
sampling_param,
|
||||
lora_request=lora_req,
|
||||
prompt_adapter_request=None)
|
||||
|
||||
outputs = llm._run_engine(use_tqdm=True)
|
||||
return [outputs[i].outputs[0].text.strip() for i in range(len(outputs))]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def lora_llm(long_context_infos):
|
||||
scaling_factors = [
|
||||
context_len_to_scaling_factor[info["context_length"]]
|
||||
for info in long_context_infos.values()
|
||||
]
|
||||
|
||||
llm = vllm.LLM(
|
||||
"meta-llama/Llama-2-13b-chat-hf",
|
||||
enable_lora=True,
|
||||
max_num_seqs=16,
|
||||
max_loras=2,
|
||||
long_lora_scaling_factors=tuple(scaling_factors),
|
||||
max_num_batched_tokens=4096 * 8,
|
||||
tensor_parallel_size=4,
|
||||
# FIXME enable async output processor
|
||||
disable_async_output_proc=True,
|
||||
distributed_executor_backend="mp",
|
||||
enable_chunked_prefill=True)
|
||||
yield llm
|
||||
del llm
|
||||
|
||||
|
||||
def test_rotary_emb_replaced(dist_init):
|
||||
"""Verify rotary emb in all the layers are replaced"""
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.worker.model_runner import ModelRunner
|
||||
engine_args = EngineArgs("meta-llama/Llama-2-7b-hf",
|
||||
long_lora_scaling_factors=(4.0, ),
|
||||
enable_lora=True)
|
||||
engine_config = engine_args.create_engine_config()
|
||||
model_runner = ModelRunner(
|
||||
vllm_config=engine_config,
|
||||
is_driver_worker=True,
|
||||
)
|
||||
model_runner.load_model()
|
||||
rotary_emb_count = 0
|
||||
for module_name, module in model_runner.model.named_modules(
|
||||
remove_duplicate=False):
|
||||
if "rotary_emb" in module_name:
|
||||
if "base_layer" not in module_name:
|
||||
rotary_emb_count += 1
|
||||
assert isinstance(module, LinearScalingRotaryEmbeddingWithLoRA)
|
||||
else:
|
||||
assert isinstance(module, LinearScalingRotaryEmbedding)
|
||||
# Llama 2 has 32 layers.
|
||||
assert rotary_emb_count == 32
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_batched_rope_kernel(lora_llm, long_context_infos):
|
||||
"""We test the batched kernel by comparing the results of batched an
|
||||
non-batched generation.
|
||||
"""
|
||||
# Create non batched results first to compare against batched results
|
||||
non_batched_results: list[str] = []
|
||||
|
||||
for lora_id, info in long_context_infos.items():
|
||||
context_len = info["context_length"]
|
||||
lora_prompt = (prompts_and_responses[context_len][0]["prompt"],
|
||||
sampling_params,
|
||||
_create_lora_request(lora_id, long_context_infos))
|
||||
lora_output = generate(lora_llm, lora_prompt)
|
||||
non_batched_results.append(lora_output)
|
||||
|
||||
# Create batched results
|
||||
# Each element of the batch must be
|
||||
# (prompt, prompt_sampling_params, prompt_lora_request)
|
||||
batched_prompts: list[tuple[str, SamplingParams,
|
||||
Optional[LoRARequest]]] = []
|
||||
for lora_id, info in long_context_infos.items():
|
||||
context_len = info["context_length"]
|
||||
batched_prompts.extend([
|
||||
(prompts_and_responses[context_len][0]["prompt"], sampling_params,
|
||||
_create_lora_request(lora_id, long_context_infos))
|
||||
])
|
||||
batched_results = batched_generate(lora_llm, batched_prompts)
|
||||
|
||||
# Results should be the same
|
||||
for non_batched, batched in zip(non_batched_results, batched_results):
|
||||
assert non_batched == batched, (
|
||||
"Non batched and batched results should be the "
|
||||
f"same:\n{batched}\n{non_batched}")
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_self_consistency(lora_llm, long_context_infos):
|
||||
"""We test consistency of the batched kernel by permuting batched
|
||||
inputs and comparing the results to the non-permuted batched results.
|
||||
"""
|
||||
num_loras = len(long_context_infos)
|
||||
|
||||
# Create results in order of long_context_infos
|
||||
batched_prompts: list[tuple[str, SamplingParams,
|
||||
Optional[LoRARequest]]] = []
|
||||
for lora_id, info in long_context_infos.items():
|
||||
context_len = info["context_length"]
|
||||
batched_prompts.extend([
|
||||
(prompts_and_responses[context_len][0]["prompt"], sampling_params,
|
||||
_create_lora_request(lora_id, long_context_infos))
|
||||
])
|
||||
|
||||
batched_results = batched_generate(lora_llm, batched_prompts)
|
||||
|
||||
permutation = np.random.default_rng(seed=42).permutation(num_loras)
|
||||
|
||||
# Create results in random order of permutation
|
||||
batched_prompts = []
|
||||
for i in permutation:
|
||||
lora_id, info = list(long_context_infos.items())[i]
|
||||
context_len = info["context_length"]
|
||||
batched_prompts.extend([
|
||||
(prompts_and_responses[context_len][0]["prompt"], sampling_params,
|
||||
_create_lora_request(lora_id, long_context_infos))
|
||||
])
|
||||
|
||||
permutated_batched_results = batched_generate(lora_llm, batched_prompts)
|
||||
|
||||
# Results should be the same
|
||||
for i in range(num_loras):
|
||||
assert batched_results[i] == permutated_batched_results[
|
||||
permutation[i]], (
|
||||
f"Results should be the same:\n{batched_results[i]}"
|
||||
f"\n{permutated_batched_results[permutation[i]]}")
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_quality(lora_llm, long_context_infos):
|
||||
"""We test the quality of the answers given by the LoRA model by
|
||||
comparing the generated text to the merged model's outputs.
|
||||
|
||||
This is effectively a mini-benchmark over four prompts.
|
||||
If this test fails, this indicates that the quality of the LoRA model
|
||||
is suboptimal compared to the merged model. For example, if the model
|
||||
does not output valid dictionaries, this test will fail.
|
||||
|
||||
If needed for testing, the merged versions of the models are available
|
||||
as part of the `conftest`.
|
||||
|
||||
The test is expected to run for about 1 minute on a p4de.24xlarge
|
||||
instance.
|
||||
"""
|
||||
scores: list[float] = []
|
||||
for lora_id, info in long_context_infos.items():
|
||||
context_len = info["context_length"]
|
||||
for prompt_and_response in prompts_and_responses[context_len]:
|
||||
lora_prompt = (prompt_and_response["prompt"], sampling_params,
|
||||
_create_lora_request(lora_id, long_context_infos))
|
||||
response = generate(lora_llm, lora_prompt)
|
||||
golden_answer = prompt_and_response["golden_answer"]
|
||||
score = evaluate_json_response(response, golden_answer)
|
||||
scores.append(score)
|
||||
assert score > 0.3, ("Quality of the answer is not good enough. "
|
||||
f"Expected {golden_answer}, got {response}")
|
||||
assert np.mean(scores) > 0.5
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_max_len(lora_llm, long_context_infos):
|
||||
"""Test that we raise an ValueError when the input of a given LoRA
|
||||
model exceeds the maximum length."""
|
||||
# Since each LoRA model has a different maximum length, we need to
|
||||
# test each one separately
|
||||
for lora_id, info in long_context_infos.items():
|
||||
context_len = info["context_length"]
|
||||
lora_request = _create_lora_request(lora_id, long_context_infos)
|
||||
# Good prompt should be fine
|
||||
good_prompt = prompts_and_responses[context_len][0]["prompt"]
|
||||
generate(lora_llm, (good_prompt, sampling_params, lora_request))
|
||||
# Bad prompt should raise an error
|
||||
bad_prompt = good_prompt * 2
|
||||
with pytest.raises(ValueError):
|
||||
generate(lora_llm, (bad_prompt, sampling_params, lora_request))
|
||||
|
||||
# Also test batched
|
||||
batched_prompts: list[tuple[str, SamplingParams,
|
||||
Optional[LoRARequest]]] = []
|
||||
for lora_id_with_bad_inputs in long_context_infos:
|
||||
for lora_id, info in long_context_infos.items():
|
||||
context_len = info["context_length"]
|
||||
batched_prompts.extend([
|
||||
(prompts_and_responses[context_len][0]["prompt"] *
|
||||
(2 if lora_id == lora_id_with_bad_inputs else 1),
|
||||
sampling_params,
|
||||
_create_lora_request(lora_id, long_context_infos))
|
||||
])
|
||||
# Turn good prompt into bad prompt inside of batched prompts
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
batched_generate(lora_llm, batched_prompts)
|
@ -7,6 +7,10 @@ from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.activation import (GeluAndMul,
|
||||
ReLUSquaredActivation,
|
||||
SiluAndMul)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
dispatch_fused_experts_func, dispatch_topk_func,
|
||||
torch_vllm_inplace_fused_experts, torch_vllm_outplace_fused_experts,
|
||||
vllm_topk_softmax)
|
||||
from vllm.model_executor.layers.layernorm import (
|
||||
RMSNorm, dispatch_cuda_rmsnorm_func, fused_add_rms_norm, rms_norm,
|
||||
rocm_aiter_fused_add_rms_norm, rocm_aiter_rms_norm)
|
||||
@ -92,6 +96,38 @@ def test_enabled_ops_invalid(env: str):
|
||||
RMSNorm(1024).enabled()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
|
||||
def test_topk_dispatch(use_rocm_aiter: str, monkeypatch):
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
|
||||
topk_func = dispatch_topk_func()
|
||||
|
||||
if current_platform.is_rocm() and int(use_rocm_aiter):
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
rocm_aiter_topk_softmax)
|
||||
|
||||
assert topk_func == rocm_aiter_topk_softmax
|
||||
else:
|
||||
assert topk_func == vllm_topk_softmax
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
|
||||
@pytest.mark.parametrize("inplace", [True, False])
|
||||
def test_fused_experts_dispatch(use_rocm_aiter: str, inplace: bool,
|
||||
monkeypatch):
|
||||
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
|
||||
fused_experts_func = dispatch_fused_experts_func(inplace)
|
||||
if current_platform.is_rocm() and int(use_rocm_aiter):
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
rocm_aiter_fused_experts)
|
||||
|
||||
assert fused_experts_func == rocm_aiter_fused_experts
|
||||
elif inplace:
|
||||
assert fused_experts_func == torch_vllm_inplace_fused_experts
|
||||
else:
|
||||
assert fused_experts_func == torch_vllm_outplace_fused_experts
|
||||
|
||||
|
||||
@pytest.mark.parametrize("add_residual", [True, False])
|
||||
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
|
||||
@pytest.mark.parametrize("use_rocm_aiter_norm", ["0", "1"])
|
||||
|
@ -174,15 +174,8 @@ SAMPLE_JSON_SCHEMA = {
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
@pytest.mark.parametrize("max_tokens", [64])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
def test_models(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
) -> None:
|
||||
def test_models(hf_runner, vllm_runner, example_prompts, model: str,
|
||||
dtype: str, max_tokens: int, num_logprobs: int) -> None:
|
||||
# TODO(sang): Sliding window should be tested separately.
|
||||
with hf_runner(model, dtype=dtype) as hf_model:
|
||||
hf_outputs = hf_model.generate_greedy_logprobs_limit(
|
||||
@ -206,14 +199,8 @@ def test_models(
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
@pytest.mark.parametrize("max_tokens", [64])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
def test_mistral_format(
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
) -> None:
|
||||
def test_mistral_format(vllm_runner, example_prompts, model: str, dtype: str,
|
||||
max_tokens: int, num_logprobs: int) -> None:
|
||||
with vllm_runner(
|
||||
model,
|
||||
dtype=dtype,
|
||||
@ -244,11 +231,8 @@ def test_mistral_format(
|
||||
|
||||
@pytest.mark.parametrize("model", MISTRAL_FORMAT_MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
def test_mistral_symbolic_languages(
|
||||
vllm_runner,
|
||||
model: str,
|
||||
dtype: str,
|
||||
) -> None:
|
||||
def test_mistral_symbolic_languages(vllm_runner, model: str,
|
||||
dtype: str) -> None:
|
||||
with vllm_runner(model,
|
||||
dtype=dtype,
|
||||
max_model_len=8192,
|
||||
@ -266,11 +250,7 @@ def test_mistral_symbolic_languages(
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
@pytest.mark.parametrize("model",
|
||||
MISTRAL_FORMAT_MODELS) # v1 can't do func calling
|
||||
def test_mistral_function_calling(
|
||||
vllm_runner,
|
||||
model: str,
|
||||
dtype: str,
|
||||
) -> None:
|
||||
def test_mistral_function_calling(vllm_runner, model: str, dtype: str) -> None:
|
||||
with vllm_runner(model,
|
||||
dtype=dtype,
|
||||
tokenizer_mode="mistral",
|
||||
@ -301,11 +281,8 @@ def test_mistral_function_calling(
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("guided_backend",
|
||||
["outlines", "lm-format-enforcer", "xgrammar"])
|
||||
def test_mistral_guided_decoding(
|
||||
vllm_runner,
|
||||
model: str,
|
||||
guided_backend: str,
|
||||
) -> None:
|
||||
def test_mistral_guided_decoding(vllm_runner, model: str,
|
||||
guided_backend: str) -> None:
|
||||
with vllm_runner(model, dtype='bfloat16',
|
||||
tokenizer_mode="mistral") as vllm_model:
|
||||
|
||||
|
@ -163,24 +163,24 @@ VLM_TEST_SETTINGS = {
|
||||
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
|
||||
),
|
||||
#### Extended model tests
|
||||
# "aria": VLMTestInfo(
|
||||
# models=["rhymes-ai/Aria"],
|
||||
# test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
|
||||
# prompt_formatter=lambda img_prompt: f"<|im_start|>user\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n ", # noqa: E501
|
||||
# img_idx_to_prompt=lambda idx: "<fim_prefix><|img|><fim_suffix>\n",
|
||||
# max_model_len=4096,
|
||||
# max_num_seqs=2,
|
||||
# auto_cls=AutoModelForImageTextToText,
|
||||
# single_image_prompts=IMAGE_ASSETS.prompts({
|
||||
# "stop_sign": "<vlm_image>Please describe the image shortly.",
|
||||
# "cherry_blossom": "<vlm_image>Please infer the season with reason.", # noqa: E501
|
||||
# }),
|
||||
# multi_image_prompt="<vlm_image><vlm_image>Describe the two images shortly.", # noqa: E501
|
||||
# stop_str=["<|im_end|>"],
|
||||
# image_size_factors=[(0.10, 0.15)],
|
||||
# max_tokens=64,
|
||||
# marks=[large_gpu_mark(min_gb=64)],
|
||||
# ),
|
||||
"aria": VLMTestInfo(
|
||||
models=["rhymes-ai/Aria"],
|
||||
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
|
||||
prompt_formatter=lambda img_prompt: f"<|im_start|>user\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n ", # noqa: E501
|
||||
img_idx_to_prompt=lambda idx: "<fim_prefix><|img|><fim_suffix>\n",
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
auto_cls=AutoModelForImageTextToText,
|
||||
single_image_prompts=IMAGE_ASSETS.prompts({
|
||||
"stop_sign": "<vlm_image>Please describe the image shortly.",
|
||||
"cherry_blossom": "<vlm_image>Please infer the season with reason.", # noqa: E501
|
||||
}),
|
||||
multi_image_prompt="<vlm_image><vlm_image>Describe the two images shortly.", # noqa: E501
|
||||
stop_str=["<|im_end|>"],
|
||||
image_size_factors=[(0.10, 0.15)],
|
||||
max_tokens=64,
|
||||
marks=[large_gpu_mark(min_gb=64)],
|
||||
),
|
||||
"blip2": VLMTestInfo(
|
||||
models=["Salesforce/blip2-opt-2.7b"],
|
||||
test_type=VLMTestType.IMAGE,
|
||||
@ -352,6 +352,7 @@ VLM_TEST_SETTINGS = {
|
||||
prompt_formatter=lambda vid_prompt: f"USER: {vid_prompt} ASSISTANT:",
|
||||
num_video_frames=16,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
auto_cls=AutoModelForVision2Seq,
|
||||
vllm_output_post_proc=model_utils.llava_video_vllm_to_hf_output,
|
||||
),
|
||||
@ -384,7 +385,7 @@ VLM_TEST_SETTINGS = {
|
||||
),
|
||||
"minicpmo_26": VLMTestInfo(
|
||||
models=["openbmb/MiniCPM-o-2_6"],
|
||||
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
|
||||
test_type=(VLMTestType.IMAGE),
|
||||
prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501
|
||||
img_idx_to_prompt=lambda idx: "(<image>./</image>)\n",
|
||||
max_model_len=4096,
|
||||
@ -393,9 +394,21 @@ VLM_TEST_SETTINGS = {
|
||||
hf_output_post_proc=model_utils.minicpmv_trunc_hf_output,
|
||||
patch_hf_runner=model_utils.minicpmo_26_patch_hf_runner,
|
||||
),
|
||||
"minicpmo_26_multi_image": VLMTestInfo(
|
||||
models=["openbmb/MiniCPM-o-2_6"],
|
||||
test_type=(VLMTestType.MULTI_IMAGE),
|
||||
prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501
|
||||
img_idx_to_prompt=lambda idx: "(<image>./</image>)\n",
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
get_stop_token_ids=lambda tok: tok.convert_tokens_to_ids(['<|im_end|>', '<|endoftext|>']), # noqa: E501
|
||||
hf_output_post_proc=model_utils.minicpmv_trunc_hf_output,
|
||||
patch_hf_runner=model_utils.minicpmo_26_patch_hf_runner,
|
||||
marks=[large_gpu_mark(min_gb=32)],
|
||||
),
|
||||
"minicpmv_26": VLMTestInfo(
|
||||
models=["openbmb/MiniCPM-V-2_6"],
|
||||
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
|
||||
test_type=(VLMTestType.IMAGE),
|
||||
prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501
|
||||
img_idx_to_prompt=lambda idx: "(<image>./</image>)\n",
|
||||
max_model_len=4096,
|
||||
@ -404,9 +417,21 @@ VLM_TEST_SETTINGS = {
|
||||
hf_output_post_proc=model_utils.minicpmv_trunc_hf_output,
|
||||
patch_hf_runner=model_utils.minicpmv_26_patch_hf_runner,
|
||||
),
|
||||
"minicpmv_26_multi_image": VLMTestInfo(
|
||||
models=["openbmb/MiniCPM-V-2_6"],
|
||||
test_type=(VLMTestType.MULTI_IMAGE),
|
||||
prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501
|
||||
img_idx_to_prompt=lambda idx: "(<image>./</image>)\n",
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
get_stop_token_ids=lambda tok: tok.convert_tokens_to_ids(['<|im_end|>', '<|endoftext|>']), # noqa: E501
|
||||
hf_output_post_proc=model_utils.minicpmv_trunc_hf_output,
|
||||
patch_hf_runner=model_utils.minicpmv_26_patch_hf_runner,
|
||||
marks=[large_gpu_mark(min_gb=32)],
|
||||
),
|
||||
"molmo": VLMTestInfo(
|
||||
models=["allenai/Molmo-7B-D-0924"],
|
||||
test_type=(VLMTestType.IMAGE),
|
||||
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
|
||||
prompt_formatter=identity,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
|
@ -1,6 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import copy
|
||||
from functools import partial
|
||||
from typing import Optional, Union
|
||||
|
||||
@ -29,7 +28,7 @@ def _test_processing_correctness(
|
||||
hit_rate: float,
|
||||
num_batches: int,
|
||||
simplify_rate: float,
|
||||
ignore_mm_keys: Optional[list[str]] = None,
|
||||
ignore_mm_keys: Optional[set[str]] = None,
|
||||
):
|
||||
model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
|
||||
model_info.check_available_online(on_fail="skip")
|
||||
@ -145,7 +144,7 @@ def _test_processing_correctness_hf(
|
||||
baseline_processor: BaseMultiModalProcessor,
|
||||
cached_processor: BaseMultiModalProcessor,
|
||||
batch_idx: int,
|
||||
ignore_mm_keys: Optional[list[str]] = None,
|
||||
ignore_mm_keys: Optional[set[str]] = None,
|
||||
):
|
||||
if model_config.hf_config.model_type in ("mllama", "whisper", "ultravox"):
|
||||
# For some multimodal models, tokenizer will always add bos_token
|
||||
@ -167,11 +166,12 @@ def _test_processing_correctness_hf(
|
||||
hf_processor_mm_kwargs={},
|
||||
)
|
||||
|
||||
assert _inputs_equal(
|
||||
_assert_inputs_equal(
|
||||
baseline_result,
|
||||
cached_result,
|
||||
ignore_mm_keys,
|
||||
), f"Failed ({batch_idx=}, {prompt=}, {mm_data=})"
|
||||
ignore_mm_keys=ignore_mm_keys,
|
||||
msg=f"Failed ({batch_idx=}, {prompt=}, {mm_data=})",
|
||||
)
|
||||
|
||||
baseline_tokenized_result = baseline_processor.apply(
|
||||
token_prompt,
|
||||
@ -179,11 +179,12 @@ def _test_processing_correctness_hf(
|
||||
hf_processor_mm_kwargs={},
|
||||
)
|
||||
|
||||
assert _inputs_equal(
|
||||
_assert_inputs_equal(
|
||||
baseline_result,
|
||||
baseline_tokenized_result,
|
||||
ignore_mm_keys,
|
||||
), f"Failed ({batch_idx=}, {prompt=}, {mm_data=})"
|
||||
ignore_mm_keys=ignore_mm_keys,
|
||||
msg=f"Failed ({batch_idx=}, {prompt=}, {mm_data=})",
|
||||
)
|
||||
|
||||
cached_tokenized_result = cached_processor.apply(
|
||||
token_prompt,
|
||||
@ -191,11 +192,12 @@ def _test_processing_correctness_hf(
|
||||
hf_processor_mm_kwargs={},
|
||||
)
|
||||
|
||||
assert _inputs_equal(
|
||||
_assert_inputs_equal(
|
||||
cached_result,
|
||||
cached_tokenized_result,
|
||||
ignore_mm_keys,
|
||||
), f"Failed ({batch_idx=}, {prompt=}, {mm_data=})"
|
||||
ignore_mm_keys=ignore_mm_keys,
|
||||
msg=f"Failed ({batch_idx=}, {prompt=}, {mm_data=})",
|
||||
)
|
||||
|
||||
|
||||
def _test_processing_correctness_mistral(
|
||||
@ -206,7 +208,7 @@ def _test_processing_correctness_mistral(
|
||||
baseline_processor: BaseMultiModalProcessor,
|
||||
cached_processor: BaseMultiModalProcessor,
|
||||
batch_idx: int,
|
||||
ignore_mm_keys: Optional[list[str]] = None,
|
||||
ignore_mm_keys: Optional[set[str]] = None,
|
||||
):
|
||||
images = mm_data.get("image", [])
|
||||
if not isinstance(images, list):
|
||||
@ -233,11 +235,12 @@ def _test_processing_correctness_mistral(
|
||||
hf_processor_mm_kwargs={},
|
||||
)
|
||||
|
||||
assert _inputs_equal(
|
||||
_assert_inputs_equal(
|
||||
baseline_tokenized_result,
|
||||
cached_tokenized_result,
|
||||
ignore_mm_keys,
|
||||
), f"Failed ({batch_idx=}, {prompt=}, {mm_data=})"
|
||||
ignore_mm_keys=ignore_mm_keys,
|
||||
msg=f"Failed ({batch_idx=}, {prompt=}, {mm_data=})",
|
||||
)
|
||||
|
||||
|
||||
# yapf: disable
|
||||
@ -261,6 +264,7 @@ def _test_processing_correctness_mistral(
|
||||
"TIGER-Lab/Mantis-8B-siglip-llama3",
|
||||
"mistralai/Pixtral-12B-2409",
|
||||
"mistral-community/pixtral-12b",
|
||||
"openbmb/MiniCPM-Llama3-V-2_5",
|
||||
"openbmb/MiniCPM-o-2_6",
|
||||
"openbmb/MiniCPM-V-2_6",
|
||||
"allenai/Molmo-7B-D-0924",
|
||||
@ -290,7 +294,7 @@ def test_processing_correctness(
|
||||
# In Ultravox, the audio_features can be different depending on padding
|
||||
# The slight difference should not be a problem though, since
|
||||
# attention_mask lets us ignore the difference.
|
||||
ignore_mm_keys = ['audio_features']
|
||||
ignore_mm_keys = {"audio_features"}
|
||||
|
||||
_test_processing_correctness(
|
||||
model_id,
|
||||
@ -328,38 +332,26 @@ def test_processing_correctness_phi3v(
|
||||
)
|
||||
|
||||
|
||||
def _inputs_equal(
|
||||
def _assert_inputs_equal(
|
||||
a: MultiModalInputs,
|
||||
b: MultiModalInputs,
|
||||
ignore_mm_keys: Optional[list[str]] = None,
|
||||
*,
|
||||
ignore_mm_keys: Optional[set[str]] = None,
|
||||
msg: str = "",
|
||||
):
|
||||
return _drop_mm_kwargs_keys(a, ignore_mm_keys) == _drop_mm_kwargs_keys(
|
||||
b, ignore_mm_keys)
|
||||
if ignore_mm_keys is None:
|
||||
ignore_mm_keys = set()
|
||||
|
||||
if msg is None:
|
||||
assert "mm_kwargs" in a and "mm_kwargs" in b
|
||||
else:
|
||||
assert "mm_kwargs" in a and "mm_kwargs" in b, msg
|
||||
|
||||
def _drop_mm_kwargs_keys(
|
||||
result: MultiModalInputs,
|
||||
ignore_mm_keys: Optional[list[str]] = None,
|
||||
) -> MultiModalInputs:
|
||||
"""Drop specified keys from result['mm_kwargs'].
|
||||
for key in ignore_mm_keys:
|
||||
a["mm_kwargs"].pop(key, None)
|
||||
b["mm_kwargs"].pop(key, None)
|
||||
|
||||
This is mainly to avoid doing exact match of audio_features in ultravox.
|
||||
|
||||
Args:
|
||||
result: Result to drop keys from
|
||||
ignore_mm_keys: List of keys to ignore, e.g. ['audio_features']
|
||||
"""
|
||||
if not ignore_mm_keys:
|
||||
return result
|
||||
|
||||
if 'mm_kwargs' in result:
|
||||
result = copy.deepcopy(result)
|
||||
mm_kwargs = result['mm_kwargs']
|
||||
for key in ignore_mm_keys:
|
||||
mm_kwargs.pop(key, None)
|
||||
for items in mm_kwargs._items_by_modality.values():
|
||||
for item in items:
|
||||
for key in ignore_mm_keys:
|
||||
item.pop(key, None)
|
||||
|
||||
return result
|
||||
if msg is None:
|
||||
assert a == b
|
||||
else:
|
||||
assert a == b, msg
|
||||
|
@ -29,7 +29,7 @@ def test_processor_override(
|
||||
num_imgs: int,
|
||||
kwargs_on_init: bool,
|
||||
):
|
||||
"""Ensure input_processor_for_idefics3 handles num_crops properly."""
|
||||
"""Ensure Idefics3MultiModalProcessor handles num_crops properly."""
|
||||
# Same as the previous test - don't initialize mm_processor_kwargs
|
||||
# in this test and assume that the kwargs will be correctly expanded by
|
||||
# the partial when calling the custom input processor.
|
||||
|
@ -30,7 +30,7 @@ def test_processor_override(
|
||||
num_imgs: int,
|
||||
kwargs_on_init: bool,
|
||||
):
|
||||
"""Ensure input_processor_for_phi3v handles num_crops properly."""
|
||||
"""Ensure Phi3VMultiModalProcessor handles num_crops properly."""
|
||||
# Avoid initializing CUDA early
|
||||
from vllm.model_executor.models.phi3v import _IMAGE_TOKEN_ID
|
||||
|
||||
|
@ -319,7 +319,7 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
|
||||
}
|
||||
|
||||
_FALLBACK_MODEL = {
|
||||
"TransformersModel": _HfExamplesInfo("ArthurZ/Ilama-3.2-1B", trust_remote_code=True), # noqa: E501
|
||||
"TransformersForCausalLM": _HfExamplesInfo("ArthurZ/Ilama-3.2-1B", trust_remote_code=True), # noqa: E501
|
||||
}
|
||||
|
||||
_EXAMPLE_MODELS = {
|
||||
|
@ -3,8 +3,6 @@
|
||||
|
||||
Run `pytest tests/models/test_transformers.py`.
|
||||
"""
|
||||
from contextlib import nullcontext
|
||||
|
||||
import pytest
|
||||
|
||||
from ..conftest import HfRunner, VllmRunner
|
||||
@ -42,7 +40,6 @@ def check_implementation(
|
||||
"model,model_impl",
|
||||
[
|
||||
("meta-llama/Llama-3.2-1B-Instruct", "transformers"),
|
||||
("openai-community/gpt2", "transformers"),
|
||||
("ArthurZ/Ilama-3.2-1B", "auto"), # CUSTOM CODE
|
||||
]) # trust_remote_code=True by default
|
||||
def test_models(
|
||||
@ -52,20 +49,11 @@ def test_models(
|
||||
model: str,
|
||||
model_impl: str,
|
||||
) -> None:
|
||||
|
||||
maybe_raises = nullcontext()
|
||||
if model == "openai-community/gpt2" and model_impl == "transformers":
|
||||
# Model is not backend compatible
|
||||
maybe_raises = pytest.raises(
|
||||
ValueError,
|
||||
match="The Transformers implementation.*not compatible with vLLM")
|
||||
|
||||
with maybe_raises:
|
||||
check_implementation(hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model,
|
||||
model_impl=model_impl)
|
||||
check_implementation(hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model,
|
||||
model_impl=model_impl)
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
|
@ -23,8 +23,14 @@ MODELS = [
|
||||
reason="FP8 is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("model_id", MODELS)
|
||||
@pytest.mark.parametrize("force_marlin", [False, True])
|
||||
@pytest.mark.parametrize(
|
||||
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
|
||||
def test_model_load_and_run(vllm_runner, model_id: str, force_marlin: bool,
|
||||
monkeypatch) -> None:
|
||||
use_rocm_aiter: bool, monkeypatch) -> None:
|
||||
|
||||
if use_rocm_aiter:
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||
|
||||
if force_marlin:
|
||||
monkeypatch.setenv("VLLM_TEST_FORCE_FP8_MARLIN", "1")
|
||||
|
||||
@ -47,7 +53,13 @@ KV_CACHE_MODELS = [
|
||||
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
|
||||
reason="FP8 is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("model_id", KV_CACHE_MODELS)
|
||||
def test_kv_cache_model_load_and_run(vllm_runner, model_id: str, monkeypatch):
|
||||
@pytest.mark.parametrize(
|
||||
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
|
||||
def test_kv_cache_model_load_and_run(vllm_runner, model_id: str,
|
||||
use_rocm_aiter: bool, monkeypatch):
|
||||
if use_rocm_aiter:
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||
|
||||
# vllm_runner.apply_model() relies on V0 internals.
|
||||
monkeypatch.setenv("VLLM_USE_V1", "0")
|
||||
with vllm_runner(model_id, kv_cache_dtype="fp8") as llm:
|
||||
@ -86,8 +98,13 @@ def test_kv_cache_model_load_and_run(vllm_runner, model_id: str, monkeypatch):
|
||||
reason="FP8 is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
|
||||
@pytest.mark.parametrize("force_marlin", [False, True])
|
||||
@pytest.mark.parametrize(
|
||||
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
|
||||
def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool,
|
||||
monkeypatch) -> None:
|
||||
use_rocm_aiter: bool, monkeypatch) -> None:
|
||||
if use_rocm_aiter:
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||
|
||||
# vllm_runner.apply_model() relies on V0 internals.
|
||||
monkeypatch.setenv("VLLM_USE_V1", "0")
|
||||
|
||||
|
@ -2,6 +2,8 @@
|
||||
# ruff: noqa
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import pickle
|
||||
import socket
|
||||
from collections.abc import AsyncIterator
|
||||
from unittest.mock import patch
|
||||
@ -14,7 +16,8 @@ from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.utils import (FlexibleArgumentParser, MemorySnapshot,
|
||||
PlaceholderModule, StoreBoolean, bind_kv_cache,
|
||||
deprecate_kwargs, get_open_port, memory_profiling,
|
||||
merge_async_iterators, supports_kw, swap_dict_values)
|
||||
merge_async_iterators, sha256, supports_kw,
|
||||
swap_dict_values)
|
||||
|
||||
from .utils import create_new_process_for_each_test, error_on_warning
|
||||
|
||||
@ -476,3 +479,21 @@ def test_swap_dict_values(obj, key1, key2):
|
||||
assert obj[key1] == original_obj[key2]
|
||||
else:
|
||||
assert key1 not in obj
|
||||
|
||||
@pytest.mark.parametrize("input", [(), ("abc", ), (None, ),
|
||||
(None, bool, [1, 2, 3])])
|
||||
@pytest.mark.parametrize("output", [0, 1, 2])
|
||||
def test_sha256(input: tuple, output: int):
|
||||
hash = sha256(input)
|
||||
assert hash is not None
|
||||
assert isinstance(hash, int)
|
||||
assert hash != 0
|
||||
|
||||
bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
assert hash == int.from_bytes(hashlib.sha256(bytes).digest(), byteorder="big")
|
||||
|
||||
# hashing again, returns the same value
|
||||
assert hash == sha256(input)
|
||||
|
||||
# hashing different input, returns different value
|
||||
assert hash != sha256(input + (1, ))
|
||||
|
@ -5,12 +5,8 @@ import os
|
||||
import tempfile
|
||||
|
||||
import depyf
|
||||
import pytest
|
||||
|
||||
from vllm.config import CompilationLevel
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Not working; needs investigation.")
|
||||
def test_tpu_compilation():
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
with depyf.prepare_debug(temp_dir):
|
||||
@ -22,27 +18,24 @@ def test_tpu_compilation():
|
||||
"The greatest glory in living lies not in never falling,",
|
||||
]
|
||||
answers = [
|
||||
" or, through inaction, allow a human being to come to harm.",
|
||||
" what is essential is invisible to the eye.",
|
||||
" but in rising every time we fall.",
|
||||
" or, through inaction",
|
||||
" what is essential ",
|
||||
" but in rising ",
|
||||
]
|
||||
N = 1
|
||||
|
||||
# Currently, top-p sampling is disabled. `top_p` should be 1.0.
|
||||
N = 1
|
||||
sampling_params = SamplingParams(temperature=0.7,
|
||||
top_p=1.0,
|
||||
n=N,
|
||||
max_tokens=16)
|
||||
|
||||
# Set `enforce_eager=True` to avoid ahead-of-time compilation.
|
||||
# In real workloads, `enforace_eager` should be `False`.
|
||||
|
||||
# disable custom dispatcher, let Dynamo takes over
|
||||
# all the control
|
||||
llm = LLM(model="Qwen/Qwen2.5-1.5B-Instruct",
|
||||
max_model_len=512,
|
||||
max_num_seqs=64,
|
||||
enforce_eager=True,
|
||||
compilation_config={"level": CompilationLevel.DYNAMO_AS_IS})
|
||||
max_num_batched_tokens=256,
|
||||
max_model_len=256,
|
||||
max_num_seqs=32,
|
||||
enforce_eager=False)
|
||||
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
for output, answer in zip(outputs, answers):
|
||||
prompt = output.prompt
|
||||
@ -56,16 +49,11 @@ def test_tpu_compilation():
|
||||
for i, compiled_code in enumerate(compiled_codes):
|
||||
print("{} file: {}".format(i + 1, compiled_code))
|
||||
|
||||
# We should only trigger Dynamo compilation 4 times:
|
||||
# 1. forward pass (symbolic)
|
||||
# 2. compute_logits (symbolic)
|
||||
# 3. forward pass (shape 16)
|
||||
# 4. forward pass (shape 32)
|
||||
# and later calls should not trigger Dynamo compilation again.
|
||||
# NOTE: It might still trigger XLA compilation.
|
||||
|
||||
# We should only trigger Dynamo compilation 2 times:
|
||||
# 1. Forward pass without kv_caches
|
||||
# 2. Forward pass with kv_caches
|
||||
# Check we have 4 compiled codes
|
||||
assert len(compiled_codes) == 4
|
||||
assert len(compiled_codes) == 2
|
||||
|
||||
kv_cache_prefix = "kv_cache"
|
||||
attn_prefix = "ragged_paged_attention"
|
||||
@ -77,24 +65,13 @@ def test_tpu_compilation():
|
||||
for i, compiled_fn in enumerate(compiled_fns):
|
||||
print("{} file: {}".format(i + 1, compiled_fn))
|
||||
|
||||
# The first compilation is symbolic, so it should not have any kv_caches
|
||||
# The first compilation should not have any kv_caches
|
||||
with open(compiled_fns[0]) as f:
|
||||
content = f.read()
|
||||
assert kv_cache_prefix not in content
|
||||
|
||||
# The second compilation is symbolic, so it should not have any kv_caches
|
||||
# The second compilation should have kv_caches and the
|
||||
# ragged_paged_attention
|
||||
with open(compiled_fns[1]) as f:
|
||||
content = f.read()
|
||||
assert kv_cache_prefix not in content
|
||||
|
||||
# The third compilation is shape 16, so it should have kv_caches and the
|
||||
# ragged_paged_attention
|
||||
with open(compiled_fns[2]) as f:
|
||||
content = f.read()
|
||||
assert (kv_cache_prefix in content and attn_prefix in content)
|
||||
|
||||
# The forth compilation is shape 32, so it should have kv_caches and the
|
||||
# ragged_paged_attention
|
||||
with open(compiled_fns[3]) as f:
|
||||
content = f.read()
|
||||
assert (kv_cache_prefix in content and attn_prefix in content)
|
||||
|
@ -5,8 +5,12 @@ import torch
|
||||
|
||||
from vllm.multimodal.inputs import MultiModalKwargs
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue,
|
||||
KVCacheBlock, PrefixCachingMetrics,
|
||||
from vllm.utils import sha256
|
||||
# disable yapf here as it formats differently than isort such that both fail
|
||||
# yapf: disable
|
||||
from vllm.v1.core.kv_cache_utils import (NONE_HASH, BlockHashType,
|
||||
FreeKVCacheBlockQueue, KVCacheBlock,
|
||||
PrefixCachingMetrics,
|
||||
generate_block_hash_extra_keys,
|
||||
hash_block_tokens,
|
||||
hash_request_tokens,
|
||||
@ -16,6 +20,8 @@ from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
from vllm.v1.metrics.stats import PrefixCacheStats
|
||||
from vllm.v1.request import Request
|
||||
|
||||
# yapf: enable
|
||||
|
||||
|
||||
def make_request(request_id,
|
||||
prompt_token_ids,
|
||||
@ -40,6 +46,12 @@ def make_request(request_id,
|
||||
)
|
||||
|
||||
|
||||
def test_none_hash():
|
||||
assert NONE_HASH is not None
|
||||
assert isinstance(NONE_HASH, int)
|
||||
assert NONE_HASH != 0
|
||||
|
||||
|
||||
def test_kv_cache_block():
|
||||
# Test KVCacheBlock initialization
|
||||
block = KVCacheBlock(block_id=0)
|
||||
@ -190,21 +202,23 @@ def test_generate_block_hash_extra_keys_no_mm_inputs():
|
||||
assert next_mm_idx == 0
|
||||
|
||||
|
||||
def test_hash_block_tokens():
|
||||
@pytest.mark.parametrize("hash_fn", [sha256, hash])
|
||||
def test_hash_block_tokens(hash_fn):
|
||||
parent_block_hash = 123
|
||||
curr_block_token_ids = (1, 2, 3)
|
||||
extra_keys = ("key1", "key2")
|
||||
|
||||
block_hash = hash_block_tokens(parent_block_hash, curr_block_token_ids,
|
||||
extra_keys)
|
||||
block_hash = hash_block_tokens(hash_fn, parent_block_hash,
|
||||
curr_block_token_ids, extra_keys)
|
||||
assert isinstance(block_hash, BlockHashType)
|
||||
assert block_hash.hash_value == hash(
|
||||
assert block_hash.hash_value == hash_fn(
|
||||
(parent_block_hash, curr_block_token_ids, extra_keys))
|
||||
assert block_hash.token_ids == curr_block_token_ids
|
||||
assert block_hash.extra_keys == extra_keys
|
||||
|
||||
|
||||
def test_hash_request_tokens():
|
||||
@pytest.mark.parametrize("hash_fn", [sha256, hash])
|
||||
def test_hash_request_tokens(hash_fn):
|
||||
request = make_request(
|
||||
request_id=0,
|
||||
prompt_token_ids=[_ for _ in range(6)],
|
||||
@ -219,7 +233,7 @@ def test_hash_request_tokens():
|
||||
)
|
||||
|
||||
block_size = 3
|
||||
block_hashes = hash_request_tokens(block_size, request)
|
||||
block_hashes = hash_request_tokens(hash_fn, block_size, request)
|
||||
|
||||
assert len(block_hashes) == 2
|
||||
assert isinstance(block_hashes[0], BlockHashType)
|
||||
@ -234,7 +248,8 @@ def test_hash_request_tokens():
|
||||
assert block_hashes[1].extra_keys == ("hash2", )
|
||||
|
||||
|
||||
def test_hash_tokens_different_mm_input():
|
||||
@pytest.mark.parametrize("hash_fn", [sha256, hash])
|
||||
def test_hash_tokens_different_mm_input(hash_fn):
|
||||
request1 = make_request(
|
||||
request_id=0,
|
||||
prompt_token_ids=[_ for _ in range(6)],
|
||||
@ -260,13 +275,14 @@ def test_hash_tokens_different_mm_input():
|
||||
mm_hashes=["hash3", "hash2"],
|
||||
)
|
||||
block_size = 3
|
||||
block_hashes1 = hash_request_tokens(block_size, request1)
|
||||
block_hashes2 = hash_request_tokens(block_size, request2)
|
||||
block_hashes1 = hash_request_tokens(hash_fn, block_size, request1)
|
||||
block_hashes2 = hash_request_tokens(hash_fn, block_size, request2)
|
||||
assert block_hashes1[0] != block_hashes2[0]
|
||||
assert block_hashes1[1] != block_hashes2[1]
|
||||
|
||||
|
||||
def test_hash_request_tokens_no_mm_inputs():
|
||||
@pytest.mark.parametrize("hash_fn", [sha256, hash])
|
||||
def test_hash_request_tokens_no_mm_inputs(hash_fn):
|
||||
request = make_request(
|
||||
request_id=0,
|
||||
prompt_token_ids=[_ for _ in range(6)],
|
||||
@ -275,7 +291,7 @@ def test_hash_request_tokens_no_mm_inputs():
|
||||
)
|
||||
|
||||
block_size = 3
|
||||
block_hashes = hash_request_tokens(block_size, request)
|
||||
block_hashes = hash_request_tokens(hash_fn, block_size, request)
|
||||
|
||||
assert len(block_hashes) == 2
|
||||
assert block_hashes[0].token_ids == (0, 1, 2)
|
||||
|
@ -7,7 +7,7 @@ import pytest
|
||||
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import cdiv
|
||||
from vllm.utils import cdiv, sha256
|
||||
from vllm.v1.core.block_pool import BlockPool
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock,
|
||||
@ -39,16 +39,21 @@ def make_request(request_id,
|
||||
)
|
||||
|
||||
|
||||
def test_prefill():
|
||||
@pytest.mark.parametrize("hash_algo", ["sha256", "hash"])
|
||||
def test_prefill(hash_algo):
|
||||
manager = KVCacheManager(
|
||||
block_size=16,
|
||||
num_gpu_blocks=10,
|
||||
max_model_len=8192,
|
||||
sliding_window=None,
|
||||
enable_caching=True,
|
||||
caching_hash_algo=hash_algo,
|
||||
num_preallocate_tokens=16,
|
||||
)
|
||||
|
||||
# choose the hash function according to the parameter
|
||||
hash_fn = sha256 if hash_algo == "sha256" else hash
|
||||
|
||||
# Complete 3 blocks (48 tokens)
|
||||
common_token_ids = [i for i in range(3) for _ in range(16)]
|
||||
|
||||
@ -68,7 +73,8 @@ def test_prefill():
|
||||
parent_block_hash = None
|
||||
for block_id in (0, 1, 2):
|
||||
block_tokens = tuple(all_token_ids[block_id * 16:(block_id + 1) * 16])
|
||||
block_hash = hash_block_tokens(parent_block_hash, block_tokens)
|
||||
block_hash = hash_block_tokens(hash_fn, parent_block_hash,
|
||||
block_tokens)
|
||||
assert manager.block_pool.blocks[block_id].block_hash == block_hash
|
||||
assert manager.block_pool.blocks[block_id].ref_cnt == 1
|
||||
parent_block_hash = block_hash.hash_value
|
||||
@ -163,6 +169,8 @@ def test_prefill_plp():
|
||||
enable_caching=True,
|
||||
num_preallocate_tokens=16,
|
||||
)
|
||||
# the default hash function is hash
|
||||
hash_fn = hash
|
||||
|
||||
# Complete 3 blocks (48 tokens)
|
||||
common_token_ids = [i for i in range(3) for _ in range(16)]
|
||||
@ -185,7 +193,8 @@ def test_prefill_plp():
|
||||
parent_block_hash = None
|
||||
for block_id in (0, 1, 2):
|
||||
block_tokens = tuple(all_token_ids[block_id * 16:(block_id + 1) * 16])
|
||||
block_hash = hash_block_tokens(parent_block_hash, block_tokens)
|
||||
block_hash = hash_block_tokens(hash_fn, parent_block_hash,
|
||||
block_tokens)
|
||||
assert manager.block_pool.blocks[block_id].block_hash == block_hash
|
||||
assert manager.block_pool.blocks[block_id].ref_cnt == 1
|
||||
parent_block_hash = block_hash.hash_value
|
||||
@ -522,7 +531,8 @@ def test_preallocate_blocks(num_preallocate_tokens: int, block_size: int):
|
||||
assert len(blocks) == 1 + num_preallocated_blocks
|
||||
|
||||
|
||||
def test_cache_blocks():
|
||||
@pytest.mark.parametrize("hash_fn", [sha256, hash])
|
||||
def test_cache_blocks(hash_fn):
|
||||
"""
|
||||
This is a unit test that tests the correctness of the _cache_full_blocks
|
||||
function of KVCacheManager.
|
||||
@ -550,6 +560,7 @@ def test_cache_blocks():
|
||||
num_cached_blocks=0,
|
||||
num_full_blocks=2,
|
||||
block_size=block_size,
|
||||
hash_fn=hash_fn,
|
||||
)
|
||||
|
||||
assert len(block_pool.cached_block_hash_to_block) == 2
|
||||
@ -564,6 +575,7 @@ def test_cache_blocks():
|
||||
num_cached_blocks=2,
|
||||
num_full_blocks=3,
|
||||
block_size=block_size,
|
||||
hash_fn=hash_fn,
|
||||
)
|
||||
assert len(block_pool.cached_block_hash_to_block) == 3
|
||||
assert blocks[0].block_hash is not None
|
||||
|
@ -20,9 +20,10 @@ def create_scheduler(
|
||||
max_num_seqs: int = 16,
|
||||
max_num_batched_tokens: int = 8192,
|
||||
enable_prefix_caching: Optional[bool] = None,
|
||||
long_prefill_token_threshold: int = 0,
|
||||
) -> Scheduler:
|
||||
'''Create scheduler under test.
|
||||
|
||||
|
||||
Args:
|
||||
model: model under test
|
||||
max_num_seqs: max sequences to schedule
|
||||
@ -38,6 +39,7 @@ def create_scheduler(
|
||||
max_num_seqs=max_num_seqs,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
max_model_len=max_num_batched_tokens,
|
||||
long_prefill_token_threshold=long_prefill_token_threshold,
|
||||
)
|
||||
model_config = ModelConfig(
|
||||
model=model,
|
||||
@ -242,7 +244,9 @@ def test_schedule_partial_requests():
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=[request.request_id for request in requests],
|
||||
req_id_to_index=req_to_index,
|
||||
sampled_token_ids=[[0] for _ in range(len(requests))],
|
||||
# Only the first request has a sampled token id because
|
||||
# the rest requests are still being prefilled.
|
||||
sampled_token_ids=[[0], [], []],
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
@ -263,6 +267,86 @@ def test_schedule_partial_requests():
|
||||
assert requests[2].request_id not in output.num_scheduled_tokens
|
||||
|
||||
|
||||
@pytest.mark.parametrize("enable_prefix_caching", [True, False])
|
||||
def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
|
||||
"""Test scheduling behavior with concurrent partial requests.
|
||||
|
||||
This test verifies that: there are multiple long prefill requests in the
|
||||
RUNNING state, and we can schedule them together.
|
||||
|
||||
"""
|
||||
scheduler = create_scheduler(
|
||||
model="facebook/opt-125m",
|
||||
max_num_batched_tokens=1024,
|
||||
long_prefill_token_threshold=400,
|
||||
enable_prefix_caching=enable_prefix_caching,
|
||||
)
|
||||
requests = create_requests(
|
||||
num_requests=3,
|
||||
num_tokens=800,
|
||||
)
|
||||
for request in requests:
|
||||
scheduler.add_request(request)
|
||||
|
||||
output = scheduler.schedule()
|
||||
assert len(output.scheduled_new_reqs) == 3
|
||||
assert len(output.scheduled_cached_reqs) == 0
|
||||
assert len(output.finished_req_ids) == 0
|
||||
|
||||
# The first request is scheduled partially - 400.
|
||||
assert output.num_scheduled_tokens[requests[0].request_id] == 400
|
||||
# The second request is scheduled partially - 400.
|
||||
assert output.num_scheduled_tokens[requests[1].request_id] == 400
|
||||
# The third request is also scheduled partially - 1024 - 400 - 400 = 224.
|
||||
assert output.num_scheduled_tokens[requests[2].request_id] == 224
|
||||
req_to_index = {
|
||||
request.request_id: i
|
||||
for i, request in enumerate(requests)
|
||||
}
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=[request.request_id for request in requests],
|
||||
req_id_to_index=req_to_index,
|
||||
sampled_token_ids=[[] for _ in range(len(requests))],
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
)
|
||||
scheduler.update_from_output(output, model_runner_output)
|
||||
|
||||
# Schedule the next step. All three requests are running.
|
||||
# Processed the remaining prefills of the first and second requests.
|
||||
output1 = scheduler.schedule()
|
||||
assert len(scheduler.running) == 3
|
||||
assert len(output1.scheduled_new_reqs) == 0
|
||||
assert len(output1.scheduled_cached_reqs) == 3
|
||||
assert len(output1.finished_req_ids) == 0
|
||||
assert output1.num_scheduled_tokens[requests[0].request_id] == 400
|
||||
assert output1.num_scheduled_tokens[requests[1].request_id] == 400
|
||||
assert output1.num_scheduled_tokens[requests[2].request_id] == 224
|
||||
|
||||
# Schedule the third step. All three requests are running.
|
||||
# First and second requests are in the decode stage.
|
||||
# All the remaining tokens in the third request are processed.
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=[request.request_id for request in requests],
|
||||
req_id_to_index=req_to_index,
|
||||
sampled_token_ids=[[0], [0]] + [[] for _ in range(len(requests) - 2)],
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
)
|
||||
scheduler.update_from_output(output1, model_runner_output)
|
||||
output2 = scheduler.schedule()
|
||||
assert len(scheduler.running) == 3
|
||||
assert len(output2.scheduled_new_reqs) == 0
|
||||
assert len(output2.scheduled_cached_reqs) == 3
|
||||
assert len(output2.finished_req_ids) == 0
|
||||
assert output2.num_scheduled_tokens[requests[0].request_id] == 1
|
||||
assert output2.num_scheduled_tokens[requests[1].request_id] == 1
|
||||
assert output2.num_scheduled_tokens[
|
||||
requests[2].request_id] == 800 - 224 - 224
|
||||
|
||||
|
||||
def test_stop_via_update_from_output():
|
||||
"""Test stopping behavior through update_from_output"""
|
||||
scheduler = create_scheduler()
|
||||
|
29
tests/v1/core/test_scheduler_e2e.py
Normal file
29
tests/v1/core/test_scheduler_e2e.py
Normal file
@ -0,0 +1,29 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import LLM
|
||||
|
||||
if os.getenv("VLLM_USE_V1", "0") != "1":
|
||||
pytest.skip("Test package requires V1", allow_module_level=True)
|
||||
|
||||
MODEL = "meta-llama/Llama-3.2-1B"
|
||||
PROMPT = "Hello my name is Robert and I"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def model() -> LLM:
|
||||
return LLM(MODEL,
|
||||
enforce_eager=True,
|
||||
enable_prefix_caching=True,
|
||||
long_prefill_token_threshold=2,
|
||||
max_num_batched_tokens=6,
|
||||
max_num_seqs=3)
|
||||
|
||||
|
||||
def test_concurrent_partial_prefill(model):
|
||||
outputs = model.generate([PROMPT] * 3)
|
||||
assert len(outputs) == 3
|
||||
for output in outputs:
|
||||
assert len(output.outputs) == 1
|
@ -1,5 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from argparse import ArgumentError
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import envs
|
||||
@ -32,6 +34,24 @@ def test_prefix_caching_from_cli():
|
||||
vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
|
||||
assert vllm_config.cache_config.enable_prefix_caching
|
||||
|
||||
# default hash algorithm is "builtin"
|
||||
assert vllm_config.cache_config.prefix_caching_hash_algo == "builtin"
|
||||
|
||||
# set hash algorithm to sha256
|
||||
args = parser.parse_args(["--prefix-caching-hash-algo", "sha256"])
|
||||
vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
|
||||
assert vllm_config.cache_config.prefix_caching_hash_algo == "sha256"
|
||||
|
||||
# set hash algorithm to builtin
|
||||
args = parser.parse_args(["--prefix-caching-hash-algo", "builtin"])
|
||||
vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
|
||||
assert vllm_config.cache_config.prefix_caching_hash_algo == "builtin"
|
||||
|
||||
# an invalid hash algorithm raises an error
|
||||
parser.exit_on_error = False
|
||||
with pytest.raises(ArgumentError):
|
||||
args = parser.parse_args(["--prefix-caching-hash-algo", "invalid"])
|
||||
|
||||
|
||||
def test_defaults_with_usage_context():
|
||||
engine_args = EngineArgs(model="facebook/opt-125m")
|
||||
|
@ -231,8 +231,10 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
|
||||
Test that the engine can handle multiple concurrent batches.
|
||||
"""
|
||||
|
||||
def make_request_with_max_tokens(max_tokens: int) -> EngineCoreRequest:
|
||||
def make_request_with_max_tokens(req_id: int,
|
||||
max_tokens: int) -> EngineCoreRequest:
|
||||
request = make_request()
|
||||
request.request_id = req_id
|
||||
request.sampling_params.max_tokens = max_tokens
|
||||
return request
|
||||
|
||||
@ -279,6 +281,8 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
|
||||
# Avoid all requests being scheduled once.
|
||||
enable_prefix_caching=False,
|
||||
max_num_batched_tokens=10,
|
||||
# Reduce startup time.
|
||||
enforce_eager=True,
|
||||
)
|
||||
vllm_config = engine_args.create_engine_config()
|
||||
engine_core = EngineCore(vllm_config=vllm_config,
|
||||
@ -286,13 +290,13 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
|
||||
executor_class=DummyExecutor)
|
||||
assert engine_core.batch_queue is not None
|
||||
|
||||
# Add two requests in a row.
|
||||
req = make_request_with_max_tokens(5)
|
||||
engine_core.add_request(req)
|
||||
req = make_request_with_max_tokens(5)
|
||||
engine_core.add_request(req)
|
||||
# Add two requests in a row. Each request have 12 prompt tokens.
|
||||
req0 = make_request_with_max_tokens(0, 5)
|
||||
engine_core.add_request(req0)
|
||||
req1 = make_request_with_max_tokens(1, 5)
|
||||
engine_core.add_request(req1)
|
||||
|
||||
# First saturate the batch queue.
|
||||
# Schedule Batch 1: (10, req0)
|
||||
assert engine_core.step_with_batch_queue() is None
|
||||
assert engine_core.batch_queue.qsize() == 1
|
||||
assert engine_core.step_with_batch_queue() is None
|
||||
|
@ -167,11 +167,11 @@ def test_engine_core_client(monkeypatch: pytest.MonkeyPatch,
|
||||
|
||||
core_client: SyncMPClient = client
|
||||
|
||||
result = core_client._call_utility("echo", "testarg")
|
||||
result = core_client.call_utility("echo", "testarg")
|
||||
assert result == "testarg"
|
||||
|
||||
with pytest.raises(Exception) as e_info:
|
||||
core_client._call_utility("echo", None, "help!")
|
||||
core_client.call_utility("echo", None, "help!")
|
||||
|
||||
assert str(e_info.value) == "Call to echo method failed: help!"
|
||||
|
||||
@ -238,10 +238,10 @@ async def test_engine_core_client_asyncio(monkeypatch: pytest.MonkeyPatch):
|
||||
|
||||
core_client: AsyncMPClient = client
|
||||
|
||||
result = await core_client._call_utility_async("echo", "testarg")
|
||||
result = await core_client.call_utility_async("echo", "testarg")
|
||||
assert result == "testarg"
|
||||
|
||||
with pytest.raises(Exception) as e_info:
|
||||
await core_client._call_utility_async("echo", None, "help!")
|
||||
await core_client.call_utility_async("echo", None, "help!")
|
||||
|
||||
assert str(e_info.value) == "Call to echo method failed: help!"
|
||||
|
37
tests/v1/sample/test_topk_topp_sampler.py
Normal file
37
tests/v1/sample/test_topk_topp_sampler.py
Normal file
@ -0,0 +1,37 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import torch
|
||||
from torch import Generator
|
||||
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
||||
|
||||
DEVICE = "cuda"
|
||||
|
||||
BATCH_SIZE = 1024
|
||||
VOCAB_SIZE = 128 * 1024
|
||||
|
||||
|
||||
def test_topk_impl_equivalance():
|
||||
|
||||
with torch.device(DEVICE):
|
||||
generator = Generator(device=DEVICE).manual_seed(33)
|
||||
|
||||
logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator)
|
||||
|
||||
# Random top-k values between 1 and 9.
|
||||
k = torch.randint(1, 10, (BATCH_SIZE, ), generator=generator)
|
||||
|
||||
# Set k=vocab_size for ~50% of requests in the batch (top-k disabled).
|
||||
k.masked_fill_(
|
||||
torch.randint(0,
|
||||
2, (BATCH_SIZE, ),
|
||||
generator=generator,
|
||||
dtype=bool), VOCAB_SIZE)
|
||||
|
||||
# Top-k only implementation
|
||||
result1 = apply_top_k_top_p(logits=logits.clone(), k=k, p=None)
|
||||
|
||||
# Top-p + top-k
|
||||
no_op_top_p = torch.tensor([1.0])
|
||||
result2 = apply_top_k_top_p(logits=logits.clone(), k=k, p=no_op_top_p)
|
||||
|
||||
assert torch.allclose(result1, result2)
|
109
tests/v1/test_async_llm_dp.py
Normal file
109
tests/v1/test_async_llm_dp.py
Normal file
@ -0,0 +1,109 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from contextlib import ExitStack
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.inputs import PromptType
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sampling_params import RequestOutputKind
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
from vllm.v1.engine.core_client import DPAsyncMPClient
|
||||
|
||||
engine_args = AsyncEngineArgs(
|
||||
model="ibm-research/PowerMoE-3b",
|
||||
enforce_eager=True,
|
||||
disable_log_requests=True,
|
||||
tensor_parallel_size=int(os.getenv("TP_SIZE", 1)),
|
||||
data_parallel_size=int(os.getenv("DP_SIZE", 2)),
|
||||
)
|
||||
|
||||
if not current_platform.supports_v1(engine_args.create_model_config()):
|
||||
pytest.skip(reason="Requires V1-supporting platform.",
|
||||
allow_module_level=True)
|
||||
|
||||
|
||||
async def generate(engine: AsyncLLM,
|
||||
request_id: str,
|
||||
prompt: PromptType,
|
||||
output_kind: RequestOutputKind,
|
||||
max_tokens: int,
|
||||
prompt_logprobs: Optional[int] = None) -> tuple[int, str]:
|
||||
# Ensure generate doesn't complete too fast for cancellation test.
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
count = 0
|
||||
sampling_params = SamplingParams(max_tokens=max_tokens,
|
||||
ignore_eos=True,
|
||||
output_kind=output_kind,
|
||||
temperature=0,
|
||||
prompt_logprobs=prompt_logprobs)
|
||||
async for out in engine.generate(request_id=request_id,
|
||||
prompt=prompt,
|
||||
sampling_params=sampling_params):
|
||||
|
||||
num_tokens = len(out.outputs[0].token_ids)
|
||||
if output_kind == RequestOutputKind.DELTA:
|
||||
count += num_tokens
|
||||
else:
|
||||
count = num_tokens
|
||||
|
||||
await asyncio.sleep(0.)
|
||||
|
||||
return count, request_id
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
|
||||
@pytest.mark.asyncio
|
||||
async def test_load(output_kind: RequestOutputKind):
|
||||
|
||||
with ExitStack() as after:
|
||||
|
||||
prompt = "This is a test of data parallel"
|
||||
|
||||
engine = AsyncLLM.from_engine_args(engine_args)
|
||||
after.callback(engine.shutdown)
|
||||
|
||||
NUM_REQUESTS = 100
|
||||
NUM_EXPECTED_TOKENS = 10
|
||||
|
||||
request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]
|
||||
|
||||
# Create concurrent requests.
|
||||
tasks = []
|
||||
for request_id in request_ids:
|
||||
tasks.append(
|
||||
asyncio.create_task(
|
||||
generate(engine, request_id, prompt, output_kind,
|
||||
NUM_EXPECTED_TOKENS)))
|
||||
|
||||
# Confirm that we got all the EXPECTED tokens from the requests.
|
||||
done, pending = await asyncio.wait(tasks,
|
||||
return_when=asyncio.FIRST_EXCEPTION)
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
for task in done:
|
||||
num_generated_tokens, request_id = await task
|
||||
assert num_generated_tokens == NUM_EXPECTED_TOKENS, (
|
||||
f"{request_id} generated {num_generated_tokens} but "
|
||||
f"expected {NUM_EXPECTED_TOKENS}")
|
||||
|
||||
assert not engine.output_processor.has_unfinished_requests()
|
||||
|
||||
# testing internals here which may break
|
||||
core_client: DPAsyncMPClient = engine.engine_core
|
||||
# the engines only synchronize stopping every N steps so
|
||||
# allow a small amount of time here.
|
||||
for _ in range(10):
|
||||
if core_client.num_engines_running == 0:
|
||||
break
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
assert core_client.num_engines_running == 0
|
||||
assert not core_client.reqs_in_flight
|
@ -1,7 +1,4 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import tempfile
|
||||
from time import time
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import LLM, envs
|
||||
@ -15,60 +12,6 @@ if not envs.VLLM_USE_V1:
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", ["D4nt3/Qwen2.5-two-layers"])
|
||||
@pytest.mark.skipif(not current_platform.is_tpu(),
|
||||
reason="This test needs a TPU")
|
||||
def test_sampler_compilation(model_name: str, monkeypatch):
|
||||
"""
|
||||
Check that no recompilation happens despite changing sampling parameters.
|
||||
We can't read XLA metrics from the engine process, hence we measure time.
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
monkeypatch.setenv("VLLM_XLA_CACHE_PATH", temp_dir)
|
||||
# Compiling model init may still take some time, enforce_eager to skip.
|
||||
llm = LLM(model_name,
|
||||
enforce_eager=True,
|
||||
max_num_seqs=16,
|
||||
max_model_len=1024,
|
||||
gpu_memory_utilization=0.5)
|
||||
prompts = [
|
||||
"A robot may not injure a human being",
|
||||
"It is only with the heart that one can see rightly;",
|
||||
]
|
||||
# First inference should be slow
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.7,
|
||||
# top_p=0.6, # TODO too slow!
|
||||
top_k=10,
|
||||
min_p=0.2,
|
||||
max_tokens=16)
|
||||
s = time()
|
||||
_ = llm.generate(prompts, sampling_params)
|
||||
run1 = time() - s
|
||||
|
||||
# Second request with different params, but for which we
|
||||
# compiled for in previous eager iteration.
|
||||
sampling_params = SamplingParams(temperature=0.1,
|
||||
top_k=12,
|
||||
min_p=0.8,
|
||||
max_tokens=24)
|
||||
s = time()
|
||||
_ = llm.generate(prompts, sampling_params)
|
||||
run2 = time() - s
|
||||
# Much faster after compiling
|
||||
assert run1 * 0.1 > run2
|
||||
print("TIMES", run1, run2)
|
||||
|
||||
# Third request with min_p set to "None". It will not trigger
|
||||
# recompilation as a default 0 value will be used.
|
||||
sampling_params = SamplingParams(max_tokens=24, temperature=0.0)
|
||||
s = time()
|
||||
_ = llm.generate(prompts, sampling_params)
|
||||
run3 = time() - s
|
||||
assert run1 * 0.1 > run3
|
||||
print("TIMES", run1, run3)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", ["Qwen/Qwen2.5-1.5B-Instruct"])
|
||||
@pytest.mark.skipif(not current_platform.is_tpu(),
|
||||
reason="This test needs a TPU")
|
||||
@ -77,13 +20,11 @@ def test_sampler_different(model_name: str):
|
||||
Test significantly different sampling params to assert the model produces
|
||||
different results.
|
||||
"""
|
||||
llm = LLM(
|
||||
model_name,
|
||||
enforce_eager=True,
|
||||
max_num_seqs=1,
|
||||
max_model_len=64,
|
||||
# TODO: setting to 0.5 or it will go OOM
|
||||
gpu_memory_utilization=0.5)
|
||||
llm = LLM(model_name,
|
||||
enforce_eager=False,
|
||||
max_num_seqs=1,
|
||||
max_model_len=512,
|
||||
max_num_batched_tokens=512)
|
||||
prompts = [
|
||||
"Write a short story about a robot that dreams for the first time."
|
||||
]
|
||||
|
327
tests/v1/tpu/worker/test_tpu_model_runner.py
Normal file
327
tests/v1/tpu/worker/test_tpu_model_runner.py
Normal file
@ -0,0 +1,327 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import unittest.mock as mock
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
|
||||
SchedulerOutput)
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.worker.tpu_model_runner import (TPUModelRunner,
|
||||
_get_padded_token_len,
|
||||
_get_paddings)
|
||||
|
||||
# Mock torch_xla module since it may not be available in the test environments
|
||||
torch_xla_patcher = mock.patch.dict(
|
||||
"sys.modules", {
|
||||
"torch_xla": mock.MagicMock(),
|
||||
"torch_xla.core.xla_model": mock.MagicMock(),
|
||||
"torch_xla.runtime": mock.MagicMock(),
|
||||
})
|
||||
torch_xla_patcher.start()
|
||||
|
||||
# Mock the PallasAttentionBackend
|
||||
pallas_attention_backend_patcher = mock.patch(
|
||||
"vllm.v1.worker.tpu_model_runner.PallasAttentionBackend", )
|
||||
pallas_attention_backend_patcher.start()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_runner():
|
||||
# Patchers have already been started at module level.
|
||||
scheduler_config = SchedulerConfig(
|
||||
max_num_seqs=10,
|
||||
max_num_batched_tokens=512,
|
||||
max_model_len=512,
|
||||
)
|
||||
model_config = ModelConfig(
|
||||
model="facebook/opt-125m",
|
||||
task="generate",
|
||||
tokenizer="facebook/opt-125m",
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=True,
|
||||
dtype="bfloat16", # TPUs typically use bfloat16
|
||||
seed=42,
|
||||
)
|
||||
cache_config = CacheConfig(
|
||||
block_size=16,
|
||||
gpu_memory_utilization=0.9,
|
||||
swap_space=0,
|
||||
cache_dtype="auto",
|
||||
)
|
||||
vllm_config = VllmConfig(
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
scheduler_config=scheduler_config,
|
||||
)
|
||||
device = "xla:0" # Mocking TPU device
|
||||
with mock.patch("vllm.v1.worker.tpu_model_runner.torch"), \
|
||||
mock.patch("vllm.v1.worker.tpu_model_runner.xm"), \
|
||||
mock.patch("vllm.v1.worker.tpu_model_runner.xr"):
|
||||
return TPUModelRunner(vllm_config, device)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, scope="session")
|
||||
def cleanup_patches():
|
||||
yield
|
||||
torch_xla_patcher.stop()
|
||||
pallas_attention_backend_patcher.stop()
|
||||
|
||||
|
||||
def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
|
||||
new_reqs = []
|
||||
num_scheduled_tokens = {}
|
||||
total_num_scheduled_tokens = 0
|
||||
for req_id in req_ids:
|
||||
new_reqs.append(
|
||||
NewRequestData(
|
||||
req_id=req_id,
|
||||
prompt_token_ids=[1, 2, 3],
|
||||
prompt="test",
|
||||
mm_inputs=[],
|
||||
mm_hashes=[],
|
||||
mm_positions=[],
|
||||
sampling_params=SamplingParams(),
|
||||
block_ids=[0],
|
||||
num_computed_tokens=0,
|
||||
lora_request=None,
|
||||
))
|
||||
num_scheduled_tokens[req_id] = 3
|
||||
total_num_scheduled_tokens += num_scheduled_tokens[req_id]
|
||||
|
||||
return SchedulerOutput(
|
||||
scheduled_new_reqs=new_reqs,
|
||||
scheduled_cached_reqs=[],
|
||||
num_scheduled_tokens=num_scheduled_tokens,
|
||||
total_num_scheduled_tokens=total_num_scheduled_tokens,
|
||||
scheduled_spec_decode_tokens={},
|
||||
scheduled_encoder_inputs={},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_input_ids=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None,
|
||||
)
|
||||
|
||||
|
||||
def _is_req_scheduled(model_runner, req_id: str) -> bool:
|
||||
return req_id in model_runner.input_batch.req_id_to_index
|
||||
|
||||
|
||||
def _is_req_added(model_runner, req_id: str) -> bool:
|
||||
return req_id in model_runner.requests
|
||||
|
||||
|
||||
def _is_sampling_metadata_changed(model_runner,
|
||||
sampling_metadata_before: SamplingMetadata):
|
||||
return model_runner.input_batch.sampling_metadata is not (
|
||||
sampling_metadata_before)
|
||||
|
||||
|
||||
def _is_req_state_block_table_match(model_runner, req_id: str) -> bool:
|
||||
req_index = model_runner.input_batch.req_id_to_index[req_id]
|
||||
block_table = model_runner.input_batch.block_table
|
||||
req_state = model_runner.requests[req_id]
|
||||
if block_table.num_blocks_per_row[req_index] != len(req_state.block_ids):
|
||||
return False
|
||||
num_blocks = block_table.num_blocks_per_row[req_index]
|
||||
return (block_table.block_table_np[req_index, :num_blocks] ==
|
||||
req_state.block_ids).all()
|
||||
|
||||
|
||||
def test_update_states_new_request(model_runner):
|
||||
req_id = "req_0"
|
||||
|
||||
# new req
|
||||
scheduler_output = _schedule_new_request(req_id)
|
||||
|
||||
metadata_before = model_runner.input_batch.sampling_metadata
|
||||
model_runner._update_states(scheduler_output)
|
||||
|
||||
assert _is_sampling_metadata_changed(model_runner, metadata_before)
|
||||
assert _is_req_added(model_runner, req_id)
|
||||
assert _is_req_scheduled(model_runner, req_id)
|
||||
assert _is_req_state_block_table_match(model_runner, req_id)
|
||||
|
||||
|
||||
def test_update_states_request_finished(model_runner):
|
||||
req_id = "req_0"
|
||||
|
||||
# new req
|
||||
scheduler_output = _schedule_new_request(req_id)
|
||||
|
||||
model_runner._update_states(scheduler_output)
|
||||
assert _is_req_added(model_runner, req_id)
|
||||
assert _is_req_scheduled(model_runner, req_id)
|
||||
|
||||
# finish req
|
||||
scheduler_output = SchedulerOutput(
|
||||
scheduled_new_reqs=[],
|
||||
scheduled_cached_reqs=[],
|
||||
num_scheduled_tokens={},
|
||||
total_num_scheduled_tokens=0,
|
||||
scheduled_spec_decode_tokens={},
|
||||
scheduled_encoder_inputs={},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids={req_id},
|
||||
free_encoder_input_ids=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None,
|
||||
)
|
||||
|
||||
metadata_before = model_runner.input_batch.sampling_metadata
|
||||
model_runner._update_states(scheduler_output)
|
||||
assert _is_sampling_metadata_changed(model_runner, metadata_before)
|
||||
assert not _is_req_added(model_runner, req_id)
|
||||
assert not _is_req_scheduled(model_runner, req_id)
|
||||
|
||||
|
||||
def test_update_states_request_resumed(model_runner):
|
||||
req_id = "req_0"
|
||||
|
||||
# new req
|
||||
scheduler_output = _schedule_new_request(req_id)
|
||||
|
||||
model_runner._update_states(scheduler_output)
|
||||
assert _is_req_added(model_runner, req_id)
|
||||
assert _is_req_scheduled(model_runner, req_id)
|
||||
|
||||
# unschedule req
|
||||
scheduler_output = SchedulerOutput(
|
||||
scheduled_new_reqs=[],
|
||||
scheduled_cached_reqs=[],
|
||||
num_scheduled_tokens={},
|
||||
total_num_scheduled_tokens=0,
|
||||
scheduled_spec_decode_tokens={},
|
||||
scheduled_encoder_inputs={},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_input_ids=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None,
|
||||
)
|
||||
|
||||
model_runner._update_states(scheduler_output)
|
||||
assert _is_req_added(model_runner, req_id)
|
||||
assert not _is_req_scheduled(model_runner, req_id)
|
||||
|
||||
# resume req
|
||||
cached_req_data = CachedRequestData(
|
||||
req_id=req_id,
|
||||
resumed_from_preemption=False,
|
||||
new_token_ids=[],
|
||||
new_block_ids=[],
|
||||
num_computed_tokens=0,
|
||||
)
|
||||
|
||||
scheduler_output = SchedulerOutput(
|
||||
scheduled_new_reqs=[],
|
||||
scheduled_cached_reqs=[cached_req_data],
|
||||
num_scheduled_tokens={req_id: 1},
|
||||
total_num_scheduled_tokens=1,
|
||||
scheduled_spec_decode_tokens={},
|
||||
scheduled_encoder_inputs={},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_input_ids=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None,
|
||||
)
|
||||
|
||||
metadata_before = model_runner.input_batch.sampling_metadata
|
||||
model_runner._update_states(scheduler_output)
|
||||
assert _is_sampling_metadata_changed(model_runner, metadata_before)
|
||||
assert _is_req_added(model_runner, req_id)
|
||||
assert _is_req_scheduled(model_runner, req_id)
|
||||
assert _is_req_state_block_table_match(model_runner, req_id)
|
||||
|
||||
|
||||
def test_update_states_no_changes(model_runner):
|
||||
req_id = "req_0"
|
||||
|
||||
# new req
|
||||
scheduler_output = _schedule_new_request(req_id)
|
||||
|
||||
model_runner._update_states(scheduler_output)
|
||||
assert _is_req_added(model_runner, req_id)
|
||||
assert _is_req_scheduled(model_runner, req_id)
|
||||
|
||||
# schedule req
|
||||
scheduler_output = SchedulerOutput(
|
||||
scheduled_new_reqs=[],
|
||||
scheduled_cached_reqs=[],
|
||||
num_scheduled_tokens={req_id: 1},
|
||||
total_num_scheduled_tokens=1,
|
||||
scheduled_spec_decode_tokens={},
|
||||
scheduled_encoder_inputs={},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_input_ids=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None,
|
||||
)
|
||||
|
||||
metadata_before = model_runner.input_batch.sampling_metadata
|
||||
model_runner._update_states(scheduler_output)
|
||||
assert not _is_sampling_metadata_changed(model_runner, metadata_before)
|
||||
assert _is_req_added(model_runner, req_id)
|
||||
assert _is_req_scheduled(model_runner, req_id)
|
||||
assert _is_req_state_block_table_match(model_runner, req_id)
|
||||
|
||||
|
||||
def test_update_states_request_unscheduled(model_runner):
|
||||
req_ids = ("req_0", "req_1")
|
||||
|
||||
# new reqs
|
||||
scheduler_output = _schedule_new_request(*req_ids)
|
||||
|
||||
model_runner._update_states(scheduler_output)
|
||||
|
||||
assert _is_req_added(model_runner, req_ids[0])
|
||||
assert _is_req_scheduled(model_runner, req_ids[0])
|
||||
|
||||
assert _is_req_added(model_runner, req_ids[1])
|
||||
assert _is_req_scheduled(model_runner, req_ids[1])
|
||||
|
||||
# unschedule req_1
|
||||
scheduler_output = SchedulerOutput(
|
||||
scheduled_new_reqs=[],
|
||||
scheduled_cached_reqs=[],
|
||||
num_scheduled_tokens={req_ids[0]: 1},
|
||||
total_num_scheduled_tokens=1,
|
||||
scheduled_spec_decode_tokens={},
|
||||
scheduled_encoder_inputs={},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_input_ids=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None,
|
||||
)
|
||||
|
||||
metadata_before = model_runner._update_states(scheduler_output)
|
||||
assert _is_sampling_metadata_changed(model_runner, metadata_before)
|
||||
|
||||
assert _is_req_added(model_runner, req_ids[0])
|
||||
assert _is_req_scheduled(model_runner, req_ids[0])
|
||||
|
||||
assert _is_req_added(model_runner, req_ids[1])
|
||||
assert not _is_req_scheduled(model_runner, req_ids[1])
|
||||
|
||||
|
||||
def test_get_paddings():
|
||||
min_token_size, max_token_size, padding_gap = 16, 512, 64
|
||||
expected_paddings = [16, 32, 64, 128, 192, 256, 320, 384, 448, 512]
|
||||
actual_paddings = _get_paddings(min_token_size, max_token_size,
|
||||
padding_gap)
|
||||
assert actual_paddings == expected_paddings
|
||||
|
||||
|
||||
def test_get_padded_token_len():
|
||||
min_token_size, max_token_size, padding_gap = 16, 512, 64
|
||||
paddings = _get_paddings(min_token_size, max_token_size, padding_gap)
|
||||
assert _get_padded_token_len(paddings, 1) == 16
|
||||
assert _get_padded_token_len(paddings, 16) == 16
|
||||
assert _get_padded_token_len(paddings, 20) == 32
|
||||
assert _get_padded_token_len(paddings, 300) == 320
|
||||
assert _get_padded_token_len(paddings, 512) == 512
|
@ -124,6 +124,18 @@ def paged_attention_rocm(
|
||||
kv_cache_dtype, k_scale, v_scale)
|
||||
|
||||
|
||||
def mla_decode_kvcache_cpu(
|
||||
out: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
scale: float,
|
||||
block_tables: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
) -> None:
|
||||
torch.ops._C_cpu.mla_decode_kvcache(out, query, kv_cache, scale,
|
||||
block_tables, seq_lens)
|
||||
|
||||
|
||||
# pos encoding ops
|
||||
def rotary_embedding(
|
||||
positions: torch.Tensor,
|
||||
@ -575,6 +587,9 @@ def cutlass_sparse_scaled_mm_supported(cuda_device_capability: int) -> bool:
|
||||
cuda_device_capability)
|
||||
|
||||
|
||||
def cutlass_group_gemm_supported(cuda_device_capability: int) -> bool:
|
||||
return torch.ops._C.cutlass_group_gemm_supported(cuda_device_capability)
|
||||
|
||||
def cutlass_sparse_compress(a: torch.Tensor) \
|
||||
-> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
@ -665,6 +680,56 @@ def cutlass_scaled_sparse_mm(
|
||||
return out
|
||||
|
||||
|
||||
def get_cutlass_moe_mm_data(
|
||||
topk_ids: torch.Tensor, expert_offsets: torch.Tensor,
|
||||
problem_sizes1: torch.Tensor, problem_sizes2: torch.Tensor,
|
||||
input_permutation: torch.Tensor, output_permutation: torch.Tensor,
|
||||
num_experts: int, n: int, k: int):
|
||||
"""
|
||||
Prepare data necessary to perform CUTLASS grouped matrix multiplications
|
||||
used in CUTLASS-based fused MoE.
|
||||
|
||||
The function takes in topk_ids (token-expert mapping) and uses it to
|
||||
compute:
|
||||
- expert_offsets: Indices that mark at which token index each expert begins
|
||||
its computation after the input is sorted with
|
||||
input_permutation. The number of tokens computed with
|
||||
expert E is expert_offsets[E + 1] - expert_offsets[E]
|
||||
- problem_sizes1, problem_sizes2: MxNxK sizes of each expert's
|
||||
multiplication in two grouped MMs used in
|
||||
the fused MoE operation.
|
||||
- input_permutation: Permutation that must be used to shuffle the input
|
||||
before executing the MMs.
|
||||
- output_permutation: Permutation that must be used to shuffle the output
|
||||
after executing the MMs.
|
||||
"""
|
||||
torch.ops._C.get_cutlass_moe_mm_data(topk_ids, expert_offsets,
|
||||
problem_sizes1, problem_sizes2,
|
||||
input_permutation, output_permutation,
|
||||
num_experts, n, k)
|
||||
|
||||
|
||||
def cutlass_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor,
|
||||
b_tensors: torch.Tensor, a_scales: torch.Tensor,
|
||||
b_scales: torch.Tensor, expert_offsets: torch.Tensor,
|
||||
problem_sizes: torch.Tensor, a_strides: torch.Tensor,
|
||||
b_strides: torch.Tensor, c_strides: torch.Tensor):
|
||||
"""
|
||||
A single grouped matrix multiplication used in CUTLASS-based fused MoE.
|
||||
The function executes fp8-quantized OUT = AB matrix multiplication.
|
||||
|
||||
- expert_offsets: Indices that mark at which token index each expert begins
|
||||
its computation. The number of tokens computed with
|
||||
expert E is expert_offsets[E + 1] - expert_offsets[E]
|
||||
- problem_sizes: MxNxK sizes of each expert's multiplication in two grouped
|
||||
MMs used in the fused MoE operation.
|
||||
- a/b/c_strides: The data strides passed to grouped matrix multiplication.
|
||||
"""
|
||||
torch.ops._C.cutlass_moe_mm(out_tensors, a_tensors, b_tensors, a_scales,
|
||||
b_scales, expert_offsets, problem_sizes,
|
||||
a_strides, b_strides, c_strides)
|
||||
|
||||
|
||||
# aqlm
|
||||
def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor,
|
||||
codebooks: torch.Tensor, scales: torch.Tensor,
|
||||
|
@ -187,15 +187,28 @@ class ipex_ops:
|
||||
gen_: torch.Generator,
|
||||
logits_soft_cap: float,
|
||||
) -> None:
|
||||
ipex.llm.functional.varlen_attention(query.contiguous(),
|
||||
key.contiguous(),
|
||||
value.contiguous(), out,
|
||||
seqlen_q.int(), seqlen_k.int(),
|
||||
max_seqlen_q, max_seqlen_k,
|
||||
pdropout, softmax_scale,
|
||||
zero_tensors, is_causal,
|
||||
return_softmax, gen_,
|
||||
logits_soft_cap)
|
||||
if ipex.__version__.endswith("cpu"):
|
||||
if logits_soft_cap != 0.0:
|
||||
raise ValueError("IPEX CPU does not support logits_soft_cap")
|
||||
ipex.llm.functional.varlen_attention(query.contiguous(),
|
||||
key.contiguous(),
|
||||
value.contiguous(), out,
|
||||
seqlen_q.int(),
|
||||
seqlen_k.int(), max_seqlen_q,
|
||||
max_seqlen_k, pdropout,
|
||||
softmax_scale, zero_tensors,
|
||||
is_causal, return_softmax,
|
||||
gen_)
|
||||
else: # XPU build
|
||||
ipex.llm.functional.varlen_attention(query.contiguous(),
|
||||
key.contiguous(),
|
||||
value.contiguous(), out,
|
||||
seqlen_q.int(),
|
||||
seqlen_k.int(), max_seqlen_q,
|
||||
max_seqlen_k, pdropout,
|
||||
softmax_scale, zero_tensors,
|
||||
is_causal, return_softmax,
|
||||
gen_, logits_soft_cap)
|
||||
|
||||
@staticmethod
|
||||
def reshape_and_cache(
|
||||
|
303
vllm/attention/backends/cpu_mla.py
Normal file
303
vllm/attention/backends/cpu_mla.py
Normal file
@ -0,0 +1,303 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
from vllm._ipex_ops import ipex_ops
|
||||
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||
AttentionMetadataBuilder,
|
||||
AttentionType,
|
||||
is_quantized_kv_cache)
|
||||
from vllm.attention.backends.mla.common import MLACommonImpl, MLACommonState
|
||||
from vllm.attention.backends.torch_sdpa import TorchSDPAMetadata
|
||||
from vllm.utils import make_tensor_with_pad
|
||||
from vllm.worker.cpu_model_runner import ModelInputForCPUBuilder
|
||||
|
||||
|
||||
class CPUMLABackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "CPU_MLA"
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> Type["CPUMLAMetadata"]:
|
||||
return CPUMLAMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> Type["CPUMLAMetadataBuilder"]:
|
||||
return CPUMLAMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> Type["MLACommonState"]:
|
||||
return MLACommonState
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["CPUMLAImpl"]:
|
||||
return CPUMLAImpl
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int, # assumed to be 1 for MLA
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return (num_blocks, block_size, head_size)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: torch.Tensor,
|
||||
) -> None:
|
||||
ops.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: torch.Tensor,
|
||||
) -> None:
|
||||
ops.copy_blocks_mla(kv_caches, src_to_dists)
|
||||
|
||||
@staticmethod
|
||||
def get_supported_head_sizes() -> List[int]:
|
||||
return [576]
|
||||
|
||||
|
||||
@dataclass
|
||||
class CPUMLAMetadata(TorchSDPAMetadata):
|
||||
# New for MLA
|
||||
# Input positions for rotrary embeddings since for MLA the rotary
|
||||
# position embeddings are applied inside the attention backend
|
||||
input_positions: torch.Tensor = None
|
||||
|
||||
# required by MLACommonImpl
|
||||
is_profile_run: bool = False
|
||||
|
||||
|
||||
class CPUMLAMetadataBuilder(AttentionMetadataBuilder[CPUMLAMetadata]):
|
||||
|
||||
def __init__(self, input_builder: ModelInputForCPUBuilder) -> None:
|
||||
self.chunked_prefill = input_builder.chunked_prefill
|
||||
self.input_builder = input_builder
|
||||
assert not self.chunked_prefill, \
|
||||
"chunked prefill is currently not supported"
|
||||
|
||||
def prepare(self):
|
||||
self.input_data = self.input_builder.input_data
|
||||
|
||||
def build(self, seq_lens, query_lens, cuda_graph_pad_size, batch_size):
|
||||
input_data = self.input_data
|
||||
prefill_seq_lens = seq_lens[0:input_data.num_prefills]
|
||||
prefill_query_lens = query_lens[0:input_data.num_prefills]
|
||||
slot_mapping = torch.tensor(input_data.slot_mapping,
|
||||
dtype=torch.long,
|
||||
device="cpu")
|
||||
|
||||
# metadata for prefill
|
||||
if input_data.num_prefills > 0:
|
||||
query_lens_tensor = torch.tensor(prefill_query_lens,
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
kv_lens_tensor = torch.tensor(prefill_seq_lens,
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
query_start_loc = torch.zeros(input_data.num_prefills + 1,
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
kv_start_loc = torch.zeros(input_data.num_prefills + 1,
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
torch.cumsum(query_lens_tensor,
|
||||
dim=0,
|
||||
dtype=torch.int32,
|
||||
out=query_start_loc[1:])
|
||||
torch.cumsum(kv_lens_tensor,
|
||||
dim=0,
|
||||
dtype=torch.int32,
|
||||
out=kv_start_loc[1:])
|
||||
max_query_len = max(prefill_query_lens)
|
||||
max_kv_len = max(prefill_seq_lens)
|
||||
|
||||
# for chunked-prefill
|
||||
if self.chunked_prefill:
|
||||
prefill_block_tables = make_tensor_with_pad(
|
||||
self.input_data.prefill_block_tables,
|
||||
pad=0,
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
)
|
||||
else:
|
||||
prefill_block_tables = None
|
||||
|
||||
else:
|
||||
query_start_loc = None
|
||||
kv_start_loc = None
|
||||
max_query_len = None
|
||||
max_kv_len = None
|
||||
prefill_block_tables = None
|
||||
|
||||
# metadata for decode
|
||||
if input_data.num_decode_tokens != 0:
|
||||
seq_lens_tensor = torch.tensor(
|
||||
input_data.seq_lens[input_data.num_prefills:],
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
)
|
||||
block_tables = make_tensor_with_pad(
|
||||
self.input_data.decode_block_tables,
|
||||
pad=0,
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
)
|
||||
else:
|
||||
block_tables = torch.tensor([])
|
||||
seq_lens_tensor = torch.tensor(
|
||||
input_data.seq_lens[:input_data.num_prefills],
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
)
|
||||
|
||||
# For multi-modal models
|
||||
placeholder_index_maps = None
|
||||
if len(input_data.multi_modal_inputs_list) != 0:
|
||||
placeholder_index_maps = {
|
||||
modality: placeholder_map.index_map()
|
||||
for modality, placeholder_map in
|
||||
input_data.multi_modal_placeholder_maps.items()
|
||||
}
|
||||
|
||||
return CPUMLAMetadata(
|
||||
chunked_prefill=self.chunked_prefill,
|
||||
seq_lens=prefill_seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_query_len=max_query_len,
|
||||
max_kv_len=max_kv_len,
|
||||
query_start_loc=query_start_loc,
|
||||
kv_start_loc=kv_start_loc,
|
||||
max_decode_seq_len=input_data.max_decode_seq_len,
|
||||
num_prefills=input_data.num_prefills,
|
||||
num_prefill_tokens=input_data.num_prefill_tokens,
|
||||
num_decode_tokens=input_data.num_decode_tokens,
|
||||
block_tables=block_tables,
|
||||
prefill_block_tables=prefill_block_tables,
|
||||
slot_mapping=slot_mapping,
|
||||
multi_modal_placeholder_index_maps=placeholder_index_maps,
|
||||
enable_kv_scales_calculation=False,
|
||||
input_positions=torch.tensor([self.input_data.input_positions]))
|
||||
|
||||
|
||||
class CPUMLAImpl(MLACommonImpl[CPUMLAMetadata]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[Dict[str, Any]],
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
# MLA Specific Arguments
|
||||
**mla_args) -> None:
|
||||
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
blocksparse_params, logits_soft_cap, attn_type,
|
||||
**mla_args)
|
||||
|
||||
unsupported_features = [
|
||||
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
|
||||
]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"CPUMLAImpl does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, blocksparse_params, "
|
||||
"logits_soft_cap")
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"CPUMLAImpl")
|
||||
|
||||
# states is implemented.
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||
raise NotImplementedError(
|
||||
"CPUMLAImpl with FP8 KV cache not yet supported")
|
||||
|
||||
def _forward_prefill(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
kv_c_normed: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: CPUMLAMetadata, # type: ignore[override]
|
||||
) -> torch.Tensor:
|
||||
|
||||
prefill_metadata = attn_metadata.prefill_metadata
|
||||
assert prefill_metadata is not None
|
||||
|
||||
kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\
|
||||
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
|
||||
k_nope, v = kv_nope\
|
||||
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
|
||||
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
|
||||
|
||||
# For MLA the v head dim is smaller than qk head dim so we pad out
|
||||
# v with 0s to match the qk head dim
|
||||
v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
|
||||
value=0)
|
||||
|
||||
output = torch.empty_like(q)
|
||||
ipex_ops.varlen_attention(
|
||||
query=q,
|
||||
key=k,
|
||||
value=v_padded,
|
||||
out=output,
|
||||
seqlen_q=prefill_metadata.query_start_loc,
|
||||
seqlen_k=prefill_metadata.query_start_loc,
|
||||
max_seqlen_q=prefill_metadata.max_query_len,
|
||||
max_seqlen_k=prefill_metadata.max_query_len,
|
||||
pdropout=0.0,
|
||||
softmax_scale=self.scale,
|
||||
zero_tensors=False,
|
||||
is_causal=True,
|
||||
return_softmax=False,
|
||||
gen_=None,
|
||||
logits_soft_cap=0.0,
|
||||
)
|
||||
|
||||
# remove padding
|
||||
output = output.view(-1, self.num_heads,
|
||||
q.shape[-1])[..., :v.shape[-1]]
|
||||
output = output.reshape(-1, self.num_heads * v.shape[-1])
|
||||
return self.o_proj(output)[0]
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: CPUMLAMetadata, # type: ignore[override]
|
||||
) -> torch.Tensor:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
|
||||
decode_meta = attn_metadata.decode_metadata
|
||||
assert decode_meta is not None
|
||||
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)
|
||||
o = q.new_empty(q.shape[0], self.num_heads, self.kv_lora_rank)
|
||||
|
||||
# Run MQA
|
||||
ops.mla_decode_kvcache_cpu(o, q, kv_c_and_k_pe_cache, self.scale,
|
||||
decode_meta.block_tables,
|
||||
decode_meta.seq_lens_tensor)
|
||||
return self._v_up_proj_and_o_proj(o)
|
@ -204,7 +204,6 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
|
||||
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
|
||||
compute_slot_mapping_start_idx,
|
||||
is_block_tables_empty)
|
||||
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearBase, RowParallelLinear,
|
||||
UnquantizedLinearMethod)
|
||||
@ -212,18 +211,27 @@ from vllm.model_executor.layers.rotary_embedding import (
|
||||
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
|
||||
from vllm.multimodal import MultiModalPlaceholderMap
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
from vllm.utils import async_tensor_h2d, cdiv, make_tensor_with_pad, round_down
|
||||
from vllm.vllm_flash_attn.fa_utils import get_flash_attn_version
|
||||
|
||||
if HAS_TRITON:
|
||||
from vllm.attention.ops.triton_flash_attention import triton_attention
|
||||
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
|
||||
else:
|
||||
merge_attn_states = None
|
||||
triton_attention = None
|
||||
|
||||
try:
|
||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||
is_vllm_fa = True
|
||||
except ImportError:
|
||||
# For rocm use upstream flash attention
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
is_vllm_fa = False
|
||||
|
||||
from vllm.attention.ops.triton_flash_attention import triton_attention
|
||||
try:
|
||||
# For rocm use upstream flash attention
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
except ImportError:
|
||||
flash_attn_varlen_func = None
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
|
||||
|
@ -884,9 +884,8 @@ def _sdpa_attention(
|
||||
|
||||
for i, seq_len in enumerate(seq_lens):
|
||||
end = start + seq_len
|
||||
with torch.backends.cuda.sdp_kernel(enable_math=True,
|
||||
enable_flash=False,
|
||||
enable_mem_efficient=False):
|
||||
with torch.nn.attention.sdpa_kernel(
|
||||
torch.nn.attention.SDPBackend.MATH):
|
||||
sub_out = torch.nn.functional.scaled_dot_product_attention(
|
||||
query[:, start:end, :],
|
||||
key[:, start:end, :],
|
||||
@ -909,4 +908,5 @@ def _use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int,
|
||||
and (qtype == torch.half or qtype == torch.bfloat16)
|
||||
and (head_size == 64 or head_size == 128)
|
||||
and (block_size == 16 or block_size == 32)
|
||||
and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768)
|
||||
and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768
|
||||
and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
|
||||
|
@ -54,6 +54,15 @@ def merge_attn_states_kernel(
|
||||
|
||||
p_lse = tl.load(prefix_lse + head_idx * num_tokens + token_idx)
|
||||
s_lse = tl.load(suffix_lse + head_idx * num_tokens + token_idx)
|
||||
|
||||
# FA2 and FA3 have different behavior for when the sum-exp is 0, this namely
|
||||
# arises with 0 len seqlens. FA3 returns -inf here while FA2 returns inf.
|
||||
# If we see an inf assume FA2 and convert inf to -inf for consistency
|
||||
# and correctness. Inf generally doesn't make sense in this context outside
|
||||
# of undefined-behavior/FA2-case, so I think this a safe assumption.
|
||||
p_lse = float('-inf') if p_lse == float('inf') else p_lse
|
||||
s_lse = float('-inf') if s_lse == float('inf') else s_lse
|
||||
|
||||
max_lse = tl.maximum(p_lse, s_lse)
|
||||
p_lse = p_lse - max_lse
|
||||
s_lse = s_lse - max_lse
|
||||
|
@ -381,8 +381,8 @@ class VllmBackend:
|
||||
with open(filepath) as f:
|
||||
hash_content.append(f.read())
|
||||
import hashlib
|
||||
code_hash = hashlib.md5(
|
||||
"\n".join(hash_content).encode()).hexdigest()
|
||||
code_hash = hashlib.md5("\n".join(hash_content).encode(),
|
||||
usedforsecurity=False).hexdigest()
|
||||
factors.append(code_hash)
|
||||
|
||||
# 3. compiler hash
|
||||
@ -390,7 +390,8 @@ class VllmBackend:
|
||||
factors.append(compiler_hash)
|
||||
|
||||
# combine all factors to generate the cache dir
|
||||
hash_key = hashlib.md5(str(factors).encode()).hexdigest()[:10]
|
||||
hash_key = hashlib.md5(str(factors).encode(),
|
||||
usedforsecurity=False).hexdigest()[:10]
|
||||
|
||||
cache_dir = os.path.join(
|
||||
envs.VLLM_CACHE_ROOT,
|
||||
|
@ -139,7 +139,8 @@ class InductorAdaptor(CompilerInterface):
|
||||
from torch._inductor.codecache import torch_key
|
||||
torch_factors = torch_key()
|
||||
factors.append(torch_factors)
|
||||
hash_str = hashlib.md5(str(factors).encode()).hexdigest()[:10]
|
||||
hash_str = hashlib.md5(str(factors).encode(),
|
||||
usedforsecurity=False).hexdigest()[:10]
|
||||
return hash_str
|
||||
|
||||
def initialize_cache(self, cache_dir: str, disable_cache: bool = False):
|
||||
@ -228,7 +229,20 @@ class InductorAdaptor(CompilerInterface):
|
||||
inductor_compiled_graph = output
|
||||
if inductor_compiled_graph is not None:
|
||||
nonlocal file_path
|
||||
file_path = inductor_compiled_graph.current_callable.__code__.co_filename # noqa
|
||||
compiled_fn = inductor_compiled_graph.current_callable
|
||||
file_path = compiled_fn.__code__.co_filename # noqa
|
||||
if not file_path.startswith(self.cache_dir):
|
||||
# hooked in the align_inputs_from_check_idxs function
|
||||
# in torch/_inductor/utils.py
|
||||
for cell in compiled_fn.__closure__:
|
||||
if not callable(cell.cell_contents):
|
||||
continue
|
||||
code = cell.cell_contents.__code__
|
||||
if code.co_filename.startswith(self.cache_dir):
|
||||
# this is the real file path
|
||||
# compiled from Inductor
|
||||
file_path = code.co_filename
|
||||
break
|
||||
hash_str = inductor_compiled_graph._fx_graph_cache_key
|
||||
return output
|
||||
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user