Compare commits

..

14 Commits

Author SHA1 Message Date
44d638a896 merge
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-03-25 10:26:20 -07:00
0a049c7d86 [CI/Build] Add tests for the V1 tpu_model_runner. (#14843)
Signed-off-by: Yarong Mu <ymu@google.com>
2025-03-25 12:27:16 -04:00
d0cfec7ab9 [bugfix] fix inductor cache on max_position_embeddings (#15436)
Signed-off-by: youkaichao <youkaichao@gmail.com>
2025-03-25 07:05:39 -07:00
a608160027 [Kernel] Fix conflicting macro names for gguf kernels (#15456)
Signed-off-by: SzymonOzog <szymon.ozog@gmail.com>
2025-03-25 13:50:49 +00:00
3f04a7fbf2 [Doc] Update V1 user guide for multi-modality (#15460)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
2025-03-25 11:01:58 +00:00
5994430b84 [Misc] Remove redundant num_embeds (#15443)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
2025-03-25 18:27:57 +08:00
a9e879b316 [Misc] Clean up MiniCPM-V/O code (#15337)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
2025-03-25 10:22:52 +00:00
3e2f37a69a Dockerfile.ppc64le changes to move to UBI (#15402)
Signed-off-by: Md. Shafi Hussain <Md.Shafi.Hussain@ibm.com>
2025-03-25 10:15:14 +00:00
4f044b1d67 [Kernel][CPU] CPU MLA (#14744)
Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>
2025-03-25 09:34:59 +00:00
4157f563b4 [Hardware][TPU][Bugfix] Fix v1 mp profiler (#15409)
Signed-off-by: Siyuan Liu <lsiyuan@google.com>
2025-03-25 01:43:00 -07:00
051da7efe3 Fix CUDA kernel index data type in vllm/csrc/quantization/gptq_marlin/awq_marlin_repack.cu +10 (#15160)
Signed-off-by: Lu Fang <lufang@fb.com>
Co-authored-by: Richard Barnes <rbarnes@meta.com>
2025-03-25 15:36:45 +08:00
caacd1ddfb Merge branch 'main' into v1-block-table-opt 2025-01-23 22:59:55 -08:00
e68f63ef83 Simplify
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-01-15 02:31:16 -08:00
223e17424c working
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-01-15 01:24:45 -08:00
50 changed files with 2609 additions and 924 deletions

View File

@ -38,6 +38,8 @@ function cpu_tests() {
set -e
pip install -r vllm/requirements/test.txt
pip install -r vllm/requirements/cpu.txt
pytest -v -s tests/kernels/test_cache.py -m cpu_model
pytest -v -s tests/kernels/test_mla_decode_cpu.py -m cpu_model
pytest -v -s tests/models/decoder_only/language -m cpu_model
pytest -v -s tests/models/embedding/language -m cpu_model
pytest -v -s tests/models/encoder_decoder/language -m cpu_model

View File

@ -30,7 +30,9 @@ docker run --privileged --net host --shm-size=16G -it \
&& echo TEST_4 \
&& pytest -s -v /workspace/vllm/tests/tpu/test_quantization_accuracy.py \
&& echo TEST_5 \
&& python3 /workspace/vllm/examples/offline_inference/tpu.py" \
&& python3 /workspace/vllm/examples/offline_inference/tpu.py \
&& echo TEST_6 \
&& pytest -s -v /workspace/vllm/tests/tpu/worker/test_tpu_model_runner.py" \
# TODO: This test fails because it uses RANDOM_SEED sampling

View File

@ -228,6 +228,7 @@ endif()
set(VLLM_EXT_SRC
"csrc/cache_kernels.cu"
"csrc/block_table.cu"
"csrc/attention/paged_attention_v1.cu"
"csrc/attention/paged_attention_v2.cu"
"csrc/pos_encoding_kernels.cu"

View File

@ -1,37 +1,267 @@
FROM mambaorg/micromamba
ARG MAMBA_DOCKERFILE_ACTIVATE=1
USER root
ARG BASE_UBI_IMAGE_TAG=9.5-1741850109
ENV PATH="/usr/local/cargo/bin:$PATH:/opt/conda/bin/"
###############################################################
# base stage with basic dependencies
###############################################################
RUN apt-get update -y && apt-get install -y git wget kmod curl vim libnuma-dev libsndfile-dev libprotobuf-dev build-essential ffmpeg libsm6 libxext6 libgl1 libssl-dev
FROM registry.access.redhat.com/ubi9/ubi-minimal:${BASE_UBI_IMAGE_TAG} AS base-builder
# Some packages in requirements/cpu are installed here
# IBM provides optimized packages for ppc64le processors in the open-ce project for mamba
# Currently these may not be available for venv or pip directly
RUN micromamba install -y -n base -c https://ftp.osuosl.org/pub/open-ce/1.11.0-p10/ -c defaults python=3.10 rust && micromamba clean --all --yes
ARG PYTHON_VERSION=3.12
ARG OPENBLAS_VERSION=0.3.29
# Set Environment Variables for venv, cargo & openblas
ENV VIRTUAL_ENV=/opt/vllm
ENV PATH=${VIRTUAL_ENV}/bin:/root/.cargo/bin:$PATH
ENV PKG_CONFIG_PATH=/usr/local/lib/pkgconfig/
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib64:/usr/local/lib:/usr/lib64:/usr/lib
ENV UV_LINK_MODE=copy
# install gcc-13, python, rust, openblas
# Note: A symlink for libatomic.so is created for gcc-13 (linker fails to find libatomic otherwise - reqd. for sentencepiece)
# Note: A dummy file 'control' is created in /tmp/ to artificially create dependencies between stages when building stages in parallel
# when `--jobs=<N>` is passed with podman build command
RUN microdnf install -y openssl-devel dnf \
&& dnf install -y https://mirror.stream.centos.org/9-stream/BaseOS/`arch`/os/Packages/centos-gpg-keys-9.0-24.el9.noarch.rpm \
https://mirror.stream.centos.org/9-stream/BaseOS/`arch`/os/Packages/centos-stream-repos-9.0-24.el9.noarch.rpm \
https://dl.fedoraproject.org/pub/epel/epel-release-latest-9.noarch.rpm \
&& dnf config-manager --add-repo https://mirror.stream.centos.org/9-stream/BaseOS/`arch`/os \
&& dnf config-manager --add-repo https://mirror.stream.centos.org/9-stream/AppStream/`arch`/os \
&& dnf config-manager --set-enabled crb \
&& dnf install -y \
git tar gcc-toolset-13 automake libtool numactl-devel lapack-devel \
pkgconfig xsimd zeromq-devel kmod findutils protobuf* \
libtiff-devel libjpeg-devel openjpeg2-devel zlib-devel \
freetype-devel lcms2-devel libwebp-devel tcl-devel tk-devel \
harfbuzz-devel fribidi-devel libraqm-devel libimagequant-devel libxcb-devel \
python${PYTHON_VERSION}-devel python${PYTHON_VERSION}-pip \
&& dnf clean all \
&& ln -sf /usr/lib64/libatomic.so.1 /usr/lib64/libatomic.so \
&& python${PYTHON_VERSION} -m venv ${VIRTUAL_ENV} \
&& python -m pip install -U pip uv \
&& uv pip install wheel build "setuptools<70" setuptools_scm setuptools_rust meson-python cmake ninja cython scikit_build_core scikit_build \
&& curl -sL https://ftp2.osuosl.org/pub/ppc64el/openblas/latest/Openblas_${OPENBLAS_VERSION}_ppc64le.tar.gz | tar xvf - -C /usr/local \
&& curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y \
&& cd /tmp && touch control
###############################################################
# Stage to build torch family
###############################################################
FROM base-builder AS torch-builder
ARG MAX_JOBS
ARG TORCH_VERSION=2.6.0
ARG _GLIBCXX_USE_CXX11_ABI=1
RUN --mount=type=cache,target=/root/.cache/uv \
source /opt/rh/gcc-toolset-13/enable && \
git clone --recursive https://github.com/pytorch/pytorch.git -b v${TORCH_VERSION} && \
cd pytorch && \
uv pip install -r requirements.txt && \
python setup.py develop && \
rm -f dist/torch*+git*whl && \
MAX_JOBS=${MAX_JOBS:-$(nproc)} \
PYTORCH_BUILD_VERSION=${TORCH_VERSION} PYTORCH_BUILD_NUMBER=1 uv build --wheel --out-dir /torchwheels/
ARG TORCHVISION_VERSION=0.21.0
ARG TORCHVISION_USE_NVJPEG=0
ARG TORCHVISION_USE_FFMPEG=0
RUN --mount=type=cache,target=/root/.cache/uv \
source /opt/rh/gcc-toolset-13/enable && \
git clone --recursive https://github.com/pytorch/vision.git -b v${TORCHVISION_VERSION} && \
cd vision && \
MAX_JOBS=${MAX_JOBS:-$(nproc)} \
BUILD_VERSION=${TORCHVISION_VERSION} \
uv build --wheel --out-dir /torchwheels/ --no-build-isolation
ARG TORCHAUDIO_VERSION=2.6.0
ARG BUILD_SOX=1
ARG BUILD_KALDI=1
ARG BUILD_RNNT=1
ARG USE_FFMPEG=0
ARG USE_ROCM=0
ARG USE_CUDA=0
ARG TORCHAUDIO_TEST_ALLOW_SKIP_IF_NO_FFMPEG=1
RUN --mount=type=cache,target=/root/.cache/uv \
source /opt/rh/gcc-toolset-13/enable && \
git clone --recursive https://github.com/pytorch/audio.git -b v${TORCHAUDIO_VERSION} && \
cd audio && \
MAX_JOBS=${MAX_JOBS:-$(nproc)} \
BUILD_VERSION=${TORCHAUDIO_VERSION} \
uv build --wheel --out-dir /torchwheels/ --no-build-isolation
###############################################################
# Stage to build pyarrow
###############################################################
FROM base-builder AS arrow-builder
ARG MAX_JOBS
ARG PYARROW_PARALLEL
ARG PYARROW_VERSION=19.0.1
RUN --mount=type=cache,target=/root/.cache/uv \
source /opt/rh/gcc-toolset-13/enable && \
git clone --recursive https://github.com/apache/arrow.git -b apache-arrow-${PYARROW_VERSION} && \
cd arrow/cpp && \
mkdir build && cd build && \
cmake -DCMAKE_BUILD_TYPE=release \
-DCMAKE_INSTALL_PREFIX=/usr/local \
-DARROW_PYTHON=ON \
-DARROW_BUILD_TESTS=OFF \
-DARROW_JEMALLOC=ON \
-DARROW_BUILD_STATIC="OFF" \
-DARROW_PARQUET=ON \
.. && \
make install -j ${MAX_JOBS:-$(nproc)} && \
cd ../../python/ && \
uv pip install -v -r requirements-wheel-build.txt && \
PYARROW_PARALLEL=${PYARROW_PARALLEL:-$(nproc)} \
python setup.py build_ext \
--build-type=release --bundle-arrow-cpp \
bdist_wheel --dist-dir /arrowwheels/
###############################################################
# Stage to build opencv
###############################################################
FROM base-builder AS cv-builder
ARG MAX_JOBS
ARG OPENCV_VERSION=84
ARG ENABLE_HEADLESS=1
RUN --mount=type=cache,target=/root/.cache/uv \
source /opt/rh/gcc-toolset-13/enable && \
git clone --recursive https://github.com/opencv/opencv-python.git -b ${OPENCV_VERSION} && \
cd opencv-python && \
sed -i 's/"setuptools==59.2.0",/"setuptools<70.0",/g' pyproject.toml && \
python -m build --wheel --installer=uv --outdir /opencvwheels/
###############################################################
# Stage to build vllm - this stage builds and installs
# vllm, tensorizer and vllm-tgis-adapter and builds uv cache
# for transitive dependencies - eg. grpcio
###############################################################
FROM base-builder AS vllmcache-builder
COPY --from=torch-builder /tmp/control /dev/null
COPY --from=arrow-builder /tmp/control /dev/null
COPY --from=cv-builder /tmp/control /dev/null
ARG VLLM_TARGET_DEVICE=cpu
# this step installs vllm and populates uv cache
# with all the transitive dependencies
RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=bind,from=torch-builder,source=/torchwheels/,target=/torchwheels/,ro \
--mount=type=bind,from=arrow-builder,source=/arrowwheels/,target=/arrowwheels/,ro \
--mount=type=bind,from=cv-builder,source=/opencvwheels/,target=/opencvwheels/,ro \
--mount=type=bind,src=.,dst=/src/,rw \
source /opt/rh/gcc-toolset-13/enable && \
uv pip install /opencvwheels/*.whl /arrowwheels/*.whl /torchwheels/*.whl && \
sed -i -e 's/.*torch.*//g' /src/pyproject.toml /src/requirements/*.txt && \
uv pip install pandas pythran pybind11 && \
# sentencepiece.pc is in some pkgconfig inside uv cache
export PKG_CONFIG_PATH=$(find / -type d -name "pkgconfig" 2>/dev/null | tr '\n' ':') && \
uv pip install -r /src/requirements/common.txt -r /src/requirements/cpu.txt -r /src/requirements/build.txt --no-build-isolation && \
cd /src/ && \
uv build --wheel --out-dir /vllmwheel/ --no-build-isolation && \
uv pip install /vllmwheel/*.whl
###############################################################
# Stage to build numactl
###############################################################
FROM base-builder AS numa-builder
# Note: Building numactl with gcc-11. Compiling with gcc-13 in this builder stage will
# trigger recompilation with gcc-11 (and require libtool) in the final stage where we do not have gcc-13
ARG MAX_JOBS
ARG NUMACTL_VERSION=2.0.19
RUN git clone --recursive https://github.com/numactl/numactl.git -b v${NUMACTL_VERSION} \
&& cd numactl \
&& autoreconf -i && ./configure \
&& make -j ${MAX_JOBS:-$(nproc)}
###############################################################
# Stage to build lapack
###############################################################
FROM base-builder AS lapack-builder
ARG MAX_JOBS
ARG LAPACK_VERSION=3.12.1
RUN git clone --recursive https://github.com/Reference-LAPACK/lapack.git -b v${LAPACK_VERSION} \
&& cd lapack && source /opt/rh/gcc-toolset-13/enable \
&& cmake -B build -S . \
&& cmake --build build -j ${MAX_JOBS:-$(nproc)}
###############################################################
# FINAL VLLM IMAGE STAGE #
###############################################################
FROM registry.access.redhat.com/ubi9/ubi-minimal:${BASE_UBI_IMAGE_TAG} AS vllm-openai
ARG PYTHON_VERSION=3.12
ARG OPENBLAS_VERSION=0.3.29
# Set Environment Variables for venv & openblas
ENV VIRTUAL_ENV=/opt/vllm
ENV PATH=${VIRTUAL_ENV}/bin:$PATH
ENV PKG_CONFIG_PATH=/usr/local/lib/pkgconfig/
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib64:/usr/local/lib:/usr/lib64:/usr/lib
ENV UV_LINK_MODE=copy
# create artificial dependencies between stages for independent stages to build in parallel
COPY --from=torch-builder /tmp/control /dev/null
COPY --from=arrow-builder /tmp/control /dev/null
COPY --from=cv-builder /tmp/control /dev/null
COPY --from=vllmcache-builder /tmp/control /dev/null
COPY --from=numa-builder /tmp/control /dev/null
COPY --from=lapack-builder /tmp/control /dev/null
# install gcc-11, python, openblas, numactl, lapack
RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=bind,from=numa-builder,source=/numactl/,target=/numactl/,rw \
--mount=type=bind,from=lapack-builder,source=/lapack/,target=/lapack/,rw \
rpm -ivh https://dl.fedoraproject.org/pub/epel/epel-release-latest-9.noarch.rpm && \
microdnf install --nodocs -y \
tar findutils openssl \
pkgconfig xsimd g++ gcc-fortran libsndfile \
libtiff libjpeg openjpeg2 zlib zeromq \
freetype lcms2 libwebp tcl tk utf8proc \
harfbuzz fribidi libraqm libimagequant libxcb \
python${PYTHON_VERSION}-devel python${PYTHON_VERSION}-pip \
&& microdnf clean all \
&& python${PYTHON_VERSION} -m venv ${VIRTUAL_ENV} \
&& python -m pip install -U pip uv --no-cache \
&& curl -sL https://ftp2.osuosl.org/pub/ppc64el/openblas/latest/Openblas_${OPENBLAS_VERSION}_ppc64le.tar.gz | tar xvf - -C /usr/local \
&& make -C /numactl install \
&& uv pip install cmake \
&& cmake --install /lapack/build \
&& uv pip uninstall cmake
# consume previously built wheels (including vllm)
RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=bind,from=torch-builder,source=/torchwheels/,target=/torchwheels/,ro \
--mount=type=bind,from=arrow-builder,source=/arrowwheels/,target=/arrowwheels/,ro \
--mount=type=bind,from=cv-builder,source=/opencvwheels/,target=/opencvwheels/,ro \
--mount=type=bind,from=vllmcache-builder,source=/vllmwheel/,target=/vllmwheel/,ro \
HOME=/root uv pip install /opencvwheels/*.whl /arrowwheels/*.whl /torchwheels/*.whl /vllmwheel/*.whl
COPY ./ /workspace/vllm
WORKDIR /workspace/vllm
ARG GIT_REPO_CHECK=0
RUN --mount=type=bind,source=.git,target=.git \
if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh; fi
RUN --mount=type=cache,target=/root/.cache/pip \
RUSTFLAGS='-L /opt/conda/lib' pip install -v --prefer-binary --extra-index-url https://repo.fury.io/mgiessing \
'cmake>=3.26' ninja packaging 'setuptools-scm>=8' wheel jinja2 \
-r requirements/cpu.txt \
xformers uvloop==0.20.0
RUN --mount=type=bind,source=.git,target=.git \
VLLM_TARGET_DEVICE=cpu python3 setup.py install
# install development dependencies (for testing)
RUN python3 -m pip install -e tests/vllm_test_utils
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install -e tests/vllm_test_utils
WORKDIR /workspace/
RUN ln -s /workspace/vllm/tests && ln -s /workspace/vllm/examples && ln -s /workspace/vllm/benchmarks
ENTRYPOINT ["/opt/conda/bin/python3", "-m", "vllm.entrypoints.openai.api_server"]
ENTRYPOINT ["python", "-m", "vllm.entrypoints.openai.api_server"]

View File

@ -190,6 +190,7 @@ set(VLLM_EXT_SRC
"csrc/cpu/cache.cpp"
"csrc/cpu/utils.cpp"
"csrc/cpu/layernorm.cpp"
"csrc/cpu/mla_decode.cpp"
"csrc/cpu/pos_encoding.cpp"
"csrc/cpu/torch_bindings.cpp")

92
csrc/block_table.cu Normal file
View File

@ -0,0 +1,92 @@
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
namespace vllm {
__global__ void append_kernel(const int* __restrict__ row_indices,
const int* __restrict__ cu_num_appends,
const int* __restrict__ block_ids,
int* __restrict__ block_table,
int max_num_blocks_per_row) {
int bid = blockIdx.x;
int tgt_row = row_indices[2 * bid];
int tgt_offset = row_indices[2 * bid + 1];
int start = cu_num_appends[bid];
int end = cu_num_appends[bid + 1];
int length = end - start;
int tid = threadIdx.x;
int64_t offset = tgt_row * max_num_blocks_per_row + tgt_offset;
for (int i = tid; i < length; i += blockDim.x) {
block_table[offset + i] = block_ids[start + i];
}
}
__global__ void move_kernel(const int* __restrict__ src_dst_n,
int* __restrict__ block_table,
int max_num_blocks_per_row) {
int bid = blockIdx.x;
int src_row = src_dst_n[3 * bid];
int tgt_row = src_dst_n[3 * bid + 1];
int num_blocks = src_dst_n[3 * bid + 2];
int tid = threadIdx.x;
for (int i = tid; i < num_blocks; i += blockDim.x) {
block_table[tgt_row * max_num_blocks_per_row + i] =
block_table[src_row * max_num_blocks_per_row + i];
}
}
} // namespace vllm
void block_table_appends(
torch::Tensor& append_row_indices,
torch::Tensor& append_row_indices_cpu,
torch::Tensor& append_cumsums,
torch::Tensor& append_cumsums_cpu,
torch::Tensor& append_block_ids,
torch::Tensor& append_block_ids_cpu,
torch::Tensor& block_table,
int64_t num_appends,
int64_t total_num_append_blocks) {
int* append_row_indices_ptr = append_row_indices.data_ptr<int>();
const int* append_row_indices_cpu_ptr = append_row_indices_cpu.data_ptr<int>();
int* append_cumsums_ptr = append_cumsums.data_ptr<int>();
const int* append_cumsums_cpu_ptr = append_cumsums_cpu.data_ptr<int>();
int* append_block_ids_ptr = append_block_ids.data_ptr<int>();
const int* append_block_ids_cpu_ptr = append_block_ids_cpu.data_ptr<int>();
int* block_table_ptr = block_table.data_ptr<int>();
const at::cuda::OptionalCUDAGuard device_guard(device_of(block_table));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
cudaMemcpyAsync(append_row_indices_ptr, append_row_indices_cpu_ptr,
num_appends * 2 * sizeof(int), cudaMemcpyHostToDevice, stream);
cudaMemcpyAsync(append_cumsums_ptr, append_cumsums_cpu_ptr,
(num_appends + 1) * sizeof(int), cudaMemcpyHostToDevice, stream);
cudaMemcpyAsync(append_block_ids_ptr, append_block_ids_cpu_ptr,
total_num_append_blocks * sizeof(int), cudaMemcpyHostToDevice, stream);
int64_t max_num_blocks_per_row = block_table.size(1);
vllm::append_kernel<<<num_appends, 1024, 0, stream>>>(
append_row_indices_ptr, append_cumsums_ptr, append_block_ids_ptr,
block_table_ptr, max_num_blocks_per_row);
}
void block_table_moves(
torch::Tensor& src_dst_n,
torch::Tensor& src_dst_n_cpu,
torch::Tensor& block_table,
int64_t num_moves) {
int* src_dst_n_ptr = src_dst_n.data_ptr<int>();
const int* src_dst_n_cpu_ptr = src_dst_n_cpu.data_ptr<int>();
int* block_table_ptr = block_table.data_ptr<int>();
const at::cuda::OptionalCUDAGuard device_guard(device_of(block_table));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
cudaMemcpyAsync(src_dst_n_ptr, src_dst_n_cpu_ptr,
num_moves * 3 * sizeof(int), cudaMemcpyHostToDevice, stream);
int64_t max_num_blocks_per_row = block_table.size(1);
vllm::move_kernel<<<num_moves, 1024, 0, stream>>>(
src_dst_n_ptr, block_table_ptr, max_num_blocks_per_row);
}

View File

@ -88,6 +88,48 @@ void reshape_and_cache_cpu_impl(
}
}; // namespace
template <typename scalar_t>
void concat_and_cache_mla_cpu_impl(
const scalar_t* __restrict__ kv_c, // [num_tokens, kv_lora_rank]
const scalar_t* __restrict__ k_pe, // [num_tokens, pe_dim]
scalar_t* __restrict__ kv_cache, // [num_blocks, block_size, (kv_lora_rank
// + pe_dim)]
const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int num_tokens, //
const int block_stride, //
const int entry_stride, //
const int kv_c_stride, //
const int k_pe_stride, //
const int kv_lora_rank, //
const int pe_dim, //
const int block_size //
) {
#pragma omp parallel for
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
const int64_t slot_idx = slot_mapping[token_idx];
// NOTE: slot_idx can be -1 if the token is padded
if (slot_idx < 0) {
continue;
}
const int64_t block_idx = slot_idx / block_size;
const int64_t block_offset = slot_idx % block_size;
auto copy = [&](const scalar_t* __restrict__ src,
scalar_t* __restrict__ dst, int src_stride, int dst_stride,
int size, int offset) {
for (int i = 0; i < size; i++) {
const int64_t src_idx = token_idx * src_stride + i;
const int64_t dst_idx =
block_idx * block_stride + block_offset * entry_stride + i + offset;
dst[dst_idx] = src[src_idx];
}
};
copy(kv_c, kv_cache, kv_c_stride, block_stride, kv_lora_rank, 0);
copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank);
}
}
// Note: the key_caches and value_caches vectors are constant but
// not the Tensors they contain. The vectors need to be const refs
// in order to satisfy pytorch's C++ operator registration code.
@ -134,6 +176,38 @@ void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
});
}
void concat_and_cache_mla(
torch::Tensor& kv_c, // [num_tokens, kv_lora_rank]
torch::Tensor& k_pe, // [num_tokens, pe_dim]
torch::Tensor& kv_cache, // [num_blocks, block_size, (kv_lora_rank +
// pe_dim)]
torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens]
const std::string& kv_cache_dtype, torch::Tensor& scale) {
int num_tokens = slot_mapping.size(0);
int kv_lora_rank = kv_c.size(1);
int pe_dim = k_pe.size(1);
int block_size = kv_cache.size(1);
TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim);
TORCH_CHECK(kv_cache_dtype != "fp8");
int kv_c_stride = kv_c.stride(0);
int k_pe_stride = k_pe.stride(0);
int block_stride = kv_cache.stride(0);
int entry_stride = kv_cache.stride(1);
VLLM_DISPATCH_FLOATING_TYPES(
kv_c.scalar_type(), "concat_and_cache_mla_cpu_impl", [&] {
CPU_KERNEL_GUARD_IN(concat_and_cache_mla_cpu_impl)
concat_and_cache_mla_cpu_impl<scalar_t>(
kv_c.data_ptr<scalar_t>(), k_pe.data_ptr<scalar_t>(),
kv_cache.data_ptr<scalar_t>(), slot_mapping.data_ptr<int64_t>(),
num_tokens, block_stride, entry_stride, kv_c_stride, k_pe_stride,
kv_lora_rank, pe_dim, block_size);
CPU_KERNEL_GUARD_OUT(concat_and_cache_mla_cpu_impl)
});
}
void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
const torch::Tensor& block_mapping) {
TORCH_CHECK(false, "swap_blocks is unsupported on CPU.")

View File

@ -130,6 +130,8 @@ struct BF16Vec32 : public Vec<BF16Vec32> {
__m512i reg;
explicit BF16Vec32() : reg(_mm512_setzero_si512()) {}
explicit BF16Vec32(const void* ptr) : reg((__m512i)_mm512_loadu_si512(ptr)) {}
explicit BF16Vec32(__m512i data) : reg(data) {}

393
csrc/cpu/mla_decode.cpp Normal file
View File

@ -0,0 +1,393 @@
#include "cpu_types.hpp"
#include <float.h>
namespace {
template <typename scalar_t>
struct KernelVecType {
using qk_load_vec_type = void;
using qk_vec_type = void;
using v_load_vec_type = void;
};
template <>
struct KernelVecType<float> {
using qk_load_vec_type = vec_op::FP32Vec16;
using qk_vec_type = vec_op::FP32Vec16;
using v_load_vec_type = vec_op::FP32Vec16;
};
template <>
struct KernelVecType<c10::Half> {
#if defined(__powerpc64__) || defined(__s390x__)
// Power and s390x architecture-specific vector types
using qk_load_vec_type = vec_op::FP32Vec16;
using qk_vec_type = vec_op::FP32Vec16;
using v_load_vec_type = vec_op::FP32Vec16;
#else
// Fallback for other architectures, including x86
using qk_load_vec_type = vec_op::FP16Vec16;
using qk_vec_type = vec_op::FP32Vec16;
using v_load_vec_type = vec_op::FP16Vec16;
#endif
};
#ifdef __AVX512BF16__
template <>
struct KernelVecType<c10::BFloat16> {
using qk_load_vec_type = vec_op::BF16Vec32;
using qk_vec_type = vec_op::BF16Vec32;
using v_load_vec_type = vec_op::BF16Vec16;
};
#elif defined(__aarch64__) && !defined(ARM_BF16_SUPPORT)
// pass
#else
template <>
struct KernelVecType<c10::BFloat16> {
using qk_load_vec_type = vec_op::BF16Vec16;
using qk_vec_type = vec_op::FP32Vec16;
using v_load_vec_type = vec_op::BF16Vec16;
};
#endif
template <int HEAD_DIM, int V_HEAD_DIM, int BLOCK_SIZE, int HEAD_UNROLL,
typename qk_vec_type>
void mla_decode_block_head(
const qk_vec_type* __restrict__ q_vecs, // [HEAD_UNROLL, head_dim]
const qk_vec_type* __restrict__ k_vecs, // [block_size, head_dim]
const vec_op::FP32Vec16* __restrict v_vecs_f32, // [block_size, v_head_dim]
float* __restrict__ acc_out, // [HEAD_UNROLL, v_head_dim]
float* __restrict__ acc_lse, // [HEAD_UNROLL]
const float scale, const int num_tokens) {
using f32_vec_type = vec_op::FP32Vec16;
constexpr int QK_NUM_ELEM = qk_vec_type::VEC_ELEM_NUM;
constexpr int V_NUM_ELEM = f32_vec_type::VEC_ELEM_NUM;
float logits[BLOCK_SIZE][HEAD_UNROLL] = {}; // initialize to zeros
float max_val[HEAD_UNROLL];
std::fill(max_val, max_val + HEAD_UNROLL, -FLT_MAX);
f32_vec_type acc_vec[BLOCK_SIZE][HEAD_UNROLL];
for (int i = 0; i < HEAD_DIM; i += QK_NUM_ELEM) {
// load to registers
qk_vec_type q_vec[HEAD_UNROLL];
#pragma unroll
for (int unroll = 0; unroll < HEAD_UNROLL; ++unroll)
q_vec[unroll] =
qk_vec_type{q_vecs[(i + unroll * HEAD_DIM) / QK_NUM_ELEM]};
for (int block_offset = 0; block_offset < num_tokens; ++block_offset) {
qk_vec_type k_vec(k_vecs[(block_offset * HEAD_DIM + i) / QK_NUM_ELEM]);
#pragma unroll
for (int unroll = 0; unroll < HEAD_UNROLL; ++unroll)
vec_op::fma(acc_vec[block_offset][unroll], q_vec[unroll], k_vec);
}
}
for (int block_offset = 0; block_offset < num_tokens; ++block_offset) {
#pragma unroll
for (int unroll = 0; unroll < HEAD_UNROLL; ++unroll) {
const float acc = acc_vec[block_offset][unroll].reduce_sum() * scale;
logits[block_offset][unroll] = acc;
max_val[unroll] = std::max(max_val[unroll], acc);
}
}
float sum_exp[HEAD_UNROLL] = {};
for (int block_offset = 0; block_offset < num_tokens; ++block_offset) {
#pragma unroll
for (int unroll = 0; unroll < HEAD_UNROLL; ++unroll) {
const float val =
std::exp(logits[block_offset][unroll] - max_val[unroll]);
logits[block_offset][unroll] = val;
sum_exp[unroll] += val;
}
}
f32_vec_type this_out[V_HEAD_DIM / V_NUM_ELEM][HEAD_UNROLL];
for (int block_offset = 0; block_offset < num_tokens; ++block_offset) {
// load to registers
f32_vec_type scale_[HEAD_UNROLL];
#pragma unroll
for (int unroll = 0; unroll < HEAD_UNROLL; ++unroll)
scale_[unroll] =
f32_vec_type{logits[block_offset][unroll] / sum_exp[unroll]};
for (int i = 0; i < V_HEAD_DIM; i += V_NUM_ELEM) {
f32_vec_type v_vec(
v_vecs_f32[(block_offset * HEAD_DIM + i) / V_NUM_ELEM]);
#pragma unroll
for (int unroll = 0; unroll < HEAD_UNROLL; ++unroll)
vec_op::fma(this_out[i / V_NUM_ELEM][unroll], v_vec, scale_[unroll]);
}
}
// merge attention state
// section 2.2 in https://arxiv.org/pdf/2501.01005
f32_vec_type prev_scale[HEAD_UNROLL];
f32_vec_type curr_scale[HEAD_UNROLL];
#pragma unroll
for (int unroll = 0; unroll < HEAD_UNROLL; ++unroll) {
const float prev_lse = acc_lse[unroll];
const float curr_lse = std::log(sum_exp[unroll]) +
max_val[unroll]; // add back max_val to get true lse
// softmax trick
const float max_lse = std::max(prev_lse, curr_lse);
const float prev_sum_exp = std::exp(prev_lse - max_lse);
const float curr_sum_exp = std::exp(curr_lse - max_lse);
const float new_sum_exp = prev_sum_exp + curr_sum_exp;
acc_lse[unroll] = std::log(new_sum_exp) + max_lse;
prev_scale[unroll] = f32_vec_type{prev_sum_exp / new_sum_exp};
curr_scale[unroll] = f32_vec_type{curr_sum_exp / new_sum_exp};
}
for (int i = 0; i < V_HEAD_DIM; i += V_NUM_ELEM) {
#pragma unroll
for (int unroll = 0; unroll < HEAD_UNROLL; ++unroll) {
f32_vec_type o_vec(acc_out + i + V_HEAD_DIM * unroll);
o_vec = o_vec * prev_scale[unroll] +
this_out[i / V_NUM_ELEM][unroll] * curr_scale[unroll];
o_vec.save(acc_out + i + V_HEAD_DIM * unroll);
}
}
q_vecs += HEAD_DIM / QK_NUM_ELEM * HEAD_UNROLL;
acc_out += V_HEAD_DIM * HEAD_UNROLL;
}
template <typename scalar_t, int HEAD_DIM, int V_HEAD_DIM, int BLOCK_SIZE,
typename qk_vec_type>
void mla_decode_block(
const qk_vec_type* __restrict__ q_vecs, // [num_heads, head_dim]
const scalar_t* __restrict__ kv_cache, // [block_size, head_dim]
float* __restrict__ acc_out, // [num_heads, v_head_dim]
float* __restrict__ acc_lse, // [num_heads]
const int num_heads, const float scale, const int num_tokens) {
using qk_load_vec_type = typename KernelVecType<scalar_t>::qk_load_vec_type;
static_assert(
std::is_same<qk_vec_type,
typename KernelVecType<scalar_t>::qk_vec_type>::value);
using v_load_vec_type = typename KernelVecType<scalar_t>::v_load_vec_type;
using f32_vec_type = vec_op::FP32Vec16;
static_assert(qk_load_vec_type::VEC_ELEM_NUM == qk_vec_type::VEC_ELEM_NUM);
static_assert(v_load_vec_type::VEC_ELEM_NUM == f32_vec_type::VEC_ELEM_NUM);
constexpr int QK_NUM_ELEM = qk_vec_type::VEC_ELEM_NUM;
constexpr int V_NUM_ELEM = v_load_vec_type::VEC_ELEM_NUM;
const qk_vec_type* k_vecs;
const f32_vec_type* v_vecs_f32;
float* kv_cache_f32 = nullptr;
if constexpr (!std::is_same<scalar_t, float>::value) {
// convert KV cache block to FP32 to reuse it across query heads and
// attn @ V computation, since FP16/BF16->FP32 is expensive.
// TODO: move malloc outside of this fn to reuse across iterations.
const int nbytes = BLOCK_SIZE * HEAD_DIM * sizeof(float);
kv_cache_f32 = static_cast<float*>(std::aligned_alloc(64, nbytes));
for (int block_offset = 0; block_offset < num_tokens; ++block_offset)
for (int i = 0; i < HEAD_DIM; i += V_NUM_ELEM) {
v_load_vec_type kv_load_vec(kv_cache + block_offset * HEAD_DIM + i);
f32_vec_type kv_vec_f32(kv_load_vec);
kv_vec_f32.save(kv_cache_f32 + block_offset * HEAD_DIM + i);
}
if constexpr (std::is_same<qk_load_vec_type, qk_vec_type>::value) {
// for AVX512_BF16, Q @ K.T uses BF16 for K (no conversion)
// NOTE: in this case, we only need to convert the V section to FP32.
// But for simplicity, we will convert the whole KV block to FP32.
k_vecs = reinterpret_cast<const qk_vec_type*>(kv_cache);
} else {
k_vecs = reinterpret_cast<const qk_vec_type*>(kv_cache_f32);
}
// attn @ V always use FP32 for V, since attn is FP32.
v_vecs_f32 = reinterpret_cast<const f32_vec_type*>(kv_cache_f32);
} else {
// KV cache is FP32. don't need to do anything.
k_vecs = reinterpret_cast<const qk_vec_type*>(kv_cache);
v_vecs_f32 = reinterpret_cast<const f32_vec_type*>(kv_cache);
}
// compute 2 heads at the same time to improve ILP and
// take advantage of register cache for K and V.
constexpr int HEAD_UNROLL = 2;
for (int iter = 0; iter < num_heads / HEAD_UNROLL; ++iter) {
mla_decode_block_head<HEAD_DIM, V_HEAD_DIM, BLOCK_SIZE, HEAD_UNROLL>(
q_vecs, k_vecs, v_vecs_f32, acc_out, acc_lse, scale, num_tokens);
q_vecs += HEAD_UNROLL * HEAD_DIM / QK_NUM_ELEM;
acc_out += HEAD_UNROLL * V_HEAD_DIM;
acc_lse += HEAD_UNROLL;
}
// take care of the remaining heads
for (int iter = 0; iter < num_heads % HEAD_UNROLL; ++iter) {
mla_decode_block_head<HEAD_DIM, V_HEAD_DIM, BLOCK_SIZE, 1>(
q_vecs, k_vecs, v_vecs_f32, acc_out, acc_lse, scale, num_tokens);
q_vecs += HEAD_DIM / QK_NUM_ELEM;
acc_out += V_HEAD_DIM;
acc_lse += 1;
}
if (kv_cache_f32 != nullptr) {
std::free(kv_cache_f32);
}
}
} // namespace
template <typename scalar_t, int HEAD_DIM, int V_HEAD_DIM, int BLOCK_SIZE>
void mla_decode_kvcache_cpu_impl(
scalar_t* __restrict__ out, // [num_seqs, num_heads, v_head_dim]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_dim]
const scalar_t* __restrict__ kv_cache, // [num_blocks, block_size,
// head_dim]
const int num_heads, const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq, const int o_stride, const int q_stride,
const int kv_stride, const int num_seqs) {
using qk_load_vec_type = typename KernelVecType<scalar_t>::qk_load_vec_type;
using qk_vec_type = typename KernelVecType<scalar_t>::qk_vec_type;
constexpr int QK_NUM_ELEM = qk_vec_type::VEC_ELEM_NUM;
// shared across threads
const int max_threads = omp_get_max_threads();
const int acc_out_nbytes =
max_threads * num_heads * V_HEAD_DIM * sizeof(float);
float* acc_out = static_cast<float*>(std::aligned_alloc(64, acc_out_nbytes));
std::vector<float> acc_lse(max_threads * num_heads);
// allocate memory to pre-convert query to FP32 later
float* q_f32;
constexpr bool PRE_CONVERT_QUERY =
!std::is_same<scalar_t, float>::value &&
std::is_same<qk_vec_type, vec_op::FP32Vec16>::value;
if constexpr (PRE_CONVERT_QUERY) {
const int q_f32_nbytes = num_heads * HEAD_DIM * sizeof(float);
q_f32 = static_cast<float*>(std::aligned_alloc(64, q_f32_nbytes));
}
#pragma omp parallel
{
const int num_threads = omp_get_num_threads();
const int thread_id = omp_get_thread_num();
float* __restrict__ acc_out_thread =
acc_out + thread_id * num_heads * V_HEAD_DIM;
float* __restrict__ acc_lse_thread = acc_lse.data() + thread_id * num_heads;
for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
// reset accumulator
std::fill(acc_out_thread, acc_out_thread + num_heads * V_HEAD_DIM, 0.0f);
std::fill(acc_lse_thread, acc_lse_thread + num_heads, -FLT_MAX);
const int seq_len = seq_lens[seq_idx];
const int block_num = (seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
const int last_block_size = seq_len - (block_num - 1) * BLOCK_SIZE;
const qk_vec_type* q_vecs;
if constexpr (PRE_CONVERT_QUERY) {
// pre-convert query to FP32 since FP16/BF16->FP32 is slow.
#pragma omp for
for (int i = 0; i < num_heads * HEAD_DIM; i += QK_NUM_ELEM) {
qk_load_vec_type q_load_vec(q + seq_idx * q_stride + i);
qk_vec_type q_vec(q_load_vec);
q_vec.save(q_f32 + i);
}
q_vecs = reinterpret_cast<const qk_vec_type*>(q_f32);
} else {
q_vecs = reinterpret_cast<const qk_vec_type*>(q + seq_idx * q_stride);
}
#pragma omp for
for (int block_idx = 0; block_idx < block_num; ++block_idx) {
const int physical_block_idx =
block_tables[seq_idx * max_num_blocks_per_seq + block_idx];
const int num_tokens =
block_idx < block_num - 1 ? BLOCK_SIZE : last_block_size;
mla_decode_block<scalar_t, HEAD_DIM, V_HEAD_DIM, BLOCK_SIZE>(
q_vecs, kv_cache + physical_block_idx * kv_stride, acc_out_thread,
acc_lse_thread, num_heads, scale, num_tokens);
}
// merge attention states across threads
// section 2.2 in https://arxiv.org/pdf/2501.01005
// each thread is responsible for 1 head
#pragma omp for
for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
float* acc_lse_head = acc_lse.data() + head_idx;
float* acc_out_head = acc_out + head_idx * V_HEAD_DIM;
float max_val = -FLT_MAX;
for (int thread_id_ = 0; thread_id_ < num_threads; ++thread_id_) {
max_val = std::max(max_val, acc_lse_head[thread_id_ * num_heads]);
}
float sum_exp = 0.0f;
for (int thread_id_ = 0; thread_id_ < num_threads; ++thread_id_) {
float val = std::exp(acc_lse_head[thread_id_ * num_heads] - max_val);
acc_lse_head[thread_id_ * num_heads] = val;
sum_exp += val;
}
float inv_sum = 1.0f / sum_exp;
float out_head[V_HEAD_DIM] = {};
for (int thread_id_ = 0; thread_id_ < num_threads; ++thread_id_) {
float scale_ = acc_lse_head[thread_id_ * num_heads] * inv_sum;
for (int i = 0; i < V_HEAD_DIM; ++i) {
out_head[i] +=
acc_out_head[thread_id_ * num_heads * V_HEAD_DIM + i] * scale_;
}
}
for (int i = 0; i < V_HEAD_DIM; ++i) {
vec_op::storeFP32(out_head[i], out + seq_idx * o_stride +
head_idx * V_HEAD_DIM + i);
}
}
}
}
if (PRE_CONVERT_QUERY) {
std::free(q_f32);
}
std::free(acc_out);
}
void mla_decode_kvcache(torch::Tensor& out, torch::Tensor& query,
torch::Tensor& kv_cache, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens) {
const int num_seqs = query.size(0);
const int num_heads = query.size(1);
const int head_dim = query.size(2);
const int block_size = kv_cache.size(1);
const int v_head_dim = out.size(2);
const int max_num_blocks_per_seq = block_tables.size(1);
const int o_stride = out.stride(0);
const int q_stride = query.stride(0);
const int kv_stride = kv_cache.stride(0);
VLLM_DISPATCH_FLOATING_TYPES(
query.scalar_type(), "mla_decode_kvcache_cpu_impl", [&] {
CPU_KERNEL_GUARD_IN(mla_decode_kvcache_cpu_impl)
if (head_dim == 576 && v_head_dim == 512 && block_size == 16)
mla_decode_kvcache_cpu_impl<scalar_t, 576, 512, 16>(
out.data_ptr<scalar_t>(), query.data_ptr<scalar_t>(),
kv_cache.data_ptr<scalar_t>(), num_heads, scale,
block_tables.data_ptr<int>(), seq_lens.data_ptr<int>(),
max_num_blocks_per_seq, o_stride, q_stride, kv_stride, num_seqs);
else
TORCH_CHECK(false, "Unsupported block size: ", block_size);
CPU_KERNEL_GUARD_OUT(mla_decode_kvcache_cpu_impl)
});
}

View File

@ -18,6 +18,10 @@ void int8_scaled_mm_azp(torch::Tensor& c, const torch::Tensor& a,
const std::optional<torch::Tensor>& azp,
const std::optional<torch::Tensor>& bias);
void mla_decode_kvcache(torch::Tensor& out, torch::Tensor& query,
torch::Tensor& kv_cache, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens);
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// vLLM custom ops
@ -150,6 +154,14 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
" str kv_cache_dtype,"
" Tensor k_scale, Tensor v_scale) -> ()");
cache_ops.impl("reshape_and_cache", torch::kCPU, &reshape_and_cache);
cache_ops.def(
"concat_and_cache_mla(Tensor kv_c, Tensor k_pe,"
" Tensor! kv_cache,"
" Tensor slot_mapping,"
" str kv_cache_dtype,"
" Tensor scale) -> ()");
cache_ops.impl("concat_and_cache_mla", torch::kCPU, &concat_and_cache_mla);
}
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _utils), utils) {
@ -157,4 +169,12 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _utils), utils) {
utils.def("init_cpu_threads_env(str cpu_ids) -> str", &init_cpu_threads_env);
}
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cpu), cpu_ops) {
cpu_ops.def(
"mla_decode_kvcache("
" Tensor! out, Tensor query, Tensor kv_cache,"
" float scale, Tensor block_tables, Tensor seq_lens) -> ()");
cpu_ops.impl("mla_decode_kvcache", torch::kCPU, &mla_decode_kvcache);
}
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)

View File

@ -119,6 +119,18 @@ void advance_step_flashinfer(
torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr,
torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bounds);
void block_table_appends(torch::Tensor& append_row_indices,
torch::Tensor& append_row_indices_cpu,
torch::Tensor& append_cumsums,
torch::Tensor& append_cumsums_cpu,
torch::Tensor& append_block_ids,
torch::Tensor& append_block_ids_cpu,
torch::Tensor& block_table, int64_t num_appends,
int64_t total_num_append_blocks);
void block_table_moves(torch::Tensor& src_dst_n, torch::Tensor& src_dst_n_cpu,
torch::Tensor& block_table, int64_t num_moves);
#ifndef USE_ROCM
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
const torch::Tensor& codebooks,

View File

@ -375,25 +375,25 @@ torch::Tensor ggml_moe_a8(torch::Tensor X, // input
int64_t ggml_moe_get_block_size(int64_t type) {
switch (type) {
case 2:
return MMQ_X_Q4_0;
return MOE_X_Q4_0;
case 3:
return MMQ_X_Q4_1;
return MOE_X_Q4_1;
case 6:
return MMQ_X_Q5_0;
return MOE_X_Q5_0;
case 7:
return MMQ_X_Q5_1;
return MOE_X_Q5_1;
case 8:
return MMQ_X_Q8_0;
return MOE_X_Q8_0;
case 10:
return MMQ_X_Q2_K;
return MOE_X_Q2_K;
case 11:
return MMQ_X_Q3_K;
return MOE_X_Q3_K;
case 12:
return MMQ_X_Q4_K;
return MOE_X_Q4_K;
case 13:
return MMQ_X_Q5_K;
return MOE_X_Q5_K;
case 14:
return MMQ_X_Q6_K;
return MOE_X_Q6_K;
}
return 0;
}

View File

@ -129,12 +129,12 @@ static __device__ __forceinline__ void moe_q(
}
#if defined(USE_ROCM)
#define MMQ_X_Q4_0 64
#define MMQ_Y_Q4_0 128
#define MOE_X_Q4_0 64
#define MOE_Y_Q4_0 128
#define NWARPS_Q4_0 8
#else
#define MMQ_X_Q4_0 4
#define MMQ_Y_Q4_0 32
#define MOE_X_Q4_0 4
#define MOE_Y_Q4_0 32
#define NWARPS_Q4_0 4
#endif
@ -149,8 +149,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q4_0, 2)
const int exp_stride, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst,
const int top_k) {
const int mmq_x = MMQ_X_Q4_0;
const int mmq_y = MMQ_Y_Q4_0;
const int mmq_x = MOE_X_Q4_0;
const int mmq_y = MOE_Y_Q4_0;
const int nwarps = NWARPS_Q4_0;
moe_q<scalar_t, QK4_0, QR4_0, QI4_0, true, block_q4_0, mmq_x, mmq_y, nwarps,
@ -167,8 +167,8 @@ static void ggml_moe_q4_0_q8_1_cuda(
const int exp_stride, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
const int tokens_post_padded, cudaStream_t stream) {
int mmq_x = MMQ_X_Q4_0;
int mmq_y = MMQ_Y_Q4_0;
int mmq_x = MOE_X_Q4_0;
int mmq_y = MOE_Y_Q4_0;
int nwarps = NWARPS_Q4_0;
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
@ -190,12 +190,12 @@ static void ggml_moe_q4_0_q8_1_cuda(
}
#if defined(USE_ROCM)
#define MMQ_X_Q4_1 64
#define MMQ_Y_Q4_1 128
#define MOE_X_Q4_1 64
#define MOE_Y_Q4_1 128
#define NWARPS_Q4_1 8
#else
#define MMQ_X_Q4_1 4
#define MMQ_Y_Q4_1 32
#define MOE_X_Q4_1 4
#define MOE_Y_Q4_1 32
#define NWARPS_Q4_1 4
#endif
@ -210,8 +210,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q4_1, 2)
const int exp_stride, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst,
const int top_k) {
const int mmq_x = MMQ_X_Q4_1;
const int mmq_y = MMQ_Y_Q4_1;
const int mmq_x = MOE_X_Q4_1;
const int mmq_y = MOE_Y_Q4_1;
const int nwarps = NWARPS_Q4_1;
moe_q<scalar_t, QK4_1, QR4_1, QI4_1, true, block_q4_1, mmq_x, mmq_y, nwarps,
@ -228,8 +228,8 @@ static void ggml_moe_q4_1_q8_1_cuda(
const int exp_stride, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
const int tokens_post_padded, cudaStream_t stream) {
int mmq_x = MMQ_X_Q4_1;
int mmq_y = MMQ_Y_Q4_1;
int mmq_x = MOE_X_Q4_1;
int mmq_y = MOE_Y_Q4_1;
int nwarps = NWARPS_Q4_1;
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
@ -251,12 +251,12 @@ static void ggml_moe_q4_1_q8_1_cuda(
}
#if defined(USE_ROCM)
#define MMQ_X_Q5_0 64
#define MMQ_Y_Q5_0 128
#define MOE_X_Q5_0 64
#define MOE_Y_Q5_0 128
#define NWARPS_Q5_0 8
#else
#define MMQ_X_Q5_0 4
#define MMQ_Y_Q5_0 32
#define MOE_X_Q5_0 4
#define MOE_Y_Q5_0 32
#define NWARPS_Q5_0 4
#endif
@ -271,8 +271,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q5_0, 2)
const int exp_stride, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst,
const int top_k) {
const int mmq_x = MMQ_X_Q5_0;
const int mmq_y = MMQ_Y_Q5_0;
const int mmq_x = MOE_X_Q5_0;
const int mmq_y = MOE_Y_Q5_0;
const int nwarps = NWARPS_Q5_0;
moe_q<scalar_t, QK5_0, QR5_0, QI5_0, false, block_q5_0, mmq_x, mmq_y, nwarps,
@ -289,8 +289,8 @@ static void ggml_moe_q5_0_q8_1_cuda(
const int exp_stride, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
const int tokens_post_padded, cudaStream_t stream) {
const int mmq_x = MMQ_X_Q5_0;
const int mmq_y = MMQ_Y_Q5_0;
const int mmq_x = MOE_X_Q5_0;
const int mmq_y = MOE_Y_Q5_0;
const int nwarps = NWARPS_Q5_0;
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
@ -312,12 +312,12 @@ static void ggml_moe_q5_0_q8_1_cuda(
}
#if defined(USE_ROCM)
#define MMQ_X_Q5_1 64
#define MMQ_Y_Q5_1 128
#define MOE_X_Q5_1 64
#define MOE_Y_Q5_1 128
#define NWARPS_Q5_1 8
#else
#define MMQ_X_Q5_1 4
#define MMQ_Y_Q5_1 32
#define MOE_X_Q5_1 4
#define MOE_Y_Q5_1 32
#define NWARPS_Q5_1 4
#endif
@ -332,8 +332,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q5_1, 2)
const int exp_stride, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst,
const int top_k) {
const int mmq_x = MMQ_X_Q5_1;
const int mmq_y = MMQ_Y_Q5_1;
const int mmq_x = MOE_X_Q5_1;
const int mmq_y = MOE_Y_Q5_1;
const int nwarps = NWARPS_Q5_1;
moe_q<scalar_t, QK5_1, QR5_1, QI5_1, true, block_q5_1, mmq_x, mmq_y, nwarps,
@ -350,8 +350,8 @@ static void ggml_moe_q5_1_q8_1_cuda(
const int exp_stride, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
const int tokens_post_padded, cudaStream_t stream) {
const int mmq_x = MMQ_X_Q5_1;
const int mmq_y = MMQ_Y_Q5_1;
const int mmq_x = MOE_X_Q5_1;
const int mmq_y = MOE_Y_Q5_1;
const int nwarps = NWARPS_Q5_1;
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
@ -373,12 +373,12 @@ static void ggml_moe_q5_1_q8_1_cuda(
}
#if defined(USE_ROCM)
#define MMQ_X_Q8_0 64
#define MMQ_Y_Q8_0 128
#define MOE_X_Q8_0 64
#define MOE_Y_Q8_0 128
#define NWARPS_Q8_0 8
#else
#define MMQ_X_Q8_0 4
#define MMQ_Y_Q8_0 32
#define MOE_X_Q8_0 4
#define MOE_Y_Q8_0 32
#define NWARPS_Q8_0 4
#endif
@ -393,8 +393,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q8_0, 2)
const int exp_stride, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst,
const int top_k) {
const int mmq_x = MMQ_X_Q8_0;
const int mmq_y = MMQ_Y_Q8_0;
const int mmq_x = MOE_X_Q8_0;
const int mmq_y = MOE_Y_Q8_0;
const int nwarps = NWARPS_Q8_0;
moe_q<scalar_t, QK8_0, QR8_0, QI8_0, false, block_q8_0, mmq_x, mmq_y, nwarps,
@ -411,8 +411,8 @@ static void ggml_moe_q8_0_q8_1_cuda(
const int exp_stride, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
const int tokens_post_padded, cudaStream_t stream) {
const int mmq_x = MMQ_X_Q8_0;
const int mmq_y = MMQ_Y_Q8_0;
const int mmq_x = MOE_X_Q8_0;
const int mmq_y = MOE_Y_Q8_0;
const int nwarps = NWARPS_Q8_0;
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
@ -434,12 +434,12 @@ static void ggml_moe_q8_0_q8_1_cuda(
}
#if defined(USE_ROCM)
#define MMQ_X_Q2_K 64
#define MMQ_Y_Q2_K 128
#define MOE_X_Q2_K 64
#define MOE_Y_Q2_K 128
#define NWARPS_Q2_K 8
#else
#define MMQ_X_Q2_K 4
#define MMQ_Y_Q2_K 32
#define MOE_X_Q2_K 4
#define MOE_Y_Q2_K 32
#define NWARPS_Q2_K 4
#endif
@ -454,8 +454,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q2_K, 2)
const int exp_stride, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst,
const int top_k) {
const int mmq_x = MMQ_X_Q2_K;
const int mmq_y = MMQ_Y_Q2_K;
const int mmq_x = MOE_X_Q2_K;
const int mmq_y = MOE_Y_Q2_K;
const int nwarps = NWARPS_Q2_K;
moe_q<scalar_t, QK_K, QR2_K, QI2_K, false, block_q2_K, mmq_x, mmq_y, nwarps,
@ -472,8 +472,8 @@ static void ggml_moe_q2_K_q8_1_cuda(
const int exp_stride, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
const int tokens_post_padded, cudaStream_t stream) {
const int mmq_x = MMQ_X_Q2_K;
const int mmq_y = MMQ_Y_Q2_K;
const int mmq_x = MOE_X_Q2_K;
const int mmq_y = MOE_Y_Q2_K;
const int nwarps = NWARPS_Q2_K;
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
@ -495,12 +495,12 @@ static void ggml_moe_q2_K_q8_1_cuda(
}
#if defined(USE_ROCM)
#define MMQ_X_Q3_K 64
#define MMQ_Y_Q3_K 128
#define MOE_X_Q3_K 64
#define MOE_Y_Q3_K 128
#define NWARPS_Q3_K 8
#else
#define MMQ_X_Q3_K 4
#define MMQ_Y_Q3_K 32
#define MOE_X_Q3_K 4
#define MOE_Y_Q3_K 32
#define NWARPS_Q3_K 4
#endif
@ -516,8 +516,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q3_K, 2)
const int ncols_y, const int nrows_y, const int nrows_dst,
const int top_k) {
const int mmq_x = MMQ_X_Q3_K;
const int mmq_y = MMQ_Y_Q3_K;
const int mmq_x = MOE_X_Q3_K;
const int mmq_y = MOE_Y_Q3_K;
const int nwarps = NWARPS_Q3_K;
moe_q<scalar_t, QK_K, QR3_K, QI3_K, false, block_q3_K, mmq_x, mmq_y, nwarps,
@ -533,8 +533,8 @@ static void ggml_moe_q3_K_q8_1_cuda(
const int exp_stride, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
const int tokens_post_padded, cudaStream_t stream) {
const int mmq_x = MMQ_X_Q3_K;
const int mmq_y = MMQ_Y_Q3_K;
const int mmq_x = MOE_X_Q3_K;
const int mmq_y = MOE_Y_Q3_K;
const int nwarps = NWARPS_Q3_K;
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
@ -556,12 +556,12 @@ static void ggml_moe_q3_K_q8_1_cuda(
}
#if defined(USE_ROCM)
#define MMQ_X_Q4_K 64
#define MMQ_Y_Q4_K 128
#define MOE_X_Q4_K 64
#define MOE_Y_Q4_K 128
#define NWARPS_Q4_K 8
#else
#define MMQ_X_Q4_K 4
#define MMQ_Y_Q4_K 32
#define MOE_X_Q4_K 4
#define MOE_Y_Q4_K 32
#define NWARPS_Q4_K 4
#endif
@ -576,8 +576,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q4_K, 2)
const int exp_stride, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst,
const int top_k) {
const int mmq_x = MMQ_X_Q4_K;
const int mmq_y = MMQ_Y_Q4_K;
const int mmq_x = MOE_X_Q4_K;
const int mmq_y = MOE_Y_Q4_K;
const int nwarps = NWARPS_Q4_K;
moe_q<scalar_t, QK_K, QR4_K, QI4_K, true, block_q4_K, mmq_x, mmq_y, nwarps,
@ -594,8 +594,8 @@ static void ggml_moe_q4_K_q8_1_cuda(
const int exp_stride, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
const int tokens_post_padded, cudaStream_t stream) {
const int mmq_x = MMQ_X_Q4_K;
const int mmq_y = MMQ_Y_Q4_K;
const int mmq_x = MOE_X_Q4_K;
const int mmq_y = MOE_Y_Q4_K;
const int nwarps = NWARPS_Q4_K;
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
@ -617,12 +617,12 @@ static void ggml_moe_q4_K_q8_1_cuda(
}
#if defined(USE_ROCM)
#define MMQ_X_Q5_K 64
#define MMQ_Y_Q5_K 128
#define MOE_X_Q5_K 64
#define MOE_Y_Q5_K 128
#define NWARPS_Q5_K 8
#else
#define MMQ_X_Q5_K 4
#define MMQ_Y_Q5_K 32
#define MOE_X_Q5_K 4
#define MOE_Y_Q5_K 32
#define NWARPS_Q5_K 4
#endif
@ -637,8 +637,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q5_K, 2)
const int exp_stride, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst,
const int top_k) {
const int mmq_x = MMQ_X_Q5_K;
const int mmq_y = MMQ_Y_Q5_K;
const int mmq_x = MOE_X_Q5_K;
const int mmq_y = MOE_Y_Q5_K;
const int nwarps = NWARPS_Q5_K;
moe_q<scalar_t, QK_K, QR5_K, QI5_K, true, block_q5_K, mmq_x, mmq_y, nwarps,
@ -655,8 +655,8 @@ static void ggml_moe_q5_K_q8_1_cuda(
const int exp_stride, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
const int tokens_post_padded, cudaStream_t stream) {
const int mmq_x = MMQ_X_Q5_K;
const int mmq_y = MMQ_Y_Q5_K;
const int mmq_x = MOE_X_Q5_K;
const int mmq_y = MOE_Y_Q5_K;
const int nwarps = NWARPS_Q5_K;
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
@ -678,12 +678,12 @@ static void ggml_moe_q5_K_q8_1_cuda(
}
#if defined(USE_ROCM)
#define MMQ_X_Q6_K 64
#define MMQ_Y_Q6_K 128
#define MOE_X_Q6_K 64
#define MOE_Y_Q6_K 128
#define NWARPS_Q6_K 8
#else
#define MMQ_X_Q6_K 4
#define MMQ_Y_Q6_K 32
#define MOE_X_Q6_K 4
#define MOE_Y_Q6_K 32
#define NWARPS_Q6_K 4
#endif
@ -698,8 +698,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q6_K, 2)
const int exp_stride, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst,
const int top_k) {
const int mmq_x = MMQ_X_Q6_K;
const int mmq_y = MMQ_Y_Q6_K;
const int mmq_x = MOE_X_Q6_K;
const int mmq_y = MOE_Y_Q6_K;
const int nwarps = NWARPS_Q6_K;
moe_q<scalar_t, QK_K, QR6_K, QI6_K, false, block_q6_K, mmq_x, mmq_y, nwarps,
@ -716,8 +716,8 @@ static void ggml_moe_q6_K_q8_1_cuda(
const int exp_stride, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
const int tokens_post_padded, cudaStream_t stream) {
const int mmq_x = MMQ_X_Q6_K;
const int mmq_y = MMQ_Y_Q6_K;
const int mmq_x = MOE_X_Q6_K;
const int mmq_y = MOE_Y_Q6_K;
const int nwarps = NWARPS_Q6_K;
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;

View File

@ -14,7 +14,7 @@ __global__ void awq_marlin_repack_kernel(
int n_tiles = size_n / tile_n_size;
int block_k_tiles = div_ceil(k_tiles, gridDim.x);
int start_k_tile = blockIdx.x * block_k_tiles;
auto start_k_tile = blockIdx.x * block_k_tiles;
if (start_k_tile >= k_tiles) {
return;
}
@ -51,8 +51,8 @@ __global__ void awq_marlin_repack_kernel(
int4* sh_ptr = sh + stage_size * pipe;
if (threadIdx.x < stage_size) {
int k_id = threadIdx.x / stage_n_threads;
int n_id = threadIdx.x % stage_n_threads;
auto k_id = threadIdx.x / stage_n_threads;
auto n_id = threadIdx.x % stage_n_threads;
int first_k = k_tile_id * tile_k_size;
@ -70,8 +70,8 @@ __global__ void awq_marlin_repack_kernel(
return;
}
int warp_id = threadIdx.x / 32;
int th_id = threadIdx.x % 32;
auto warp_id = threadIdx.x / 32;
auto th_id = threadIdx.x % 32;
if (warp_id >= 4) {
return;
@ -265,4 +265,4 @@ TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, Meta, m) {
m.impl("awq_marlin_repack", &awq_marlin_repack_meta);
}
}

View File

@ -460,7 +460,7 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
int const* __restrict__ perm_int_ptr,
int4* __restrict__ out_int4_ptr, int size_m,
int size_k, int lda, int block_rows) {
int start_row = block_rows * blockIdx.x;
auto start_row = block_rows * blockIdx.x;
int finish_row = start_row + block_rows;
if (finish_row > size_m) {
finish_row = size_m;
@ -484,7 +484,7 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
int base_k = 0;
for (int i = 0; i < iters; i++) {
int cur_k = base_k + threadIdx.x;
auto cur_k = base_k + threadIdx.x;
int src_pos = perm_int_ptr[cur_k];
out_half[cur_k] = a_row_half[src_pos];
@ -494,7 +494,7 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
if (rest) {
if (threadIdx.x < rest) {
int cur_k = base_k + threadIdx.x;
auto cur_k = base_k + threadIdx.x;
int src_pos = perm_int_ptr[cur_k];
out_half[cur_k] = a_row_half[src_pos];
@ -723,8 +723,8 @@ __global__ void Marlin(
(threadIdx.x % b_sh_stride_threads) * b_thread_vecs;
b_gl_rd += b_sh_stride * slice_col;
b_gl_rd += b_gl_rd_delta_o * slice_row;
int b_sh_wr = threadIdx.x * b_thread_vecs;
int b_sh_rd = threadIdx.x * b_thread_vecs;
auto b_sh_wr = threadIdx.x * b_thread_vecs;
auto b_sh_rd = threadIdx.x * b_thread_vecs;
// For act_order
constexpr int k_iter_size = tb_k / b_sh_wr_iters;
@ -743,7 +743,7 @@ __global__ void Marlin(
s_sh_stride * slice_col + threadIdx.x;
}
}
int s_sh_wr = threadIdx.x;
auto s_sh_wr = threadIdx.x;
bool s_sh_wr_pred = threadIdx.x < s_sh_stride;
// Zero-points
@ -756,7 +756,7 @@ __global__ void Marlin(
zp_sh_stride * slice_col + threadIdx.x;
}
}
int zp_sh_wr = threadIdx.x;
auto zp_sh_wr = threadIdx.x;
bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride;
// We use a different scale layout for grouped and column-wise quantization as
@ -1047,7 +1047,7 @@ __global__ void Marlin(
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
} else {
int warp_id = threadIdx.x / 32;
auto warp_id = threadIdx.x / 32;
int n_warps = thread_n_blocks / 4;
int warp_row = warp_id / n_warps;
@ -1085,7 +1085,7 @@ __global__ void Marlin(
// Determine "position" inside the thread-block (based on warp and
// thread-id)
int warp_id = threadIdx.x / 32;
auto warp_id = threadIdx.x / 32;
int n_warps =
thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N
@ -1094,7 +1094,7 @@ __global__ void Marlin(
cur_k += warp_row * 16;
int th_id = threadIdx.x % 32;
auto th_id = threadIdx.x % 32;
cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix
int s_col_shift =
@ -1159,7 +1159,7 @@ __global__ void Marlin(
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
}
} else {
int warp_id = threadIdx.x / 32;
auto warp_id = threadIdx.x / 32;
int n_warps = thread_n_blocks / 4;
int warp_row = warp_id / n_warps;
@ -1197,7 +1197,7 @@ __global__ void Marlin(
(pipe / (group_blocks / thread_k_blocks)));
reinterpret_cast<int4*>(&frag_zpf[k % 2])[0] = sh_zp_stage[zp_sh_rd];
} else {
int warp_id = threadIdx.x / 32;
auto warp_id = threadIdx.x / 32;
int n_warps = thread_n_blocks / 4;
int warp_row = warp_id / n_warps;
@ -1323,7 +1323,7 @@ __global__ void Marlin(
auto thread_block_reduce = [&]() {
constexpr int red_off = threads / b_sh_stride_threads / 2;
if (red_off >= 1) {
int red_idx = threadIdx.x / b_sh_stride_threads;
auto red_idx = threadIdx.x / b_sh_stride_threads;
constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2;
constexpr int red_sh_delta = b_sh_stride_threads;
int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) +
@ -1390,7 +1390,7 @@ __global__ void Marlin(
4 * (threadIdx.x / 32) + threadIdx.x % 4;
c_gl_wr += (2 * thread_n_blocks) * slice_col;
constexpr int c_sh_wr_delta = active_threads;
int c_sh_wr = threadIdx.x;
auto c_sh_wr = threadIdx.x;
int row = (threadIdx.x % 32) / 4;

View File

@ -15,7 +15,7 @@ __global__ void gptq_marlin_repack_kernel(
int n_tiles = size_n / tile_n_size;
int block_k_tiles = div_ceil(k_tiles, gridDim.x);
int start_k_tile = blockIdx.x * block_k_tiles;
auto start_k_tile = blockIdx.x * block_k_tiles;
if (start_k_tile >= k_tiles) {
return;
}
@ -71,8 +71,8 @@ __global__ void gptq_marlin_repack_kernel(
if constexpr (has_perm) {
if (threadIdx.x < stage_size) {
int k_id = threadIdx.x / stage_n_threads;
int n_id = threadIdx.x % stage_n_threads;
auto k_id = threadIdx.x / stage_n_threads;
auto n_id = threadIdx.x % stage_n_threads;
uint32_t const* sh_perm_int_ptr =
reinterpret_cast<uint32_t const*>(sh_perm_ptr);
@ -88,8 +88,8 @@ __global__ void gptq_marlin_repack_kernel(
} else {
if (threadIdx.x < stage_size) {
int k_id = threadIdx.x / stage_n_threads;
int n_id = threadIdx.x % stage_n_threads;
auto k_id = threadIdx.x / stage_n_threads;
auto n_id = threadIdx.x % stage_n_threads;
int first_k = k_tile_id * tile_k_size;
int first_k_packed = first_k / pack_factor;
@ -109,8 +109,8 @@ __global__ void gptq_marlin_repack_kernel(
return;
}
int warp_id = threadIdx.x / 32;
int th_id = threadIdx.x % 32;
auto warp_id = threadIdx.x / 32;
auto th_id = threadIdx.x % 32;
if (warp_id >= 4) {
return;
@ -339,4 +339,4 @@ TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, Meta, m) {
m.impl("gptq_marlin_repack", &gptq_marlin_repack_meta);
}
}

View File

@ -277,12 +277,12 @@ __global__ void Marlin(
b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride);
b_gl_rd += b_sh_stride * slice_col;
b_gl_rd += b_gl_rd_delta_o * slice_row;
int b_sh_wr = threadIdx.x;
int b_sh_rd = threadIdx.x;
auto b_sh_wr = threadIdx.x;
auto b_sh_rd = threadIdx.x;
int s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
s_sh_stride * slice_col + threadIdx.x;
int s_sh_wr = threadIdx.x;
auto s_sh_wr = threadIdx.x;
int s_sh_rd;
// We use a different scale layout for grouped and column-wise quantization as
// we scale a `half2` tile in column-major layout in the former and in
@ -455,7 +455,7 @@ __global__ void Marlin(
auto thread_block_reduce = [&]() {
constexpr int red_off = threads / b_sh_stride / 2;
if (red_off >= 1) {
int red_idx = threadIdx.x / b_sh_stride;
auto red_idx = threadIdx.x / b_sh_stride;
constexpr int red_sh_stride = b_sh_stride * 4 * 2;
constexpr int red_sh_delta = b_sh_stride;
int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) +
@ -522,7 +522,7 @@ __global__ void Marlin(
4 * (threadIdx.x / 32) + threadIdx.x % 4;
c_gl_wr += (2 * thread_n_blocks) * slice_col;
constexpr int c_sh_wr_delta = active_threads;
int c_sh_wr = threadIdx.x;
auto c_sh_wr = threadIdx.x;
int row = (threadIdx.x % 32) / 4;

View File

@ -353,10 +353,10 @@ __global__ void Marlin(
b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride);
b_gl_rd += b_sh_stride * slice_col;
b_gl_rd += b_gl_rd_delta_o * slice_row;
int b_sh_wr = threadIdx.x;
int b_sh_rd = threadIdx.x;
auto b_sh_wr = threadIdx.x;
auto b_sh_rd = threadIdx.x;
int s_tok_gl_rd = threadIdx.x;
auto s_tok_gl_rd = threadIdx.x;
// NOTE(HandH1998): activation scale s_tok need shuffle to [0, 8, 1, 9, 2, 10,
// 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] for example, 0, 8 row scales serve for
// thread 0, 1, 2, 3. For more details, refer to mma operand A layout as
@ -368,8 +368,8 @@ __global__ void Marlin(
int s_tok_sh_rd = (threadIdx.x % 32) / 4;
bool s_tok_sh_wr_pred = threadIdx.x < prob_m;
int s_ch_gl_rd = s_ch_sh_stride * slice_col + threadIdx.x;
int s_ch_sh_wr = threadIdx.x;
auto s_ch_gl_rd = s_ch_sh_stride * slice_col + threadIdx.x;
auto s_ch_sh_wr = threadIdx.x;
int s_ch_sh_rd = 16 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
2 * ((threadIdx.x % 32) % 4);
bool s_ch_sh_wr_pred = threadIdx.x < s_ch_sh_stride;
@ -558,7 +558,7 @@ __global__ void Marlin(
auto thread_block_reduce = [&]() {
constexpr int red_off = threads / b_sh_stride / 2;
if (red_off >= 1) {
int red_idx = threadIdx.x / b_sh_stride;
auto red_idx = threadIdx.x / b_sh_stride;
constexpr int red_sh_stride = b_sh_stride * 4 * 2;
constexpr int red_sh_delta = b_sh_stride;
int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) +
@ -628,7 +628,7 @@ __global__ void Marlin(
8 * (threadIdx.x / 32) + (threadIdx.x % 4) * 2;
c_gl_wr += (4 * thread_n_blocks) * slice_col;
constexpr int c_sh_wr_delta = active_threads * 2;
int c_sh_wr = 2 * threadIdx.x;
auto c_sh_wr = 2 * threadIdx.x;
int row = (threadIdx.x % 32) / 4;

View File

@ -273,15 +273,15 @@ __global__ void Marlin_24(
(threadIdx.x % b_sh_stride_threads) * b_thread_vecs;
b_gl_rd += b_sh_stride * slice_col;
b_gl_rd += b_gl_rd_delta_o * slice_row;
int b_sh_wr = threadIdx.x * b_thread_vecs;
int b_sh_rd = threadIdx.x * b_thread_vecs;
auto b_sh_wr = threadIdx.x * b_thread_vecs;
auto b_sh_rd = threadIdx.x * b_thread_vecs;
int m_gl_rd = m_gl_stride * (threadIdx.x / (m_sh_stride)) +
(threadIdx.x % (m_sh_stride));
m_gl_rd += (m_sh_stride)*slice_col;
m_gl_rd += m_gl_rd_delta_o * slice_row;
int m_sh_wr = threadIdx.x;
int m_sh_rd = threadIdx.x % 16 + (threadIdx.x / 32) * 16;
auto m_sh_wr = threadIdx.x;
auto m_sh_rd = threadIdx.x % 16 + (threadIdx.x / 32) * 16;
int s_gl_rd;
if constexpr (group_blocks == -1) {
@ -291,7 +291,7 @@ __global__ void Marlin_24(
s_sh_stride * slice_col + threadIdx.x;
}
int s_sh_wr = threadIdx.x;
auto s_sh_wr = threadIdx.x;
int s_sh_rd;
// We use a different scale layout for grouped and column-wise quantization as
// we scale a `half2` tile in column-major layout in the former and in
@ -516,7 +516,7 @@ __global__ void Marlin_24(
auto thread_block_reduce = [&]() {
constexpr int red_off = threads / b_sh_stride_threads / 2;
if (red_off >= 1) {
int red_idx = threadIdx.x / b_sh_stride_threads;
auto red_idx = threadIdx.x / b_sh_stride_threads;
constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2;
constexpr int red_sh_delta = b_sh_stride_threads;
int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) +
@ -583,7 +583,7 @@ __global__ void Marlin_24(
8 * (threadIdx.x / 32) + (threadIdx.x % 32) / 4;
c_gl_wr += (2 * thread_n_blocks) * slice_col;
constexpr int c_sh_wr_delta = active_threads;
int c_sh_wr = threadIdx.x;
auto c_sh_wr = threadIdx.x;
int col = 2 * ((threadIdx.x % 32) % 4);

View File

@ -284,18 +284,18 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
int max_ctx_blocks, const float* k_scale, const float* v_scale) {
// clang-format on
constexpr int NWARPS = NUM_THREADS / WARP_SIZE;
const int warpid = threadIdx.x / WARP_SIZE;
const int laneid = threadIdx.x % WARP_SIZE;
const auto warpid = threadIdx.x / WARP_SIZE;
const auto laneid = threadIdx.x % WARP_SIZE;
const int lane4id = laneid % 4;
const int lane16id = laneid % 16;
const int rowid = laneid / 16;
const int seq_idx = blockIdx.x;
const int partition_idx = blockIdx.y;
const auto seq_idx = blockIdx.x;
const auto partition_idx = blockIdx.y;
constexpr int T_PAR_SIZE = 256; // token partition size set to 256
const int max_num_partitions = gridDim.y;
const auto max_num_partitions = gridDim.y;
const int context_len = context_lens[seq_idx];
@ -346,9 +346,9 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
// can be interpreted as B8x16 for 8 bit types
_B16x8 Klocal[TLOOP][QKHELOOP];
const int wg_start_head_idx = blockIdx.z * GQA_RATIO;
const int wg_start_kv_head_idx = blockIdx.z;
const int total_num_heads = gridDim.z * GQA_RATIO;
const auto wg_start_head_idx = blockIdx.z * GQA_RATIO;
const auto wg_start_kv_head_idx = blockIdx.z;
const auto total_num_heads = gridDim.z * GQA_RATIO;
// for QK mfma, tokens in multiples of TOKENS_PER_WARP are spread across warps
// each mfma takes QH16xT16x16HE across warp
@ -789,14 +789,14 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
int max_ctx_blocks, const float* k_scale, const float* v_scale) {
// clang-format on
constexpr int NWARPS = NUM_THREADS / WARP_SIZE;
const int warpid = threadIdx.x / WARP_SIZE;
const int laneid = threadIdx.x % WARP_SIZE;
const auto warpid = threadIdx.x / WARP_SIZE;
const auto laneid = threadIdx.x % WARP_SIZE;
const int lane4id = laneid % 4;
const int seq_idx = blockIdx.x;
const int partition_idx = blockIdx.y;
const int partition_size = blockDim.x;
const int max_num_partitions = gridDim.y;
const auto seq_idx = blockIdx.x;
const auto partition_idx = blockIdx.y;
const auto partition_size = blockDim.x;
const auto max_num_partitions = gridDim.y;
const int context_len = context_lens[seq_idx];
const int partition_start_token_idx = partition_idx * partition_size;
@ -838,8 +838,8 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
qk_max[h] = -FLT_MAX;
}
const int wg_start_head_idx = blockIdx.z * GQA_RATIO;
const int wg_start_kv_head_idx = blockIdx.z;
const auto wg_start_head_idx = blockIdx.z * GQA_RATIO;
const auto wg_start_kv_head_idx = blockIdx.z;
const int warp_start_token_idx =
partition_start_token_idx + warpid * WARP_SIZE;
@ -857,7 +857,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
// token id within partition
const int local_token_idx = threadIdx.x;
const auto local_token_idx = threadIdx.x;
// token id within sequence
const int global_token_idx = partition_start_token_idx + local_token_idx;
@ -1126,7 +1126,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
__syncthreads();
const int num_heads = gridDim.z * GQA_RATIO;
const auto num_heads = gridDim.z * GQA_RATIO;
float* max_logits_ptr =
max_logits + seq_idx * num_heads * max_num_partitions + partition_idx;
float* exp_sums_ptr =
@ -1268,14 +1268,14 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
// max_num_partitions, head_size]
const int* __restrict__ context_lens, // [num_seqs]
const int max_num_partitions) {
const int num_heads = gridDim.x;
const int head_idx = blockIdx.x;
const int seq_idx = blockIdx.y;
const auto num_heads = gridDim.x;
const auto head_idx = blockIdx.x;
const auto seq_idx = blockIdx.y;
const int context_len = context_lens[seq_idx];
const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE);
[[maybe_unused]] constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
const int warpid = threadIdx.x / WARP_SIZE;
[[maybe_unused]] const int laneid = threadIdx.x % WARP_SIZE;
const auto warpid = threadIdx.x / WARP_SIZE;
[[maybe_unused]] const auto laneid = threadIdx.x % WARP_SIZE;
__shared__ float shared_global_exp_sum;
// max num partitions supported is warp_size * NPAR_LOOPS
@ -1294,7 +1294,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
#pragma unroll
for (int i = 0; i < NPAR_LOOPS; i++) {
const int partition_no = i * WARP_SIZE + threadIdx.x;
const auto partition_no = i * WARP_SIZE + threadIdx.x;
valid_partition[i] =
(partition_no < num_partitions) ? partition_no : last_valid_partition;
}
@ -1324,7 +1324,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
}
#pragma unroll
for (int i = 0; i < NPAR_LOOPS; i++) {
const int partition_no = i * WARP_SIZE + threadIdx.x;
const auto partition_no = i * WARP_SIZE + threadIdx.x;
rescaled_exp_sum[i] *= (partition_no < num_partitions)
? expf(reg_max_logit[i] - max_logit)
: 0.0f;
@ -1336,7 +1336,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
}
#pragma unroll
for (int i = 0; i < NPAR_LOOPS; i++) {
const int partition_no = i * WARP_SIZE + threadIdx.x;
const auto partition_no = i * WARP_SIZE + threadIdx.x;
shared_exp_sums[partition_no] = rescaled_exp_sum[i];
}

View File

@ -111,6 +111,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
") -> ()");
ops.impl("advance_step_flashinfer", torch::kCUDA, &advance_step_flashinfer);
ops.def(
"block_table_appends(Tensor append_row_indices, "
"Tensor append_row_indices_cpu, Tensor append_cumsums, "
"Tensor append_cumsums_cpu, Tensor append_block_ids, "
"Tensor append_block_ids_cpu, Tensor! block_table, int num_appends, "
"int total_num_append_blocks) -> ()");
ops.impl("block_table_appends", torch::kCUDA, &block_table_appends);
ops.def(
"block_table_moves(Tensor src_dst_n, Tensor src_dst_n_cpu, "
"Tensor! block_table, int num_moves) -> ()");
ops.impl("block_table_moves", torch::kCUDA, &block_table_moves);
// Layernorm
// Apply Root Mean Square (RMS) Normalization to the input tensor.
ops.def(

View File

@ -129,6 +129,9 @@ in progress.
- **Spec Decode**: Currently, only ngram-based spec decode is supported in V1. There
will be follow-up work to support other types of spec decode (e.g., see [PR #13933](https://github.com/vllm-project/vllm/pull/13933)). We will prioritize the support for Eagle, MTP compared to draft model based spec decode.
- **Multimodal Models**: V1 is almost fully compatible with V0 except that interleaved modality input is not supported yet.
See [here](https://github.com/orgs/vllm-project/projects/8) for the status of upcoming features and optimizations.
#### Features to Be Supported
- **FP8 KV Cache**: While vLLM V1 introduces new FP8 kernels for model weight quantization, support for an FP8 keyvalue cache is not yet available. Users must continue using FP16 (or other supported precisions) for the KV cache.

View File

@ -361,6 +361,7 @@ def run_llava_next_video(questions: list[str],
engine_args = EngineArgs(
model="llava-hf/LLaVA-NeXT-Video-7B-hf",
max_model_len=8192,
max_num_seqs=2,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
)

View File

@ -4,14 +4,14 @@
# Dependencies for CPUs
torch==2.6.0+cpu; platform_machine == "x86_64"
torch==2.6.0; platform_system == "Darwin"
torch==2.5.1; platform_machine == "ppc64le" or platform_machine == "aarch64"
torch==2.6.0; platform_machine == "ppc64le" or platform_machine == "aarch64"
torch==2.7.0.dev20250304; platform_machine == "s390x"
# required for the image processor of minicpm-o-2_6, this must be updated alongside torch
torchaudio; platform_machine != "ppc64le" and platform_machine != "s390x"
torchaudio==2.5.1; platform_machine == "ppc64le"
torchaudio==2.6.0; platform_machine == "ppc64le"
# required for the image processor of phi3v, this must be updated alongside torch
torchvision; platform_machine != "ppc64le" and platform_machine != "s390x"
torchvision==0.20.1; platform_machine == "ppc64le"
torchvision==0.21.0; platform_machine == "ppc64le"
datasets # for benchmark scripts

View File

@ -749,3 +749,72 @@ def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size,
ops.gather_cache(src_cache, dst, block_table, cu_seq_lens, batch_size)
torch.testing.assert_close(dst, expected)
@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS_MLA)
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.cpu_model
@pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only")
@torch.inference_mode()
def test_concat_and_cache_mla_cpu(
kv_lora_rank: int,
qk_rope_head_dim: int,
num_tokens: int,
block_size: int,
num_blocks: int,
dtype: torch.dtype,
seed: int,
) -> None:
device = "cpu"
kv_cache_dtype = "auto"
current_platform.seed_everything(seed)
torch.set_default_device(device)
total_slots = num_blocks * block_size
slot_mapping_lst = random.sample(range(total_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping_lst,
dtype=torch.long,
device=device)
kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device)
k_pe = torch.randn(num_tokens,
qk_rope_head_dim,
dtype=dtype,
device=device)
entry_size = kv_lora_rank + qk_rope_head_dim
scale = torch.tensor(0.1, dtype=torch.float32, device=device)
kv_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
kv_cache_dtype, device)
ref_temp = torch.zeros(*kv_cache.shape, dtype=dtype, device=device)
for i in range(num_tokens):
slot = slot_mapping[i].item()
block_idx = slot // block_size
block_offset = slot % block_size
ref_temp[block_idx, block_offset, :kv_lora_rank] = kv_c[i]
ref_temp[block_idx, block_offset, kv_lora_rank:] = k_pe[i]
if kv_cache_dtype == "fp8":
ref_kv_cache = torch.empty_like(ref_temp, dtype=kv_cache.dtype)
ops.convert_fp8(ref_kv_cache,
ref_temp,
scale.item(),
kv_dtype=kv_cache_dtype)
else:
ref_kv_cache = ref_temp
opcheck(
torch.ops._C_cache_ops.concat_and_cache_mla,
(kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale),
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
)
ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping,
kv_cache_dtype, scale)
torch.testing.assert_close(kv_cache, ref_kv_cache)

View File

@ -0,0 +1,94 @@
# SPDX-License-Identifier: Apache-2.0
import pytest
import torch
import torch.nn.functional as F
from torch import Tensor
import vllm._custom_ops as ops
from vllm.platforms import current_platform
def cdiv(a, b):
return (a + b - 1) // b
def ref_mla(
out: Tensor, # (bs, num_heads, v_head_dim)
query: Tensor, # (bs, num_heads, head_dim)
kv_cache: Tensor, # (num_blocks, block_size, head_dim)
scale: float,
block_tables: Tensor, # (bs, max_num_blocks)
seq_lens: Tensor, # (bs,)
):
bs, num_heads, v_head_dim = out.shape
head_dim = query.shape[2]
for i in range(bs):
# gather and flatten KV-cache
kv = kv_cache[
block_tables[i]] # (max_num_blocks, block_size, head_dim)
kv = kv.view(1, -1,
head_dim)[:, :seq_lens[i]] # (1, seq_len, head_dim)
v = kv[:, :, :v_head_dim]
q = query[i].view(num_heads, 1, head_dim)
o = F.scaled_dot_product_attention(q,
kv,
v,
scale=scale,
enable_gqa=True)
out[i] = o.view(num_heads, v_head_dim)
return out
@pytest.mark.parametrize("bs", [4])
@pytest.mark.parametrize("mean_seq_len", [256])
@pytest.mark.parametrize("h_q", [16])
@pytest.mark.parametrize("d", [576])
@pytest.mark.parametrize("dv", [512])
@pytest.mark.parametrize("block_size", [16])
@pytest.mark.parametrize("dtype", [torch.float, torch.half, torch.bfloat16])
@pytest.mark.parametrize("varlen", [False, True])
@pytest.mark.cpu_model
@pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only")
def test_mla_decode_cpu(
bs: int,
mean_seq_len: int,
h_q: int,
d: int,
dv: int,
block_size: int,
dtype: torch.dtype,
varlen: bool,
):
torch.set_default_dtype(dtype)
torch.manual_seed(0)
scale = d**(-0.5)
if varlen:
seq_lens = torch.empty(bs).normal_(mean_seq_len, mean_seq_len / 2)
seq_lens = seq_lens.clip(2).to(torch.int32)
else:
seq_lens = torch.full((bs, ), mean_seq_len, dtype=torch.int32)
max_seq_len = seq_lens.max().item()
seqlen_pad = cdiv(max_seq_len, 256) * 256 # is this necessary?
q = torch.randn(bs, h_q, d)
block_table = torch.arange(bs * seqlen_pad // block_size,
dtype=torch.int32)
block_table = block_table.view(bs, seqlen_pad // block_size)
kv_cache = torch.randn(block_table.numel(), block_size, d)
for i, seq_len in enumerate(seq_lens.tolist()):
kv_cache.view(bs, seqlen_pad, d)[i, seq_len:] = float("nan")
out_mla = q.new_zeros(bs, h_q, dv)
ops.mla_decode_kvcache_cpu(out_mla, q, kv_cache, scale, block_table,
seq_lens)
out_ref = q.new_zeros(bs, h_q, dv)
ref_mla(out_ref, q, kv_cache, scale, block_table, seq_lens)
assert not out_mla.isnan().any(), "Likely read out of bounds"
torch.testing.assert_close(out_mla, out_ref)

View File

@ -163,24 +163,24 @@ VLM_TEST_SETTINGS = {
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
),
#### Extended model tests
# "aria": VLMTestInfo(
# models=["rhymes-ai/Aria"],
# test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
# prompt_formatter=lambda img_prompt: f"<|im_start|>user\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n ", # noqa: E501
# img_idx_to_prompt=lambda idx: "<fim_prefix><|img|><fim_suffix>\n",
# max_model_len=4096,
# max_num_seqs=2,
# auto_cls=AutoModelForImageTextToText,
# single_image_prompts=IMAGE_ASSETS.prompts({
# "stop_sign": "<vlm_image>Please describe the image shortly.",
# "cherry_blossom": "<vlm_image>Please infer the season with reason.", # noqa: E501
# }),
# multi_image_prompt="<vlm_image><vlm_image>Describe the two images shortly.", # noqa: E501
# stop_str=["<|im_end|>"],
# image_size_factors=[(0.10, 0.15)],
# max_tokens=64,
# marks=[large_gpu_mark(min_gb=64)],
# ),
"aria": VLMTestInfo(
models=["rhymes-ai/Aria"],
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
prompt_formatter=lambda img_prompt: f"<|im_start|>user\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n ", # noqa: E501
img_idx_to_prompt=lambda idx: "<fim_prefix><|img|><fim_suffix>\n",
max_model_len=4096,
max_num_seqs=2,
auto_cls=AutoModelForImageTextToText,
single_image_prompts=IMAGE_ASSETS.prompts({
"stop_sign": "<vlm_image>Please describe the image shortly.",
"cherry_blossom": "<vlm_image>Please infer the season with reason.", # noqa: E501
}),
multi_image_prompt="<vlm_image><vlm_image>Describe the two images shortly.", # noqa: E501
stop_str=["<|im_end|>"],
image_size_factors=[(0.10, 0.15)],
max_tokens=64,
marks=[large_gpu_mark(min_gb=64)],
),
"blip2": VLMTestInfo(
models=["Salesforce/blip2-opt-2.7b"],
test_type=VLMTestType.IMAGE,
@ -352,6 +352,7 @@ VLM_TEST_SETTINGS = {
prompt_formatter=lambda vid_prompt: f"USER: {vid_prompt} ASSISTANT:",
num_video_frames=16,
max_model_len=4096,
max_num_seqs=2,
auto_cls=AutoModelForVision2Seq,
vllm_output_post_proc=model_utils.llava_video_vllm_to_hf_output,
),
@ -384,7 +385,7 @@ VLM_TEST_SETTINGS = {
),
"minicpmo_26": VLMTestInfo(
models=["openbmb/MiniCPM-o-2_6"],
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
test_type=(VLMTestType.IMAGE),
prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501
img_idx_to_prompt=lambda idx: "(<image>./</image>)\n",
max_model_len=4096,
@ -393,9 +394,21 @@ VLM_TEST_SETTINGS = {
hf_output_post_proc=model_utils.minicpmv_trunc_hf_output,
patch_hf_runner=model_utils.minicpmo_26_patch_hf_runner,
),
"minicpmo_26_multi_image": VLMTestInfo(
models=["openbmb/MiniCPM-o-2_6"],
test_type=(VLMTestType.MULTI_IMAGE),
prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501
img_idx_to_prompt=lambda idx: "(<image>./</image>)\n",
max_model_len=4096,
max_num_seqs=2,
get_stop_token_ids=lambda tok: tok.convert_tokens_to_ids(['<|im_end|>', '<|endoftext|>']), # noqa: E501
hf_output_post_proc=model_utils.minicpmv_trunc_hf_output,
patch_hf_runner=model_utils.minicpmo_26_patch_hf_runner,
marks=[large_gpu_mark(min_gb=32)],
),
"minicpmv_26": VLMTestInfo(
models=["openbmb/MiniCPM-V-2_6"],
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
test_type=(VLMTestType.IMAGE),
prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501
img_idx_to_prompt=lambda idx: "(<image>./</image>)\n",
max_model_len=4096,
@ -404,6 +417,18 @@ VLM_TEST_SETTINGS = {
hf_output_post_proc=model_utils.minicpmv_trunc_hf_output,
patch_hf_runner=model_utils.minicpmv_26_patch_hf_runner,
),
"minicpmv_26_multi_image": VLMTestInfo(
models=["openbmb/MiniCPM-V-2_6"],
test_type=(VLMTestType.MULTI_IMAGE),
prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501
img_idx_to_prompt=lambda idx: "(<image>./</image>)\n",
max_model_len=4096,
max_num_seqs=2,
get_stop_token_ids=lambda tok: tok.convert_tokens_to_ids(['<|im_end|>', '<|endoftext|>']), # noqa: E501
hf_output_post_proc=model_utils.minicpmv_trunc_hf_output,
patch_hf_runner=model_utils.minicpmv_26_patch_hf_runner,
marks=[large_gpu_mark(min_gb=32)],
),
"molmo": VLMTestInfo(
models=["allenai/Molmo-7B-D-0924"],
test_type=(VLMTestType.IMAGE),

View File

@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
import copy
from functools import partial
from typing import Optional, Union
@ -29,7 +28,7 @@ def _test_processing_correctness(
hit_rate: float,
num_batches: int,
simplify_rate: float,
ignore_mm_keys: Optional[list[str]] = None,
ignore_mm_keys: Optional[set[str]] = None,
):
model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
model_info.check_available_online(on_fail="skip")
@ -145,7 +144,7 @@ def _test_processing_correctness_hf(
baseline_processor: BaseMultiModalProcessor,
cached_processor: BaseMultiModalProcessor,
batch_idx: int,
ignore_mm_keys: Optional[list[str]] = None,
ignore_mm_keys: Optional[set[str]] = None,
):
if model_config.hf_config.model_type in ("mllama", "whisper", "ultravox"):
# For some multimodal models, tokenizer will always add bos_token
@ -167,11 +166,12 @@ def _test_processing_correctness_hf(
hf_processor_mm_kwargs={},
)
assert _inputs_equal(
_assert_inputs_equal(
baseline_result,
cached_result,
ignore_mm_keys,
), f"Failed ({batch_idx=}, {prompt=}, {mm_data=})"
ignore_mm_keys=ignore_mm_keys,
msg=f"Failed ({batch_idx=}, {prompt=}, {mm_data=})",
)
baseline_tokenized_result = baseline_processor.apply(
token_prompt,
@ -179,11 +179,12 @@ def _test_processing_correctness_hf(
hf_processor_mm_kwargs={},
)
assert _inputs_equal(
_assert_inputs_equal(
baseline_result,
baseline_tokenized_result,
ignore_mm_keys,
), f"Failed ({batch_idx=}, {prompt=}, {mm_data=})"
ignore_mm_keys=ignore_mm_keys,
msg=f"Failed ({batch_idx=}, {prompt=}, {mm_data=})",
)
cached_tokenized_result = cached_processor.apply(
token_prompt,
@ -191,11 +192,12 @@ def _test_processing_correctness_hf(
hf_processor_mm_kwargs={},
)
assert _inputs_equal(
_assert_inputs_equal(
cached_result,
cached_tokenized_result,
ignore_mm_keys,
), f"Failed ({batch_idx=}, {prompt=}, {mm_data=})"
ignore_mm_keys=ignore_mm_keys,
msg=f"Failed ({batch_idx=}, {prompt=}, {mm_data=})",
)
def _test_processing_correctness_mistral(
@ -206,7 +208,7 @@ def _test_processing_correctness_mistral(
baseline_processor: BaseMultiModalProcessor,
cached_processor: BaseMultiModalProcessor,
batch_idx: int,
ignore_mm_keys: Optional[list[str]] = None,
ignore_mm_keys: Optional[set[str]] = None,
):
images = mm_data.get("image", [])
if not isinstance(images, list):
@ -233,11 +235,12 @@ def _test_processing_correctness_mistral(
hf_processor_mm_kwargs={},
)
assert _inputs_equal(
_assert_inputs_equal(
baseline_tokenized_result,
cached_tokenized_result,
ignore_mm_keys,
), f"Failed ({batch_idx=}, {prompt=}, {mm_data=})"
ignore_mm_keys=ignore_mm_keys,
msg=f"Failed ({batch_idx=}, {prompt=}, {mm_data=})",
)
# yapf: disable
@ -261,6 +264,7 @@ def _test_processing_correctness_mistral(
"TIGER-Lab/Mantis-8B-siglip-llama3",
"mistralai/Pixtral-12B-2409",
"mistral-community/pixtral-12b",
"openbmb/MiniCPM-Llama3-V-2_5",
"openbmb/MiniCPM-o-2_6",
"openbmb/MiniCPM-V-2_6",
"allenai/Molmo-7B-D-0924",
@ -290,7 +294,7 @@ def test_processing_correctness(
# In Ultravox, the audio_features can be different depending on padding
# The slight difference should not be a problem though, since
# attention_mask lets us ignore the difference.
ignore_mm_keys = ['audio_features']
ignore_mm_keys = {"audio_features"}
_test_processing_correctness(
model_id,
@ -328,38 +332,26 @@ def test_processing_correctness_phi3v(
)
def _inputs_equal(
def _assert_inputs_equal(
a: MultiModalInputs,
b: MultiModalInputs,
ignore_mm_keys: Optional[list[str]] = None,
*,
ignore_mm_keys: Optional[set[str]] = None,
msg: str = "",
):
return _drop_mm_kwargs_keys(a, ignore_mm_keys) == _drop_mm_kwargs_keys(
b, ignore_mm_keys)
if ignore_mm_keys is None:
ignore_mm_keys = set()
if msg is None:
assert "mm_kwargs" in a and "mm_kwargs" in b
else:
assert "mm_kwargs" in a and "mm_kwargs" in b, msg
def _drop_mm_kwargs_keys(
result: MultiModalInputs,
ignore_mm_keys: Optional[list[str]] = None,
) -> MultiModalInputs:
"""Drop specified keys from result['mm_kwargs'].
for key in ignore_mm_keys:
a["mm_kwargs"].pop(key, None)
b["mm_kwargs"].pop(key, None)
This is mainly to avoid doing exact match of audio_features in ultravox.
Args:
result: Result to drop keys from
ignore_mm_keys: List of keys to ignore, e.g. ['audio_features']
"""
if not ignore_mm_keys:
return result
if 'mm_kwargs' in result:
result = copy.deepcopy(result)
mm_kwargs = result['mm_kwargs']
for key in ignore_mm_keys:
mm_kwargs.pop(key, None)
for items in mm_kwargs._items_by_modality.values():
for item in items:
for key in ignore_mm_keys:
item.pop(key, None)
return result
if msg is None:
assert a == b
else:
assert a == b, msg

View File

View File

@ -0,0 +1,307 @@
# SPDX-License-Identifier: Apache-2.0
import unittest.mock as mock
import pytest
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig
from vllm.sampling_params import SamplingParams
from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
SchedulerOutput)
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.worker.tpu_model_runner import TPUModelRunner
# Mock torch_xla module since it may not be available in the test environments
torch_xla_patcher = mock.patch.dict(
"sys.modules", {
"torch_xla": mock.MagicMock(),
"torch_xla.core.xla_model": mock.MagicMock(),
"torch_xla.runtime": mock.MagicMock(),
})
torch_xla_patcher.start()
# Mock the PallasAttentionBackend
pallas_attention_backend_patcher = mock.patch(
"vllm.v1.worker.tpu_model_runner.PallasAttentionBackend", )
pallas_attention_backend_patcher.start()
@pytest.fixture
def model_runner():
# Patchers have already been started at module level.
scheduler_config = SchedulerConfig(
max_num_seqs=10,
max_num_batched_tokens=512,
max_model_len=512,
)
model_config = ModelConfig(
model="facebook/opt-125m",
task="generate",
tokenizer="facebook/opt-125m",
tokenizer_mode="auto",
trust_remote_code=True,
dtype="bfloat16", # TPUs typically use bfloat16
seed=42,
)
cache_config = CacheConfig(
block_size=16,
gpu_memory_utilization=0.9,
swap_space=0,
cache_dtype="auto",
)
vllm_config = VllmConfig(
model_config=model_config,
cache_config=cache_config,
scheduler_config=scheduler_config,
)
device = "xla:0" # Mocking TPU device
with mock.patch("vllm.v1.worker.tpu_model_runner.torch"), \
mock.patch("vllm.v1.worker.tpu_model_runner.xm"), \
mock.patch("vllm.v1.worker.tpu_model_runner.xr"):
return TPUModelRunner(vllm_config, device)
@pytest.fixture(autouse=True, scope="session")
def cleanup_patches():
yield
torch_xla_patcher.stop()
pallas_attention_backend_patcher.stop()
def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
new_reqs = []
num_scheduled_tokens = {}
total_num_scheduled_tokens = 0
for req_id in req_ids:
new_reqs.append(
NewRequestData(
req_id=req_id,
prompt_token_ids=[1, 2, 3],
prompt="test",
mm_inputs=[],
mm_hashes=[],
mm_positions=[],
sampling_params=SamplingParams(),
block_ids=[0],
num_computed_tokens=0,
lora_request=None,
))
num_scheduled_tokens[req_id] = 3
total_num_scheduled_tokens += num_scheduled_tokens[req_id]
return SchedulerOutput(
scheduled_new_reqs=new_reqs,
scheduled_cached_reqs=[],
num_scheduled_tokens=num_scheduled_tokens,
total_num_scheduled_tokens=total_num_scheduled_tokens,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
def _is_req_scheduled(model_runner, req_id: str) -> bool:
return req_id in model_runner.input_batch.req_id_to_index
def _is_req_added(model_runner, req_id: str) -> bool:
return req_id in model_runner.requests
def _is_sampling_metadata_changed(model_runner,
sampling_metadata_before: SamplingMetadata):
return model_runner.input_batch.sampling_metadata is not (
sampling_metadata_before)
def _is_req_state_block_table_match(model_runner, req_id: str) -> bool:
req_index = model_runner.input_batch.req_id_to_index[req_id]
block_table = model_runner.input_batch.block_table
req_state = model_runner.requests[req_id]
if block_table.num_blocks_per_row[req_index] != len(req_state.block_ids):
return False
num_blocks = block_table.num_blocks_per_row[req_index]
return (block_table.block_table_np[req_index, :num_blocks] ==
req_state.block_ids).all()
def test_update_states_new_request(model_runner):
req_id = "req_0"
# new req
scheduler_output = _schedule_new_request(req_id)
metadata_before = model_runner.input_batch.sampling_metadata
model_runner._update_states(scheduler_output)
assert _is_sampling_metadata_changed(model_runner, metadata_before)
assert _is_req_added(model_runner, req_id)
assert _is_req_scheduled(model_runner, req_id)
assert _is_req_state_block_table_match(model_runner, req_id)
def test_update_states_request_finished(model_runner):
req_id = "req_0"
# new req
scheduler_output = _schedule_new_request(req_id)
model_runner._update_states(scheduler_output)
assert _is_req_added(model_runner, req_id)
assert _is_req_scheduled(model_runner, req_id)
# finish req
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={},
total_num_scheduled_tokens=0,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
finished_req_ids={req_id},
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
metadata_before = model_runner.input_batch.sampling_metadata
model_runner._update_states(scheduler_output)
assert _is_sampling_metadata_changed(model_runner, metadata_before)
assert not _is_req_added(model_runner, req_id)
assert not _is_req_scheduled(model_runner, req_id)
def test_update_states_request_resumed(model_runner):
req_id = "req_0"
# new req
scheduler_output = _schedule_new_request(req_id)
model_runner._update_states(scheduler_output)
assert _is_req_added(model_runner, req_id)
assert _is_req_scheduled(model_runner, req_id)
# unschedule req
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={},
total_num_scheduled_tokens=0,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
model_runner._update_states(scheduler_output)
assert _is_req_added(model_runner, req_id)
assert not _is_req_scheduled(model_runner, req_id)
# resume req
cached_req_data = CachedRequestData(
req_id=req_id,
resumed_from_preemption=False,
new_token_ids=[],
new_block_ids=[],
num_computed_tokens=0,
)
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=[cached_req_data],
num_scheduled_tokens={req_id: 1},
total_num_scheduled_tokens=1,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
metadata_before = model_runner.input_batch.sampling_metadata
model_runner._update_states(scheduler_output)
assert _is_sampling_metadata_changed(model_runner, metadata_before)
assert _is_req_added(model_runner, req_id)
assert _is_req_scheduled(model_runner, req_id)
assert _is_req_state_block_table_match(model_runner, req_id)
def test_update_states_no_changes(model_runner):
req_id = "req_0"
# new req
scheduler_output = _schedule_new_request(req_id)
model_runner._update_states(scheduler_output)
assert _is_req_added(model_runner, req_id)
assert _is_req_scheduled(model_runner, req_id)
# schedule req
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={req_id: 1},
total_num_scheduled_tokens=1,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
metadata_before = model_runner.input_batch.sampling_metadata
model_runner._update_states(scheduler_output)
assert not _is_sampling_metadata_changed(model_runner, metadata_before)
assert _is_req_added(model_runner, req_id)
assert _is_req_scheduled(model_runner, req_id)
assert _is_req_state_block_table_match(model_runner, req_id)
def test_update_states_request_unscheduled(model_runner):
req_ids = ("req_0", "req_1")
# new reqs
scheduler_output = _schedule_new_request(*req_ids)
model_runner._update_states(scheduler_output)
assert _is_req_added(model_runner, req_ids[0])
assert _is_req_scheduled(model_runner, req_ids[0])
assert _is_req_added(model_runner, req_ids[1])
assert _is_req_scheduled(model_runner, req_ids[1])
# unschedule req_1
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={req_ids[0]: 1},
total_num_scheduled_tokens=1,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
metadata_before = model_runner._update_states(scheduler_output)
assert _is_sampling_metadata_changed(model_runner, metadata_before)
assert _is_req_added(model_runner, req_ids[0])
assert _is_req_scheduled(model_runner, req_ids[0])
assert _is_req_added(model_runner, req_ids[1])
assert not _is_req_scheduled(model_runner, req_ids[1])

View File

@ -124,6 +124,18 @@ def paged_attention_rocm(
kv_cache_dtype, k_scale, v_scale)
def mla_decode_kvcache_cpu(
out: torch.Tensor,
query: torch.Tensor,
kv_cache: torch.Tensor,
scale: float,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
) -> None:
torch.ops._C_cpu.mla_decode_kvcache(out, query, kv_cache, scale,
block_tables, seq_lens)
# pos encoding ops
def rotary_embedding(
positions: torch.Tensor,
@ -190,6 +202,33 @@ def advance_step_flashinfer(num_seqs: int, num_queries: int, block_size: int,
block_table_bound)
def block_table_appends(
append_row_indices: torch.Tensor,
append_row_indices_cpu: torch.Tensor,
append_cumsums: torch.Tensor,
append_cumsums_cpu: torch.Tensor,
append_block_ids: torch.Tensor,
append_block_ids_cpu: torch.Tensor,
block_table: torch.Tensor,
num_appends: int,
total_num_append_blocks: int,
) -> None:
torch.ops._C.block_table_appends.default(
append_row_indices, append_row_indices_cpu, append_cumsums,
append_cumsums_cpu, append_block_ids, append_block_ids_cpu,
block_table, num_appends, total_num_append_blocks)
def block_table_moves(
src_dst_n: torch.Tensor,
src_dst_n_cpu: torch.Tensor,
block_table: torch.Tensor,
num_moves: int,
) -> None:
torch.ops._C.block_table_moves.default(src_dst_n, src_dst_n_cpu,
block_table, num_moves)
# fused quant layer norm ops
def rms_norm_dynamic_per_token_quant(
input: torch.Tensor,

View File

@ -187,15 +187,28 @@ class ipex_ops:
gen_: torch.Generator,
logits_soft_cap: float,
) -> None:
ipex.llm.functional.varlen_attention(query.contiguous(),
key.contiguous(),
value.contiguous(), out,
seqlen_q.int(), seqlen_k.int(),
max_seqlen_q, max_seqlen_k,
pdropout, softmax_scale,
zero_tensors, is_causal,
return_softmax, gen_,
logits_soft_cap)
if ipex.__version__.endswith("cpu"):
if logits_soft_cap != 0.0:
raise ValueError("IPEX CPU does not support logits_soft_cap")
ipex.llm.functional.varlen_attention(query.contiguous(),
key.contiguous(),
value.contiguous(), out,
seqlen_q.int(),
seqlen_k.int(), max_seqlen_q,
max_seqlen_k, pdropout,
softmax_scale, zero_tensors,
is_causal, return_softmax,
gen_)
else: # XPU build
ipex.llm.functional.varlen_attention(query.contiguous(),
key.contiguous(),
value.contiguous(), out,
seqlen_q.int(),
seqlen_k.int(), max_seqlen_q,
max_seqlen_k, pdropout,
softmax_scale, zero_tensors,
is_causal, return_softmax,
gen_, logits_soft_cap)
@staticmethod
def reshape_and_cache(

View File

@ -0,0 +1,303 @@
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Type
import torch
import vllm._custom_ops as ops
from vllm._ipex_ops import ipex_ops
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadataBuilder,
AttentionType,
is_quantized_kv_cache)
from vllm.attention.backends.mla.common import MLACommonImpl, MLACommonState
from vllm.attention.backends.torch_sdpa import TorchSDPAMetadata
from vllm.utils import make_tensor_with_pad
from vllm.worker.cpu_model_runner import ModelInputForCPUBuilder
class CPUMLABackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "CPU_MLA"
@staticmethod
def get_metadata_cls() -> Type["CPUMLAMetadata"]:
return CPUMLAMetadata
@staticmethod
def get_builder_cls() -> Type["CPUMLAMetadataBuilder"]:
return CPUMLAMetadataBuilder
@staticmethod
def get_state_cls() -> Type["MLACommonState"]:
return MLACommonState
@staticmethod
def get_impl_cls() -> Type["CPUMLAImpl"]:
return CPUMLAImpl
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int, # assumed to be 1 for MLA
head_size: int,
) -> Tuple[int, ...]:
return (num_blocks, block_size, head_size)
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: torch.Tensor,
) -> None:
ops.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
) -> None:
ops.copy_blocks_mla(kv_caches, src_to_dists)
@staticmethod
def get_supported_head_sizes() -> List[int]:
return [576]
@dataclass
class CPUMLAMetadata(TorchSDPAMetadata):
# New for MLA
# Input positions for rotrary embeddings since for MLA the rotary
# position embeddings are applied inside the attention backend
input_positions: torch.Tensor = None
# required by MLACommonImpl
is_profile_run: bool = False
class CPUMLAMetadataBuilder(AttentionMetadataBuilder[CPUMLAMetadata]):
def __init__(self, input_builder: ModelInputForCPUBuilder) -> None:
self.chunked_prefill = input_builder.chunked_prefill
self.input_builder = input_builder
assert not self.chunked_prefill, \
"chunked prefill is currently not supported"
def prepare(self):
self.input_data = self.input_builder.input_data
def build(self, seq_lens, query_lens, cuda_graph_pad_size, batch_size):
input_data = self.input_data
prefill_seq_lens = seq_lens[0:input_data.num_prefills]
prefill_query_lens = query_lens[0:input_data.num_prefills]
slot_mapping = torch.tensor(input_data.slot_mapping,
dtype=torch.long,
device="cpu")
# metadata for prefill
if input_data.num_prefills > 0:
query_lens_tensor = torch.tensor(prefill_query_lens,
dtype=torch.int32,
device="cpu")
kv_lens_tensor = torch.tensor(prefill_seq_lens,
dtype=torch.int32,
device="cpu")
query_start_loc = torch.zeros(input_data.num_prefills + 1,
dtype=torch.int32,
device="cpu")
kv_start_loc = torch.zeros(input_data.num_prefills + 1,
dtype=torch.int32,
device="cpu")
torch.cumsum(query_lens_tensor,
dim=0,
dtype=torch.int32,
out=query_start_loc[1:])
torch.cumsum(kv_lens_tensor,
dim=0,
dtype=torch.int32,
out=kv_start_loc[1:])
max_query_len = max(prefill_query_lens)
max_kv_len = max(prefill_seq_lens)
# for chunked-prefill
if self.chunked_prefill:
prefill_block_tables = make_tensor_with_pad(
self.input_data.prefill_block_tables,
pad=0,
dtype=torch.int32,
device="cpu",
)
else:
prefill_block_tables = None
else:
query_start_loc = None
kv_start_loc = None
max_query_len = None
max_kv_len = None
prefill_block_tables = None
# metadata for decode
if input_data.num_decode_tokens != 0:
seq_lens_tensor = torch.tensor(
input_data.seq_lens[input_data.num_prefills:],
dtype=torch.int32,
device="cpu",
)
block_tables = make_tensor_with_pad(
self.input_data.decode_block_tables,
pad=0,
dtype=torch.int32,
device="cpu",
)
else:
block_tables = torch.tensor([])
seq_lens_tensor = torch.tensor(
input_data.seq_lens[:input_data.num_prefills],
dtype=torch.int32,
device="cpu",
)
# For multi-modal models
placeholder_index_maps = None
if len(input_data.multi_modal_inputs_list) != 0:
placeholder_index_maps = {
modality: placeholder_map.index_map()
for modality, placeholder_map in
input_data.multi_modal_placeholder_maps.items()
}
return CPUMLAMetadata(
chunked_prefill=self.chunked_prefill,
seq_lens=prefill_seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_query_len=max_query_len,
max_kv_len=max_kv_len,
query_start_loc=query_start_loc,
kv_start_loc=kv_start_loc,
max_decode_seq_len=input_data.max_decode_seq_len,
num_prefills=input_data.num_prefills,
num_prefill_tokens=input_data.num_prefill_tokens,
num_decode_tokens=input_data.num_decode_tokens,
block_tables=block_tables,
prefill_block_tables=prefill_block_tables,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=placeholder_index_maps,
enable_kv_scales_calculation=False,
input_positions=torch.tensor([self.input_data.input_positions]))
class CPUMLAImpl(MLACommonImpl[CPUMLAMetadata]):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]],
logits_soft_cap: Optional[float],
attn_type: str,
# MLA Specific Arguments
**mla_args) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
blocksparse_params, logits_soft_cap, attn_type,
**mla_args)
unsupported_features = [
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
]
if any(unsupported_features):
raise NotImplementedError(
"CPUMLAImpl does not support one of the following: "
"alibi_slopes, sliding_window, blocksparse_params, "
"logits_soft_cap")
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"CPUMLAImpl")
# states is implemented.
if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"CPUMLAImpl with FP8 KV cache not yet supported")
def _forward_prefill(
self,
q: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: CPUMLAMetadata, # type: ignore[override]
) -> torch.Tensor:
prefill_metadata = attn_metadata.prefill_metadata
assert prefill_metadata is not None
kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = kv_nope\
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
# For MLA the v head dim is smaller than qk head dim so we pad out
# v with 0s to match the qk head dim
v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
value=0)
output = torch.empty_like(q)
ipex_ops.varlen_attention(
query=q,
key=k,
value=v_padded,
out=output,
seqlen_q=prefill_metadata.query_start_loc,
seqlen_k=prefill_metadata.query_start_loc,
max_seqlen_q=prefill_metadata.max_query_len,
max_seqlen_k=prefill_metadata.max_query_len,
pdropout=0.0,
softmax_scale=self.scale,
zero_tensors=False,
is_causal=True,
return_softmax=False,
gen_=None,
logits_soft_cap=0.0,
)
# remove padding
output = output.view(-1, self.num_heads,
q.shape[-1])[..., :v.shape[-1]]
output = output.reshape(-1, self.num_heads * v.shape[-1])
return self.o_proj(output)[0]
def _forward_decode(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: CPUMLAMetadata, # type: ignore[override]
) -> torch.Tensor:
assert kv_c_and_k_pe_cache.numel() > 0
decode_meta = attn_metadata.decode_metadata
assert decode_meta is not None
q = torch.cat([q_nope, q_pe], dim=-1)
o = q.new_empty(q.shape[0], self.num_heads, self.kv_lora_rank)
# Run MQA
ops.mla_decode_kvcache_cpu(o, q, kv_c_and_k_pe_cache, self.scale,
decode_meta.block_tables,
decode_meta.seq_lens_tensor)
return self._v_up_proj_and_o_proj(o)

View File

@ -204,7 +204,6 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
compute_slot_mapping_start_idx,
is_block_tables_empty)
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearBase, RowParallelLinear,
UnquantizedLinearMethod)
@ -212,18 +211,27 @@ from vllm.model_executor.layers.rotary_embedding import (
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
from vllm.multimodal import MultiModalPlaceholderMap
from vllm.platforms import current_platform
from vllm.triton_utils import HAS_TRITON
from vllm.utils import async_tensor_h2d, cdiv, make_tensor_with_pad, round_down
from vllm.vllm_flash_attn.fa_utils import get_flash_attn_version
if HAS_TRITON:
from vllm.attention.ops.triton_flash_attention import triton_attention
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
else:
merge_attn_states = None
triton_attention = None
try:
from vllm.vllm_flash_attn import flash_attn_varlen_func
is_vllm_fa = True
except ImportError:
# For rocm use upstream flash attention
from flash_attn import flash_attn_varlen_func
is_vllm_fa = False
from vllm.attention.ops.triton_flash_attention import triton_attention
try:
# For rocm use upstream flash attention
from flash_attn import flash_attn_varlen_func
except ImportError:
flash_attn_varlen_func = None
if TYPE_CHECKING:
from vllm.worker.model_runner import (ModelInputForGPUBuilder,

View File

@ -221,6 +221,9 @@ class ModelConfig:
factors.append(self.trust_remote_code)
factors.append(self.rope_scaling)
factors.append(self.rope_theta)
# rope cos/sin cache depends on the max_position_embeddings
factors.append(
getattr(self.hf_config, "max_position_embeddings", "None"))
return hashlib.sha256(str(factors).encode()).hexdigest()
def __init__(

View File

@ -63,9 +63,6 @@ class Gemma3ImagePixelInputs(TypedDict):
Shape: `(batch_size, num_images, num_embeds)`
"""
num_embeds: Union[torch.Tensor, list[torch.Tensor]]
"""Shape: `(batch_size, num_images)`"""
Gemma3ImageInputs = Gemma3ImagePixelInputs
@ -295,8 +292,6 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
# HF processor pops the `num_crops` kwarg, which is needed by vLLM
if (images := mm_data.get("images")) is not None:
assert isinstance(images, list)
parsed_images = (self._get_data_parser().parse_mm_data({
"image":
images
@ -319,11 +314,6 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
tokenizer.encode(image_repl, add_special_tokens=False)
for image_repl in image_repl_features
]
num_embeds = [
len(image_repl_feature_tokens)
for image_repl_feature_tokens in image_repls_feature_tokens
]
processed_outputs["num_embeds"] = torch.tensor(num_embeds)
vocab = tokenizer.get_vocab()
image_token_id = vocab[tokenizer.image_token]
@ -356,7 +346,6 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
"image", num_crops + 1),
num_crops=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
num_embeds=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_updates(
@ -585,7 +574,6 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
pixel_values = kwargs.pop("pixel_values", None)
num_crops = kwargs.pop("num_crops", None)
embed_is_patch = kwargs.pop("embed_is_patch", None)
num_embeds = kwargs.pop("num_embeds", None)
image_embeds = kwargs.pop("image_embeds", None)
assert image_embeds is None, "Gemma3 does not support image_embeds."
if pixel_values is None:
@ -603,10 +591,6 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
if not isinstance(num_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of num_embeds. "
f"Got type: {type(num_embeds)}")
pixel_values = flatten_bn(pixel_values, concat=True)
num_crops = flatten_bn(num_crops, concat=True)
@ -615,7 +599,6 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
pixel_values=self._validate_pixel_values(pixel_values),
num_patches=num_crops + 1,
embed_is_patch=embed_is_patch,
num_embeds=num_embeds,
)
def _image_pixels_to_features(
@ -658,7 +641,6 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
return flatten_2d_lists(
scatter_patch_features(*args) for args in zip(
image_features,
image_input["num_embeds"],
image_input["embed_is_patch"],
))

View File

@ -69,9 +69,6 @@ class InternVLImagePixelInputs(TypedDict):
Shape: `(batch_size, num_images, num_embeds)`
"""
num_embeds: Union[torch.Tensor, list[torch.Tensor]]
"""Shape: `(batch_size, num_images)`"""
class InternVLImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
@ -426,7 +423,6 @@ class BaseInternVLProcessor(ABC):
tokenizer = self.tokenizer
image_token_id = self.image_token_id
num_embeds = list[int]()
embed_is_patch = list[torch.Tensor]()
for pixel_values in pixel_values_lst:
@ -438,11 +434,9 @@ class BaseInternVLProcessor(ABC):
add_special_tokens=False)
text = [t.replace('<image>', image_repl.full, 1) for t in text]
num_embeds.append(len(feature_tokens))
embed_is_patch.append(
torch.tensor(feature_tokens) == image_token_id)
image_inputs["num_embeds"] = torch.tensor(num_embeds)
image_inputs["embed_is_patch"] = embed_is_patch
text_inputs = self.tokenizer(text)
@ -607,7 +601,6 @@ class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
"image", image_num_patches),
image_num_patches=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
num_embeds=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
image_token_id=MultiModalFieldConfig.shared("image", num_images),
)
@ -840,7 +833,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
pixel_values_flat = kwargs.pop("pixel_values_flat", None)
image_num_patches = kwargs.pop("image_num_patches", None)
embed_is_patch = kwargs.pop("embed_is_patch", None)
num_embeds = kwargs.pop("num_embeds", None)
image_embeds = kwargs.pop("image_embeds", None)
if pixel_values_flat is None and image_embeds is None:
@ -873,10 +865,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
if not isinstance(num_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of num_embeds. "
f"Got type: {type(num_embeds)}")
pixel_values_flat = flatten_bn(pixel_values_flat, concat=True)
image_num_patches = flatten_bn(image_num_patches, concat=True)
@ -886,7 +874,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
pixel_values_flat),
num_patches=image_num_patches,
embed_is_patch=embed_is_patch,
num_embeds=num_embeds,
)
raise AssertionError("This line should be unreachable.")
@ -941,7 +928,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
return flatten_2d_lists(
scatter_patch_features(*args) for args in zip(
image_features,
image_input["num_embeds"],
image_input["embed_is_patch"],
))

View File

@ -76,9 +76,6 @@ class PixtralHFImagePixelInputs(TypedDict):
Shape: `(batch_size, num_images, num_embeds)`
"""
num_embeds: Union[torch.Tensor, list[torch.Tensor]]
"""Shape: `(batch_size, num_images)`"""
class LlavaImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
@ -358,15 +355,10 @@ class PixtralHFMultiModalProcessor(
image_height=pixel_value.shape[-2],
) for pixel_value in processed_outputs["pixel_values"]
]
num_embeds = torch.tensor([(ncols + 1) * nrows
for ncols, nrows in tile_sizes])
# Each image may result to masks of different sizes, so we need to
# later use `num_embeds` to get per-image masks.
embed_is_patch = [
torch.tensor(([True] * ncols + [False]) * nrows)
for ncols, nrows in tile_sizes
]
processed_outputs["num_embeds"] = num_embeds
processed_outputs["embed_is_patch"] = embed_is_patch
return processed_outputs
@ -378,7 +370,6 @@ class PixtralHFMultiModalProcessor(
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
num_embeds=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
)
@ -627,16 +618,10 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
num_embeds = kwargs.pop("num_embeds")
if not isinstance(num_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of num_embeds. "
f"Got type: {type(num_embeds)}")
return PixtralHFImagePixelInputs(
type="pixel_values_pixtral",
pixel_values=flatten_bn(pixel_values),
embed_is_patch=embed_is_patch,
num_embeds=num_embeds,
)
return LlavaImagePixelInputs(
@ -738,7 +723,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
return flatten_2d_lists(
scatter_patch_features(*args) for args in zip(
vision_embeddings,
image_input["num_embeds"],
image_input["embed_is_patch"],
))

View File

@ -23,7 +23,7 @@
# limitations under the License.
"""Inference-only MiniCPM-O model compatible with HuggingFace weights."""
from collections.abc import Iterable, Mapping, Sequence
from typing import (Any, Callable, Dict, List, Literal, Optional, Set, Tuple,
from typing import (Any, Callable, Dict, Literal, Optional, Set, Tuple,
TypedDict, Union)
import torch
@ -43,24 +43,26 @@ from vllm.multimodal.parse import (AudioItem, AudioProcessorItems,
from vllm.multimodal.processing import PromptReplacement, PromptUpdate
from vllm.multimodal.profiling import ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.utils import flatten_2d_lists
from .minicpmv import (MiniCPMV2_6, MiniCPMVDummyInputsBuilder,
MiniCPMVMultiModalDataParser,
MiniCPMVMultiModalProcessor, MiniCPMVProcessingInfo,
_minicpmv_field_config)
from .utils import AutoWeightsLoader, cast_overflow_tensors, maybe_prefix
from .utils import (AutoWeightsLoader, cast_overflow_tensors, flatten_bn,
maybe_prefix)
CPU_DEVICE = torch.device("cpu")
class MiniCPMOAudioFeatureInputs(TypedDict):
type: Literal["audio_features"]
data: torch.Tensor
audio_features: torch.Tensor
"""
Shape: `(batch_size * num_audios * num_slices, num_channels, length)`
Slice here means chunk. Audio that is too long will be split into slices,
which is the same as image.
Padding is used therefore `data` is `torch.Tensor`.
Padding is used therefore `audio_features` is `torch.Tensor`.
"""
audio_feature_lens: torch.Tensor
@ -68,7 +70,7 @@ class MiniCPMOAudioFeatureInputs(TypedDict):
Shape: `(batch_size * num_audios * num_slices)`
This should be feature length of each audio slice,
which equals to `data.shape[-1]`
which equals to `audio_features.shape[-1]`
"""
audio_bounds: torch.Tensor
@ -81,7 +83,7 @@ class MiniCPMOAudioFeatureInputs(TypedDict):
class MiniCPMOAudioEmbeddingInputs(TypedDict):
type: Literal["audio_embeds"]
data: List[torch.Tensor]
audio_embeds: torch.Tensor
"""
Shape: `(batch_size * num_images * num_slices, hidden_size)`
@ -102,18 +104,11 @@ MiniCPMOAudioInputs = Union[MiniCPMOAudioFeatureInputs,
def _minicpmo_field_config(hf_inputs: Mapping[str, torch.Tensor]):
audio_num_slices = hf_inputs.get("audio_num_slices", torch.empty(0))
return dict(
**_minicpmv_field_config(hf_inputs),
audio_features=MultiModalFieldConfig.flat_from_sizes(
"audio", audio_num_slices),
audio_feature_lens=MultiModalFieldConfig.flat_from_sizes(
"audio", audio_num_slices),
audio_num_slices=MultiModalFieldConfig.batched("audio"),
audio_orders_in_mm_data=MultiModalFieldConfig.batched("audio"),
audio_embeds=MultiModalFieldConfig.flat_from_sizes(
"audio", audio_num_slices),
audio_features=MultiModalFieldConfig.batched("audio"),
audio_feature_lens=MultiModalFieldConfig.batched("audio"),
audio_embeds=MultiModalFieldConfig.batched("audio"),
)
@ -153,9 +148,6 @@ class MiniCPMOMultiModalDataParser(MiniCPMVMultiModalDataParser):
class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
audio_pattern = "(<audio>./</audio>)"
def get_supported_mm_modalities(self) -> List[str]:
return ["image", "video", "audio"]
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None, "video": None, "audio": None}
@ -277,95 +269,47 @@ class MiniCPMOMultiModalProcessor(
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> Mapping[str, NestedTensors]:
mm_data = dict(mm_data)
if (audios := mm_data.get("audios")) is None:
return {}
audios = mm_data.pop("audios", [])
audio_embeds = mm_data.pop("audio_embeds", [])
if isinstance(audios, (list, torch.Tensor)) and len(audios) > 0:
audio_outputs = {
"audio_lens": [],
"audio_features": [],
"audio_feature_lens": [],
"audio_num_segments": []
}
for audio in audios:
single_audio_outputs = super().call_base_hf_processor(
prompt=self.info.audio_pattern,
mm_data={
"audios": audio,
"chunk_input": True
},
mm_kwargs=mm_kwargs)
audio_outputs["audio_lens"].append(len(audio))
audio_outputs["audio_features"].append(
single_audio_outputs["audio_features"])
audio_outputs["audio_num_segments"].append(
len(single_audio_outputs["audio_feature_lens"][0]))
audio_outputs["audio_feature_lens"] += \
single_audio_outputs["audio_feature_lens"]
audio_outputs["audio_features"] = [
audio_feature for single_audio_features in \
audio_outputs["audio_features"]
for audio_feature in single_audio_features
]
audio_outputs["audio_feature_lens"] = torch.cat(
audio_outputs["audio_feature_lens"])
elif len(audio_embeds):
audio_outputs = {
"audio_lens": [
self.info.get_audio_len_by_num_chunks(
sum(chunk_embeds.shape[0]
for chunk_embeds in single_audio_embeds))
for single_audio_embeds in audio_embeds
],
"audio_embeds": [
chunk_embeds for single_audio_embeds in audio_embeds
for chunk_embeds in single_audio_embeds
],
"audio_num_segments": [
len(single_audio_embeds)
for single_audio_embeds in audio_embeds
]
}
else:
audio_outputs = {}
return audio_outputs
parsed_audios = (self._get_data_parser().parse_mm_data({
"audio": audios
}).get_items("audio", AudioProcessorItems))
audio_inputs = self._base_call_hf_processor(
prompts=[self.info.audio_pattern] * len(parsed_audios),
mm_data={"audios": [[audio] for audio in parsed_audios]},
mm_kwargs={
**mm_kwargs, "chunk_input": True
},
out_keys={"audio_features", "audio_feature_lens"},
)
# Avoid padding since we need the output for each audio to be
# independent of other audios for the cache to work correctly
unpadded_audio_features = [
feat[:, :feature_len] for feat, feature_len in zip(
audio_inputs["audio_features"],
audio_inputs["audio_feature_lens"],
)
]
audio_inputs["audio_features"] = unpadded_audio_features
return audio_inputs
def get_placeholder_match_pattern(self) -> str:
return r"\(<(image|video|audio)>./</\1>\)"
def get_placeholder_split_pattern(self) -> str:
return r"\(<(?:image|video|audio)>./</(?:image|video|audio)>\)"
def process_mm_inputs(
self,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> Mapping[str, Mapping[str, NestedTensors]]:
) -> Mapping[str, NestedTensors]:
return {
"image": self.process_images(mm_data, mm_kwargs),
"video": self.process_videos(mm_data, mm_kwargs),
"audio": self.process_audios(mm_data, mm_kwargs),
**super().process_mm_inputs(mm_data, mm_kwargs),
**self.process_audios(mm_data, mm_kwargs),
}
def get_modality_num_counter(self, modality: str) -> str:
if modality == "audio":
return "audio_lens"
return super().get_modality_num_counter(modality)
def get_num_slices_by_modality(self, inputs: Dict[str, object],
modality: str, index: int) -> int:
if modality == "audio":
return inputs["audio"]["audio_num_segments"][index]
return super().get_num_slices_by_modality(inputs, modality, index)
def get_prompt_texts_by_modality(self, inputs: Dict[str, object],
modality: str, index: int) -> str:
if modality == "audio":
return self.get_audio_prompt_texts(
inputs["audio"]["audio_lens"][index])
return super().get_prompt_texts_by_modality(inputs, modality, index)
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
@ -622,86 +566,84 @@ class MiniCPMO(MiniCPMV2_6):
# Copied from HF repo of MiniCPM-o-2_6,
# designed for batched inputs and outputs
def get_audio_hidden_states(self, data: MiniCPMOAudioInputs,
chunk_length: int) -> torch.Tensor:
chunk_length: int) -> list[torch.Tensor]:
wavforms = data.get(
"data",
"audio_features",
[]) # (bs, 80, frames) or [], multi audios need filled in advance
audio_feature_lens_raw = [data.get("audio_feature_lens",
[])] # list, [[x1, x2], [y1], [z1]]
# exist audio
if len(wavforms) > 0:
audio_feature_lens = torch.hstack(audio_feature_lens_raw)
batch_size, _, max_mel_seq_len = wavforms.shape
max_seq_len = (max_mel_seq_len - 1) // 2 + 1
# Create a sequence tensor of shape (batch_size, max_seq_len)
seq_range = (torch.arange(
0,
max_seq_len,
dtype=audio_feature_lens.dtype,
device=audio_feature_lens.device).unsqueeze(0).expand(
batch_size, max_seq_len))
lengths_expand = audio_feature_lens.unsqueeze(1).expand(
batch_size, max_seq_len)
# Create mask
padding_mask = seq_range >= lengths_expand # 1 for padded values
audio_attention_mask_ = padding_mask.view(
batch_size, 1, 1, max_seq_len).expand(batch_size, 1,
max_seq_len, max_seq_len)
audio_attention_mask = audio_attention_mask_.to(
dtype=self.apm.conv1.weight.dtype,
device=self.apm.conv1.weight.device)
if chunk_length > 0:
chunk_num_frame = int(chunk_length * 50)
chunk_mask = self.subsequent_chunk_mask(
size=max_seq_len,
chunk_size=chunk_num_frame,
num_left_chunks=-1,
device=audio_attention_mask_.device,
)
audio_attention_mask_ = torch.logical_or(
audio_attention_mask_, torch.logical_not(chunk_mask))
audio_attention_mask[audio_attention_mask_] = float("-inf")
audio_states = self.apm(
wavforms, attention_mask=audio_attention_mask).hidden_states[
self.audio_encoder_layer]
audio_embeds = self.audio_projection_layer(audio_states)
audio_embeds = audio_embeds.transpose(1, 2)
audio_embeds = self.audio_avg_pooler(audio_embeds)
audio_embeds = audio_embeds.transpose(1, 2)
_, feature_lens_after_pooling = \
self._get_feat_extract_output_lengths(audio_feature_lens)
num_audio_tokens = feature_lens_after_pooling
final_audio_embeds = []
idx = 0
for i in range(len(audio_feature_lens_raw)):
target_audio_embeds = []
for _ in range(len(audio_feature_lens_raw[i])):
target_audio_embeds.append(
audio_embeds[idx, :num_audio_tokens[idx], :])
idx += 1
final_audio_embeds.append(target_audio_embeds)
return final_audio_embeds
else:
if len(wavforms) == 0:
return []
audio_feature_lens = torch.hstack(audio_feature_lens_raw)
batch_size, _, max_mel_seq_len = wavforms.shape
max_seq_len = (max_mel_seq_len - 1) // 2 + 1
# Create a sequence tensor of shape (batch_size, max_seq_len)
seq_range = (torch.arange(
0,
max_seq_len,
dtype=audio_feature_lens.dtype,
device=audio_feature_lens.device).unsqueeze(0).expand(
batch_size, max_seq_len))
lengths_expand = audio_feature_lens.unsqueeze(1).expand(
batch_size, max_seq_len)
# Create mask
padding_mask = seq_range >= lengths_expand # 1 for padded values
audio_attention_mask_ = padding_mask.view(
batch_size, 1, 1, max_seq_len).expand(batch_size, 1, max_seq_len,
max_seq_len)
audio_attention_mask = audio_attention_mask_.to(
dtype=self.apm.conv1.weight.dtype,
device=self.apm.conv1.weight.device)
if chunk_length > 0:
chunk_num_frame = int(chunk_length * 50)
chunk_mask = self.subsequent_chunk_mask(
size=max_seq_len,
chunk_size=chunk_num_frame,
num_left_chunks=-1,
device=audio_attention_mask_.device,
)
audio_attention_mask_ = torch.logical_or(
audio_attention_mask_, torch.logical_not(chunk_mask))
audio_attention_mask[audio_attention_mask_] = float("-inf")
audio_states = self.apm(
wavforms, attention_mask=audio_attention_mask).hidden_states[
self.audio_encoder_layer]
audio_embeds = self.audio_projection_layer(audio_states)
audio_embeds = audio_embeds.transpose(1, 2)
audio_embeds = self.audio_avg_pooler(audio_embeds)
audio_embeds = audio_embeds.transpose(1, 2)
_, feature_lens_after_pooling = \
self._get_feat_extract_output_lengths(audio_feature_lens)
num_audio_tokens = feature_lens_after_pooling
final_audio_embeds = []
idx = 0
for i in range(len(audio_feature_lens_raw)):
target_audio_embeds = []
for _ in range(len(audio_feature_lens_raw[i])):
target_audio_embeds.append(
audio_embeds[idx, :num_audio_tokens[idx], :])
idx += 1
final_audio_embeds.append(target_audio_embeds)
return final_audio_embeds
def get_embedding_with_audios(self, vlm_embedding: torch.Tensor,
audio_inputs: Optional[MiniCPMOAudioInputs],
audio_inputs: MiniCPMOAudioInputs,
chunk_length: int) -> torch.Tensor:
device, dtype = vlm_embedding.device, vlm_embedding.dtype
if audio_inputs["type"] == "audio_embeds":
audio_embeddings = audio_inputs["data"]
audio_embeddings = [
audio_embeddings[i].to(device=device, dtype=dtype)
for i in range(len(audio_embeddings))
item.to(device=device, dtype=dtype)
for item in audio_inputs["audio_embeds"]
]
else:
audio_embeddings = self.get_audio_hidden_states(
@ -746,40 +688,56 @@ class MiniCPMO(MiniCPMV2_6):
def _parse_and_validate_audio_inputs(
self, input_ids: torch.Tensor,
**kwargs: object) -> Tuple[MiniCPMOAudioInputs]:
audio_features = kwargs.pop("audio_features", [])
audio_feature_lens = kwargs.pop("audio_feature_lens", [])
**kwargs: object) -> Optional[MiniCPMOAudioInputs]:
audio_features = kwargs.pop("audio_features", None)
audio_embeds = kwargs.pop("audio_embeds", None)
audio_start_id = kwargs.pop("audio_start_id", None)
audio_end_id = kwargs.pop("audio_end_id", None)
if audio_features is None and audio_embeds is None:
return None
audio_start_id = kwargs.pop("audio_start_id")
if not isinstance(audio_start_id, torch.Tensor):
raise ValueError("Incorrect type of audio_start_id. "
f"Got type: {type(audio_start_id)}")
audio_end_id = kwargs.pop("audio_end_id")
if not isinstance(audio_end_id, torch.Tensor):
raise ValueError("Incorrect type of audio_end_id. "
f"Got type: {type(audio_end_id)}")
if audio_embeds is not None:
audio_embeds = [
audio_embeds[i][j] for i in range(len(audio_embeds))
for j in range(len(audio_embeds[i]))
]
if not isinstance(audio_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio_embeds. "
f"Got type: {type(audio_embeds)}")
return MiniCPMOAudioEmbeddingInputs(
type="audio_embeds",
audio_embeds=flatten_bn(flatten_2d_lists(audio_embeds),
concat=True),
audio_bounds=self._get_audio_bounds(input_ids, audio_start_id,
audio_end_id),
data=audio_embeds,
type="audio_embeds")
if len(audio_features) > 0:
audio_features_all = [
i.permute(1, 0) for audio_feature in audio_features
for i in audio_feature
]
audio_features = torch.nn.utils.rnn.pad_sequence(
audio_features_all, batch_first=True,
padding_value=0.0).permute(0, 2, 1)
audio_feature_lens = torch.cat(
[item for item in audio_feature_lens])
)
if audio_features is not None:
if not isinstance(audio_features, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio_features. "
f"Got type: {type(audio_features)}")
audio_feature_lens = kwargs.pop("audio_feature_lens")
if not isinstance(audio_feature_lens, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio_feature_lens. "
f"Got type: {type(audio_feature_lens)}")
return MiniCPMOAudioFeatureInputs(
type="audio_features",
audio_features=flatten_bn(audio_features, concat=True),
audio_feature_lens=flatten_bn(
flatten_2d_lists(audio_feature_lens), concat=True),
audio_bounds=self._get_audio_bounds(input_ids, audio_start_id,
audio_end_id),
data=audio_features,
audio_feature_lens=audio_feature_lens,
type="audio_features")
return None
)
raise AssertionError("This line should be unreachable.")
def _parse_and_validate_inputs(self, input_ids: torch.Tensor,
**kwargs: object):
@ -803,7 +761,7 @@ class MiniCPMO(MiniCPMV2_6):
else:
image_inputs, audio_inputs = \
self._parse_and_validate_inputs(input_ids, **kwargs)
vlm_embeddings, _ = self.get_embedding_with_vision(
vlm_embeddings = self.get_embedding_with_vision(
input_ids, image_inputs)
if audio_inputs is not None:

View File

@ -24,6 +24,7 @@
"""Inference-only MiniCPM-V model compatible with HuggingFace weights."""
import math
import re
from collections import defaultdict
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property, partial
from typing import (Any, Callable, Dict, List, Literal, Optional, Set, Tuple,
@ -63,11 +64,12 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils import flatten_2d_lists
from .idefics2_vision_model import Idefics2VisionTransformer
from .interfaces import (SupportsLoRA, SupportsMultiModal, SupportsPP,
SupportsV0Only)
from .utils import AutoWeightsLoader, maybe_prefix
from .utils import AutoWeightsLoader, flatten_bn, maybe_prefix
CPU_DEVICE = torch.device("cpu")
@ -76,7 +78,7 @@ RawImageType = Union[Image.Image, torch.Tensor]
class MiniCPMVImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: List[torch.Tensor]
pixel_values: list[torch.Tensor]
"""
Shape: `(batch_size * num_images * num_slices, num_channels, height, width)`
@ -101,7 +103,7 @@ class MiniCPMVImagePixelInputs(TypedDict):
class MiniCPMVImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: torch.Tensor
image_embeds: torch.Tensor
"""
Shape: `(batch_size * num_images * num_slices,
image_feature_size, hidden_size)`
@ -231,26 +233,15 @@ def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]:
def _minicpmv_field_config(hf_inputs: Mapping[str, torch.Tensor]):
image_num_slices = hf_inputs.get("image_num_slices", torch.empty(0))
video_num_slices = hf_inputs.get("video_num_slices", torch.empty(0))
return dict(
pixel_values=MultiModalFieldConfig.flat_from_sizes(
"image", image_num_slices),
pixel_values=MultiModalFieldConfig.batched("image"),
image_sizes=MultiModalFieldConfig.batched("image"),
tgt_sizes=MultiModalFieldConfig.flat_from_sizes(
"image", image_num_slices),
image_num_slices=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.flat_from_sizes(
"image", image_num_slices),
video_pixel_values=MultiModalFieldConfig.flat_from_sizes(
"video", video_num_slices),
tgt_sizes=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
video_pixel_values=MultiModalFieldConfig.batched("video"),
video_image_sizes=MultiModalFieldConfig.batched("video"),
video_tgt_sizes=MultiModalFieldConfig.flat_from_sizes(
"video", video_num_slices),
video_embeds=MultiModalFieldConfig.flat_from_sizes(
"video", video_num_slices),
video_num_slices=MultiModalFieldConfig.batched("video"),
video_tgt_sizes=MultiModalFieldConfig.batched("video"),
video_embeds=MultiModalFieldConfig.batched("video"),
)
@ -356,12 +347,6 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
def get_model_version(self):
return get_version_by_config(self.get_hf_config())
def get_supported_mm_modalities(self) -> List[str]:
if self.get_model_version() == (2, 6):
return ["image", "video"]
else:
return ["image"]
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
if self.get_model_version() == (2, 6):
return {"image": None, "video": None}
@ -526,187 +511,123 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
def get_image_prompt_texts(self,
image_size: ImageSize,
image_idx: int = 0) -> str:
prompt_texts = self.get_slice_image_placeholder(image_size,
image_idx=image_idx)
return prompt_texts
return self.get_slice_image_placeholder(image_size,
image_idx=image_idx)
def get_video_prompt_texts(self, image_size: ImageSize,
num_frames: int) -> str:
prompt_texts = "".join(
self.get_slice_image_placeholder(
image_size=image_size,
image_idx=0,
max_slice_nums=self.info.get_video_max_slice_num(),
use_image_id=False) for image_idx in range(num_frames))
return prompt_texts
return self.get_slice_image_placeholder(
image_size=image_size,
image_idx=0,
max_slice_nums=self.info.get_video_max_slice_num(),
use_image_id=False,
) * num_frames
def get_special_tokens(self) -> Dict[str, torch.Tensor]:
tokenizer = self.info.get_tokenizer()
special_tokens = {
"im_start_id": torch.tensor(tokenizer.im_start_id),
"im_end_id": torch.tensor(tokenizer.im_end_id)
"im_start_id": tokenizer.im_start_id,
"im_end_id": tokenizer.im_end_id,
}
if hasattr(tokenizer, "slice_start_id"):
special_tokens["slice_start_id"] = torch.tensor(
tokenizer.slice_start_id)
special_tokens["slice_end_id"] = torch.tensor(
tokenizer.slice_end_id)
return special_tokens
special_tokens["slice_start_id"] = tokenizer.slice_start_id
special_tokens["slice_end_id"] = tokenizer.slice_end_id
@staticmethod
def repack_processor_outputs(outputs: Any) -> BatchFeature:
valid_keys = ["pixel_values", "image_sizes", "tgt_sizes"]
outputs = {key: outputs[key][0] for key in valid_keys}
return outputs
return {k: torch.tensor(v) for k, v in special_tokens.items()}
def process_images(
self,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> Mapping[str, NestedTensors]:
mm_data = dict(mm_data)
if (images := mm_data.get("images")) is None:
return {}
images = mm_data.pop("images", [])
image_embeds = mm_data.pop("image_embeds", [])
if isinstance(images, Image.Image):
images = [images]
if isinstance(images, (list, torch.Tensor)) and len(images) > 0:
image_outputs = super()._call_hf_processor(
prompt=self.info.image_pattern * len(images),
mm_data={"images": images},
mm_kwargs=mm_kwargs)
image_outputs = self.repack_processor_outputs(image_outputs)
elif len(image_embeds) > 0:
image_sizes = mm_data.pop("image_sizes", None)
image_outputs = {
"image_embeds": torch.cat(image_embeds),
"image_sizes": image_sizes
}
else:
image_outputs = {}
return image_outputs
parsed_images = (self._get_data_parser().parse_mm_data({
"image": images
}).get_items("image", ImageProcessorItems))
return self._base_call_hf_processor(
prompts=[self.info.image_pattern] * len(parsed_images),
mm_data={"images": [[image] for image in parsed_images]},
mm_kwargs=mm_kwargs,
out_keys={"pixel_values", "image_sizes", "tgt_sizes"},
)
def process_videos(
self,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> Mapping[str, NestedTensors]:
mm_data = dict(mm_data)
if (videos := mm_data.get("videos")) is None:
return {}
videos = mm_data.pop("videos", [])
video_embeds = mm_data.pop("video_embeds", [])
if len(videos) > 0 and isinstance(videos[0], Image.Image):
videos = [videos]
if isinstance(videos, list) and len(videos) > 0:
video_outputs = {
"video_pixel_values": [],
"video_image_sizes": [],
"video_tgt_sizes": [],
"num_frames": []
}
for video in videos:
parsed_video = []
for frame in video:
if isinstance(frame, np.ndarray):
parsed_video.append(Image.fromarray(frame))
else:
parsed_video.append(frame)
video = parsed_video
single_video_outputs = super()._call_hf_processor(
prompt=self.info.image_pattern * len(video),
mm_data={"images": video},
mm_kwargs={
**mm_kwargs, "max_slice_nums":
self.info.get_video_max_slice_num()
})
video_outputs["num_frames"].append(len(video))
for key in single_video_outputs:
if "video_" + key in video_outputs:
if key == "image_sizes":
video_outputs["video_" + key].append(
single_video_outputs[key][0][0])
else:
video_outputs["video_" +
key] += single_video_outputs[key][0]
elif len(video_embeds):
image_sizes = mm_data.pop("image_sizes", None)
num_frames = mm_data.pop("num_frames", None)
video_outputs = {
"video_embeds": torch.cat(video_embeds),
"video_image_sizes": image_sizes,
"num_frames": num_frames
}
else:
video_outputs = {}
return video_outputs
parsed_videos = (self._get_data_parser().parse_mm_data({
"video": videos
}).get_items("video", VideoProcessorItems))
max_slice_num = self.info.get_video_max_slice_num()
video_inputs = self._base_call_hf_processor(
prompts=[
self.info.image_pattern * len(video) for video in parsed_videos
],
mm_data={"images": list(parsed_videos)},
mm_kwargs={
**mm_kwargs, "max_slice_nums": max_slice_num
},
out_keys={"pixel_values", "image_sizes", "tgt_sizes"},
)
return {f"video_{k}": v for k, v in video_inputs.items()}
def get_placeholder_match_pattern(self) -> str:
return r"\(<(image|video)>./</\1>\)"
def get_placeholder_split_pattern(self) -> str:
return r"\(<(?:image|video)>./</(?:image|video)>\)"
def process_mm_inputs(
self,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> Mapping[str, Mapping[str, NestedTensors]]:
) -> Mapping[str, NestedTensors]:
return {
"image": self.process_images(mm_data, mm_kwargs),
"video": self.process_videos(mm_data, mm_kwargs),
**self.process_images(mm_data, mm_kwargs),
**self.process_videos(mm_data, mm_kwargs),
}
def get_input_modalities(self, mm_data) -> List[str]:
supported_mm_modalities = self.info.get_supported_mm_modalities()
input_modalities = []
for modality in supported_mm_modalities:
if modality in mm_data and mm_data[modality] != {}:
input_modalities.append(modality)
return input_modalities
def get_modality_num_counter(self, modality: str) -> str:
if modality == "image":
return "image_sizes"
elif modality == "video":
return "video_image_sizes"
raise NotImplementedError(modality)
def get_num_slices_by_modality(self, inputs: dict[str, Any], modality: str,
index: int) -> int:
if modality == "image":
return self.info.get_image_slice_nums(
inputs[modality]["image_sizes"][index],
self.info.get_max_slice_num())
elif modality == "video":
return self.info.get_image_slice_nums(
inputs[modality]["video_image_sizes"][index],
self.info.get_video_max_slice_num()
) * inputs[modality]["num_frames"][index]
else:
raise ValueError(f"Unexpected modality: {modality}")
def get_prompt_texts_by_modality(self, inputs: dict[str, Any],
modality: str, index: int) -> str:
if modality == "image":
return self.get_image_prompt_texts(
inputs["image"]["image_sizes"][index], index)
elif modality == "video":
return self.get_video_prompt_texts(
inputs["video"]["video_image_sizes"][index],
inputs["video"]["num_frames"][index])
else:
raise ValueError(f"Unexpected modality: {modality}")
def call_base_hf_processor(
def _base_call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
prompts: list[str],
mm_data: Mapping[str, Sequence[object]],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
return super()._call_hf_processor(prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs)
*,
out_keys: set[str],
) -> Mapping[str, NestedTensors]:
# This processor supports zipping prompt and mm_data together
if self.info.get_model_version() == (2, 6):
inputs = super()._call_hf_processor(
prompt=prompts, # type: ignore
mm_data=mm_data,
mm_kwargs=mm_kwargs,
)
else:
inputs = defaultdict[str, list[torch.Tensor]](list)
for i, prompt in enumerate(prompts):
inputs_one = super()._call_hf_processor(
prompt=prompt,
mm_data={
k: v[i]
for k, v in mm_data.items()
},
mm_kwargs=mm_kwargs,
)
for k, v in inputs_one.items():
assert len(v) == 1, (k, len(v))
inputs[k].append(v[0])
return {k: inputs[k] for k in out_keys}
def _call_hf_processor(
self,
@ -717,35 +638,12 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
# Do not support combination inputs of images and videos for now
# Try to handle interleaved multimodal data
tokenizer = self.info.get_tokenizer()
inputs = self.process_mm_inputs(mm_data, mm_kwargs)
mm_input_modalities = self.get_input_modalities(inputs)
num_mm_slices_lst = {
modality: list[int]()
for modality in mm_input_modalities
}
for modality in mm_input_modalities:
num_counter_key = self.get_modality_num_counter(modality)
for index in range(len(inputs[modality][num_counter_key])):
num_mm_slices_lst[modality].append(
self.get_num_slices_by_modality(inputs, modality, index))
num_mm_slices = {
modality: torch.tensor(v)
for modality, v in num_mm_slices_lst.items()
}
mm_inputs = self.process_mm_inputs(mm_data, mm_kwargs)
return BatchFeature({
"input_ids": np.array([tokenizer.encode(prompt)]),
**{
key: value
for modality in inputs
for key, value in inputs[modality].items()
},
**{
f"{modality}_num_slices": num_mm_slices[modality]
for modality in mm_input_modalities
}
"input_ids":
torch.tensor([tokenizer.encode(prompt)]),
**mm_inputs,
})
def _hf_processor_applies_updates(
@ -810,7 +708,6 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
hf_processor_mm_kwargs: Mapping[str, object],
return_mm_hashes: bool = False,
) -> MultiModalInputs:
supported_mm_modalities = self.info.get_supported_mm_modalities()
if isinstance(prompt, list):
prompt = self.info.get_tokenizer().decode(prompt)
matches = re.findall(self.get_placeholder_match_pattern(), prompt)
@ -818,7 +715,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
f"{modality}_orders":
torch.tensor(
[index for index, m in enumerate(matches) if m == modality])
for modality in supported_mm_modalities
for modality in self.info.get_supported_mm_limits()
}
result = super().apply(prompt, mm_data, hf_processor_mm_kwargs,
return_mm_hashes)
@ -884,35 +781,35 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP,
self,
input_ids: torch.Tensor,
image_inputs: Optional[MiniCPMVImageInputs],
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
vlm_embedding: torch.Tensor = self.llm.get_input_embeddings(input_ids)
if image_inputs is None: # No image
vision_hidden_states = torch.tensor([], device=input_ids.device)
if image_inputs is None:
return vlm_embedding
if image_inputs["type"] == "image_embeds":
vision_hidden_states = image_inputs["image_embeds"].to(
device=vlm_embedding.device,
dtype=vlm_embedding.dtype,
)
else:
if image_inputs["type"] == "image_embeds":
vision_hidden_states = (image_inputs["data"].type(
vlm_embedding.dtype).to(vlm_embedding.device))
else:
vision_hidden_states = self.get_vision_hidden_states(
image_inputs)
vision_hidden_states = self.get_vision_hidden_states(image_inputs)
# See NOTE in _parse_and_validate_inputs
image_bounds = image_inputs["image_bounds"]
if len(image_bounds) > 0:
image_indices = torch.stack([
torch.arange(start, end, dtype=torch.long)
for start, end in image_bounds.tolist()
]).to(vlm_embedding.device)
vlm_embedding.scatter_(
0,
image_indices.view(-1, 1).repeat(1,
vlm_embedding.shape[-1]),
vision_hidden_states.view(-1,
vision_hidden_states.shape[-1]),
)
# See NOTE in _parse_and_validate_inputs
image_bounds = image_inputs["image_bounds"]
if len(image_bounds) > 0:
image_indices = torch.stack([
torch.arange(start, end, dtype=torch.long)
for start, end in image_bounds.tolist()
]).to(vlm_embedding.device)
return vlm_embedding, vision_hidden_states
vlm_embedding.scatter_(
0,
image_indices.view(-1, 1).repeat(1, vlm_embedding.shape[-1]),
vision_hidden_states.view(-1, vision_hidden_states.shape[-1]),
)
return vlm_embedding
def _get_image_bounds(
self,
@ -947,90 +844,115 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP,
input_ids: torch.Tensor,
**kwargs: object,
) -> Optional[MiniCPMVImageInputs]:
mm_data = {
image_keys = {"pixel_values", "tgt_sizes"}
pixel_data = {
"image": {
key: kwargs.pop(key, [])
for key in ["pixel_values", "tgt_sizes", "image_num_slices"]
key: kwargs.pop(key, None)
for key in image_keys
},
"video": {
"pixel_values": kwargs.pop("video_pixel_values", []),
"tgt_sizes": kwargs.pop("video_tgt_sizes", []),
"video_num_slices": kwargs.pop("video_num_slices", [])
key: kwargs.pop("video_" + key, None)
for key in image_keys
}
}
im_start_id = kwargs.pop("im_start_id", None)
im_end_id = kwargs.pop("im_end_id", None)
slice_start_id = kwargs.pop("slice_start_id", None)
slice_end_id = kwargs.pop("slice_end_id", None)
mm_orders = {
f"{modality}": kwargs.pop(f"{modality}_orders", None)
for modality in ["image", "video", "audio"]
embed_data = {
"image": kwargs.pop("image_embeds", None),
"video": kwargs.pop("video_embeds", None),
}
batch_size = max(len(mm_data["image"]["pixel_values"]),
len(mm_data["video"]["pixel_values"]))
image_embeds = kwargs.pop("image_embeds", None)
video_embeds = kwargs.pop("video_embeds", None)
if image_embeds is not None and video_embeds is not None:
raise ValueError(
"Incorrect inputs for vision embeddings. "
"Image embeds and video embeds can not exist simultaneously.")
if video_embeds is not None:
image_embeds = video_embeds
if image_embeds is not None:
if not isinstance(image_embeds, (torch.Tensor, list)):
raise ValueError(f"Incorrect type of image embeds. "
f"Got type: {type(image_embeds)}")
image_embeds = torch.concat(
[image_embeds[i] for i in range(len(image_embeds))])
all_pixel_data = [
v for vs in pixel_data.values() for v in vs.values()
if v is not None
]
all_embed_data = [v for v in embed_data.values() if v is not None]
if len(all_pixel_data) == 0 and len(all_embed_data) == 0:
return None
im_start_id = kwargs.pop("im_start_id")
if not isinstance(im_start_id, torch.Tensor):
raise ValueError("Incorrect type of im_start_id. "
f"Got type: {type(im_start_id)}")
im_end_id = kwargs.pop("im_end_id")
if not isinstance(im_end_id, torch.Tensor):
raise ValueError("Incorrect type of im_end_id. "
f"Got type: {type(im_end_id)}")
slice_start_id = kwargs.pop("slice_start_id", None)
if slice_start_id is not None and not isinstance(
slice_start_id, torch.Tensor):
raise ValueError("Incorrect type of slice_start_id. "
f"Got type: {type(slice_start_id)}")
slice_end_id = kwargs.pop("slice_end_id", None)
if slice_end_id is not None and not isinstance(slice_end_id,
torch.Tensor):
raise ValueError("Incorrect type of slice_end_id. "
f"Got type: {type(slice_end_id)}")
if len(all_embed_data) > 0:
if len(all_embed_data) > 1:
raise ValueError("Incorrect inputs for vision embeddings. "
"Image embeds and video embeds can not "
"exist simultaneously.")
vision_embeds, = all_embed_data
if not isinstance(vision_embeds, (torch.Tensor, list)):
raise ValueError(f"Incorrect type of vision_embeds. "
f"Got type: {type(vision_embeds)}")
return MiniCPMVImageEmbeddingInputs(
type="image_embeds",
image_embeds=flatten_bn(flatten_2d_lists(vision_embeds),
concat=True),
image_bounds=self._get_image_bounds(input_ids, im_start_id,
im_end_id, slice_start_id,
slice_end_id),
data=image_embeds,
type="image_embeds",
)
for modality, modality_mm_data in mm_data.items():
if not isinstance(modality_mm_data["pixel_values"],
(torch.Tensor, list)):
raise ValueError(
"Incorrect type of pixel values. "
f"Got type: {type(modality_mm_data['pixel_values'])}")
if not isinstance(modality_mm_data["tgt_sizes"],
(torch.Tensor, list)):
raise ValueError(
"Incorrect type of target sizes. "
f"Got type: {type(modality_mm_data['tgt_sizes'])}")
order_data = dict[str, Union[torch.Tensor, list[torch.Tensor]]]()
for modality in ("image", "video"):
modality_orders = kwargs.pop(f"{modality}_orders", None)
if modality_orders is not None:
if not isinstance(modality_orders, (torch.Tensor, list)):
raise ValueError(f"Incorrect type of {modality}_orders. "
f"Got type: {type(modality_orders)}")
if len(modality_mm_data["pixel_values"]) != len(
modality_mm_data["tgt_sizes"]):
raise ValueError(
"Inconsistent batch lengths, found: "
f"{len(modality_mm_data['pixel_values'])} vs. "
f"{len(modality_mm_data['tgt_sizes'])}")
order_data[modality] = modality_orders
pixel_values_flat: List[torch.Tensor] = []
tgt_sizes_flat: List[torch.Tensor] = []
batch_sizes = {
modality: len(modality_orders)
for modality, modality_orders in order_data.items()
}
unique_batch_sizes = set(batch_sizes.values())
assert len(unique_batch_sizes) == 1, (
f"Found inconsistent batch sizes: {batch_sizes}")
batch_size, = unique_batch_sizes
pixel_values_flat = list[torch.Tensor]()
tgt_sizes_flat = list[torch.Tensor]()
for b in range(batch_size):
mm_counts = {"image": 0, "video": 0} if self.version == (2, 6) \
else {"image": 0}
mm_slice_counts = {"image": 0, "video": 0} \
if self.version == (2, 6) else {"image": 0}
mm_orders_b = [(index, modality) for modality in mm_counts
for index in mm_orders[modality][b]]
mm_orders_b = [(idx_b.item(), modality)
for modality, modality_orders in order_data.items()
for idx_b in modality_orders[b]]
for _, modality in sorted(mm_orders_b, key=lambda x: x[0]):
pos = mm_counts[modality]
num_slices = mm_data[modality][f"{modality}_num_slices"][b][
pos]
slice_start_idx = mm_slice_counts[modality]
slice_end_idx = slice_start_idx + num_slices
pixel_values_flat += mm_data[modality]["pixel_values"][b][
slice_start_idx:slice_end_idx]
tgt_sizes_flat += mm_data[modality]["tgt_sizes"][b][
slice_start_idx:slice_end_idx]
mm_counts[modality] += 1
mm_slice_counts[modality] += num_slices
modality_pixel_data = pixel_data[modality]
modality_pixel_values = modality_pixel_data["pixel_values"]
if not isinstance(modality_pixel_values, (torch.Tensor, list)):
raise ValueError(
f"Incorrect type of pixel_values for {modality=}. "
f"Got type: {type(modality_pixel_values)}")
modality_tgt_sizes = modality_pixel_data["tgt_sizes"]
if not isinstance(modality_tgt_sizes, (torch.Tensor, list)):
raise ValueError(
f"Incorrect type of tgt_sizes for {modality=}. "
f"Got type: {type(modality_tgt_sizes)}")
pixel_values_flat += flatten_2d_lists(modality_pixel_values[b])
tgt_sizes_flat += flatten_2d_lists(modality_tgt_sizes[b])
# NOTE: Input IDs does not contain image tokens during memory profiling,
# so we allow it to be empty
@ -1042,16 +964,13 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP,
if len(pixel_values_flat) == 0:
return None
if im_start_id is None:
return None
return MiniCPMVImagePixelInputs(
type="pixel_values",
pixel_values=pixel_values_flat,
tgt_sizes=torch.stack(tgt_sizes_flat),
image_bounds=self._get_image_bounds(input_ids, im_start_id,
im_end_id, slice_start_id,
slice_end_id),
data=pixel_values_flat,
tgt_sizes=torch.stack(tgt_sizes_flat),
type="pixel_values",
)
def _parse_and_validate_inputs(self, input_ids: torch.Tensor,
@ -1070,7 +989,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP,
else:
image_inputs = \
self._parse_and_validate_inputs(input_ids, **kwargs)
vlm_embeddings, _ = self.get_embedding_with_vision(
vlm_embeddings = self.get_embedding_with_vision(
input_ids, image_inputs)
# always pass the input via `inputs_embeds`
@ -1136,16 +1055,8 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP,
prefix: str = "") -> nn.Module:
raise NotImplementedError
def get_vision_embedding(
self,
pixel_values: List[torch.Tensor],
patch_attn_mask: Optional[torch.Tensor] = None,
tgt_sizes: Optional[torch.Tensor] = None,
) -> torch.Tensor:
raise NotImplementedError
def get_vision_hidden_states(self,
data: MiniCPMVImageInputs) -> torch.Tensor:
def get_vision_hidden_states(
self, data: MiniCPMVImagePixelInputs) -> torch.Tensor:
raise NotImplementedError
@ -1216,35 +1127,27 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
return resampler.to(device=current_platform.device_type,
dtype=torch.get_default_dtype())
def get_vision_embedding(
self,
pixel_values: List[torch.Tensor],
patch_attn_mask: Optional[torch.Tensor] = None,
tgt_sizes: Optional[torch.Tensor] = None,
) -> torch.Tensor:
res = []
dtype = self.vpm.pos_embed.data.dtype
def get_vision_hidden_states(
self, data: MiniCPMVImagePixelInputs) -> torch.Tensor:
pixel_values = data["pixel_values"]
P_h, P_w = self.vpm.patch_embed.patch_size
dtype: torch.dtype = self.vpm.pos_embed.data.dtype
num_prefix_tokens = getattr(self.vpm, "num_prefix_tokens", 0)
res = list[torch.Tensor]()
for pixel_value in pixel_values:
H, W = pixel_value[0].shape[-2:]
tgt_size = (
math.ceil(H / self.vpm.patch_embed.patch_size[0]),
math.ceil(W / self.vpm.patch_embed.patch_size[0]),
)
tgt_size = (math.ceil(H / P_h), math.ceil(W / P_w))
vision_embedding = self.vpm.forward_features(
pixel_value.unsqueeze(0).type(dtype))
if (hasattr(self.vpm, "num_prefix_tokens")
and self.vpm.num_prefix_tokens > 0):
vision_embedding = vision_embedding[:, self.vpm.
num_prefix_tokens:]
if num_prefix_tokens > 0:
vision_embedding = vision_embedding[:, num_prefix_tokens:]
res.append(self.resampler(vision_embedding, tgt_size))
return torch.vstack(res)
def get_vision_hidden_states(self,
data: MiniCPMVImageInputs) -> torch.Tensor:
pixel_values = data["data"]
return self.get_vision_embedding(pixel_values)
class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
packed_modules_mapping = {
@ -1299,45 +1202,41 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
return resampler.to(device=current_platform.device_type,
dtype=torch.get_default_dtype())
def get_vision_embedding(
self,
pixel_values: List[torch.Tensor],
patch_attn_mask: Optional[torch.Tensor] = None,
tgt_sizes: Optional[torch.Tensor] = None,
) -> torch.Tensor:
vision_embedding = self.vpm(pixel_values,
patch_attention_mask=patch_attn_mask)
vision_embedding = self.resampler(vision_embedding, tgt_sizes)
return vision_embedding
def get_vision_hidden_states(self,
data: MiniCPMVImageInputs) -> torch.Tensor:
pixel_values = data["data"]
def get_vision_hidden_states(
self, data: MiniCPMVImagePixelInputs) -> torch.Tensor:
pixel_values = data["pixel_values"]
tgt_sizes = data["tgt_sizes"]
device = self.vpm.embeddings.position_embedding.weight.device
dtype = self.vpm.embeddings.position_embedding.weight.dtype
all_pixel_values_lst = [
i.flatten(end_dim=1).permute(1, 0) for i in pixel_values
]
B = len(pixel_values)
P = pixel_values[0].shape[-2]
L = max(item.shape[-1] for item in pixel_values)
device = pixel_values[0].device
dtype = pixel_values[0].dtype
max_patches = (tgt_sizes[:, 0] * tgt_sizes[:, 1]).max().item()
all_pixel_values = torch.zeros((B, 3, P, L),
dtype=dtype,
device=device)
for i, pixel_values_item in enumerate(pixel_values):
L_item = pixel_values_item.shape[-1]
all_pixel_values[i, ..., :L_item] = pixel_values_item
num_patches = tgt_sizes.prod(-1)
max_patches = num_patches.max().item()
assert isinstance(max_patches, int)
all_pixel_values = torch.nn.utils.rnn.pad_sequence(
all_pixel_values_lst, batch_first=True, padding_value=0.0)
B, L, _ = all_pixel_values.shape
all_pixel_values = all_pixel_values.permute(0, 2,
1).reshape(B, 3, -1, L)
patch_attn_mask = torch.zeros((B, 1, max_patches),
patch_attn_mask = torch.zeros((B, max_patches),
dtype=torch.bool,
device=device)
for i in range(B):
patch_attn_mask[i, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True
for i, num_patches_item in enumerate(num_patches):
patch_attn_mask[i, :num_patches_item] = True
return self.get_vision_embedding(all_pixel_values.type(dtype),
patch_attn_mask, tgt_sizes)
vision_embedding = self.vpm(
all_pixel_values,
patch_attention_mask=patch_attn_mask.unsqueeze(1),
tgt_sizes=None,
)
return self.resampler(vision_embedding, tgt_sizes)
class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
@ -1394,47 +1293,37 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
return resampler.to(device=current_platform.device_type,
dtype=torch.get_default_dtype())
def get_vision_embedding(
self,
pixel_values: List[torch.Tensor],
patch_attn_mask: Optional[torch.Tensor] = None,
tgt_sizes: Optional[torch.Tensor] = None,
) -> torch.Tensor:
vision_embedding = self.vpm(
pixel_values,
patch_attention_mask=patch_attn_mask,
tgt_sizes=tgt_sizes,
)
return vision_embedding
def get_vision_hidden_states(self,
data: MiniCPMVImageInputs) -> torch.Tensor:
pixel_values = data["data"]
def get_vision_hidden_states(
self, data: MiniCPMVImagePixelInputs) -> torch.Tensor:
pixel_values = data["pixel_values"]
tgt_sizes = data["tgt_sizes"]
device = self.vpm.embeddings.position_embedding.weight.device
dtype = self.vpm.embeddings.position_embedding.weight.dtype
all_pixel_values_lst = [
i.flatten(end_dim=1).permute(1, 0) for i in pixel_values
]
B = len(pixel_values)
P = pixel_values[0].shape[-2]
L = max(item.shape[-1] for item in pixel_values)
device = pixel_values[0].device
dtype = pixel_values[0].dtype
max_patches = (tgt_sizes[:, 0] * tgt_sizes[:, 1]).max().item()
all_pixel_values = torch.zeros((B, 3, P, L),
dtype=dtype,
device=device)
for i, pixel_values_item in enumerate(pixel_values):
L_item = pixel_values_item.shape[-1]
all_pixel_values[i, ..., :L_item] = pixel_values_item
num_patches = tgt_sizes.prod(-1)
max_patches = num_patches.max().item()
assert isinstance(max_patches, int)
all_pixel_values = torch.nn.utils.rnn.pad_sequence(
all_pixel_values_lst, batch_first=True, padding_value=0.0)
B, L, _ = all_pixel_values.shape
all_pixel_values = all_pixel_values.permute(0, 2,
1).reshape(B, 3, -1, L)
patch_attn_mask = torch.zeros((B, 1, max_patches),
patch_attn_mask = torch.zeros((B, max_patches),
dtype=torch.bool,
device=device)
for i in range(B):
patch_attn_mask[i, 0, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True
for i, num_patches_item in enumerate(num_patches):
patch_attn_mask[i, :num_patches_item] = True
vision_embedding = self.vpm(
all_pixel_values.type(dtype),
patch_attention_mask=patch_attn_mask,
all_pixel_values,
patch_attention_mask=patch_attn_mask.unsqueeze(1),
tgt_sizes=tgt_sizes,
)

View File

@ -77,9 +77,6 @@ class PixtralImagePixelInputs(TypedDict):
Shape: `(batch_size, num_images, num_embeds)`
"""
num_embeds: Union[torch.Tensor, list[torch.Tensor]]
"""Shape: `(batch_size, num_images)`"""
class PixtralProcessorAdapter:
"""
@ -153,7 +150,6 @@ class PixtralProcessorAdapter:
images_processed = list[torch.Tensor]()
images_tokens = list[torch.Tensor]()
images_embed_is_patch = list[torch.Tensor]()
images_num_embeds = list[int]()
for image in images:
image_inputs = self.image_processor(ImageChunk(image=image))
@ -163,13 +159,11 @@ class PixtralProcessorAdapter:
images_processed.append(image_processed)
images_tokens.append(image_tokens)
images_embed_is_patch.append(image_tokens == image_token_id)
images_num_embeds.append(len(image_tokens))
return {
"input_ids": torch.cat(images_tokens)[None].expand(len(text), -1),
"images": images_processed,
"embed_is_patch": images_embed_is_patch,
"num_embeds": torch.tensor(images_num_embeds),
}
@ -273,7 +267,6 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
return dict(
images=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
num_embeds=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_updates(
@ -394,16 +387,10 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
num_embeds = kwargs.pop("num_embeds")
if not isinstance(num_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of num_embeds. "
f"Got type: {type(num_embeds)}")
return PixtralImagePixelInputs(
type="pixel_values",
images=flatten_bn(images),
embed_is_patch=embed_is_patch,
num_embeds=num_embeds,
)
def _process_image_input(
@ -447,7 +434,6 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
return flatten_2d_lists(
scatter_patch_features(*args) for args in zip(
image_features,
image_input["num_embeds"],
image_input["embed_is_patch"],
))

View File

@ -155,7 +155,6 @@ def resolve_visual_encoder_outputs(
def scatter_patch_features(
features: torch.Tensor,
num_embeds: torch.Tensor,
embed_is_patch: torch.Tensor,
) -> tuple[torch.Tensor, ...]:
"""
@ -168,13 +167,35 @@ def scatter_patch_features(
Args:
features: The patch features, concatenated across each image.
Shape: `(num_patch, feature_depth)`
num_embeds: The number of image embeddings for each image.
Shape: `(num_images,)`
embed_is_patch: A boolean mask indicating which image embeddings
correspond to patch tokens for each image.
Shape: `(num_images, num_embeds)`
Note:
The original code only considers patch tokens as feature
tokens, but our processor considers all image-related tokens
as feature tokens because the feature tokens need to be
consecutive in `input_ids`.
Example:
A simplified example for one image:
.. code-block::
Embedding tokens (from HF processor):
[<start> <patch> <patch> <col> <patch> <patch> <col> <end> ]
embed_is_patch (from HF processor):
[ False True True False True True False False ]
Encoder outputs (from model):
[ p1 p2 p3 p4 ]
The resulting embedding tensor is:
[ nan p1 p2 nan p3 p4 nan nan ]
"""
num_embeds_per_image: list[int] = num_embeds.tolist()
num_images, num_embeds = embed_is_patch.shape
num_embeds_per_image = [num_embeds] * num_images
embeds_flat = features.new_full(
(sum(num_embeds_per_image), features.shape[-1]),

View File

@ -665,6 +665,13 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
return cast(BatchedTensorInputs, json_mapped)
def __delitem__(self, key: str) -> None:
super().__delitem__(key)
for items in self._items_by_modality.values():
for item in items:
item.pop(key, None)
def __eq__(self, other: object) -> bool:
if not isinstance(other, self.__class__):
return False

View File

@ -37,6 +37,9 @@ class CpuPlatform(Platform):
use_mla: bool) -> str:
if selected_backend and selected_backend != _Backend.TORCH_SDPA:
logger.info("Cannot use %s backend on CPU.", selected_backend)
if use_mla:
logger.info("Using CPU MLA backend.")
return "vllm.attention.backends.cpu_mla.CPUMLABackend"
logger.info("Using Torch SDPA backend.")
return "vllm.attention.backends.torch_sdpa.TorchSDPABackend"
@ -129,9 +132,6 @@ class CpuPlatform(Platform):
# Disable torch async compiling which won't work with daemonic processes
os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"
# MLA attention is not supported
os.environ["VLLM_MLA_DISABLE"] = "1"
# Intel OpenMP setting
ld_prealod_str = os.getenv("LD_PRELOAD", "")
if "libiomp5.so" in ld_prealod_str:

View File

@ -9,6 +9,7 @@ logger = init_logger(__name__)
class BlockTable:
"""Device-agnostic block table for storing block IDs for each request."""
def __init__(
self,

View File

@ -0,0 +1,164 @@
# SPDX-License-Identifier: Apache-2.0
from typing import List, Set
import numpy as np
import torch
from vllm import _custom_ops as ops
from vllm.logger import init_logger
logger = init_logger(__name__)
class GPUBlockTable:
def __init__(
self,
max_num_reqs: int,
max_num_blocks_per_req: int,
pin_memory: bool,
device: torch.device,
):
self.max_num_reqs = max_num_reqs
self.max_num_blocks_per_req = max_num_blocks_per_req
self.pin_memory = pin_memory
self.device = device
self.block_table = torch.zeros(
(max_num_reqs, max_num_blocks_per_req),
device=self.device,
dtype=torch.int32,
)
self.block_table_cpu = torch.zeros(
(max_num_reqs, max_num_blocks_per_req),
device="cpu",
dtype=torch.int32,
pin_memory=False,
)
self.block_table_np = self.block_table_cpu.numpy()
self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32)
self.block_table_diff_np = np.zeros(
(max_num_reqs, 2),
dtype=np.int32,
)
self.diff_rows: Set[int] = set()
self.append_row_indices = torch.zeros(
(max_num_reqs, 2),
dtype=torch.int32,
device=self.device,
)
self.append_row_indices_cpu = torch.zeros_like(
self.append_row_indices,
device="cpu",
pin_memory=pin_memory,
)
self.append_row_indices_np = self.append_row_indices_cpu.numpy()
self.append_cumsums = torch.zeros(
(max_num_reqs + 1, ),
dtype=torch.int32,
device=self.device,
)
self.append_cumsums_cpu = torch.zeros_like(
self.append_cumsums,
device="cpu",
pin_memory=pin_memory,
)
self.append_cumsums_np = self.append_cumsums_cpu.numpy()
self.append_data = torch.zeros(
(max_num_reqs * max_num_blocks_per_req, ),
dtype=torch.int32,
device=self.device,
)
self.append_data_cpu = torch.zeros_like(
self.append_data,
device="cpu",
pin_memory=pin_memory,
)
self.append_data_np = self.append_data_cpu.numpy()
def append_row(
self,
row_idx: int,
start: int,
block_ids: List[int],
) -> None:
num_blocks = len(block_ids)
self.block_table_np[row_idx, start:start + num_blocks] = block_ids
self.num_blocks_per_row[row_idx] = start + num_blocks
self.block_table_diff_np[row_idx, 0] = start
self.block_table_diff_np[row_idx, 1] = num_blocks
self.diff_rows.add(row_idx)
def add_row(self, row_idx: int, block_ids: List[int]) -> None:
self.append_row(row_idx, 0, block_ids)
def move_row(self, src: int, tgt: int) -> None:
num_blocks = self.num_blocks_per_row[src]
self.block_table_np[tgt, :num_blocks] = self.block_table_np[
src, :num_blocks]
self.num_blocks_per_row[tgt] = num_blocks
self.block_table_diff_np[tgt, 0] = 0
self.block_table_diff_np[tgt, 1] = num_blocks
self.diff_rows.discard(src)
self.diff_rows.add(tgt)
def commit(self, num_reqs: int) -> None:
if not self.diff_rows:
return
cu_end = 0
self.append_cumsums_np[0] = 0
for i, row_idx in enumerate(self.diff_rows):
start, num_blocks = self.block_table_diff_np[row_idx]
assert num_blocks > 0
self.append_row_indices_np[i, 0] = row_idx
self.append_row_indices_np[i, 1] = start
cu_start = self.append_cumsums_np[i]
cu_end = cu_start + num_blocks
self.append_cumsums_np[i + 1] = cu_end
self.append_data_np[cu_start:cu_end] = self.block_table_np[
row_idx, start:start + num_blocks]
ops.block_table_appends(
self.append_row_indices,
self.append_row_indices_cpu,
self.append_cumsums,
self.append_cumsums_cpu,
self.append_data,
self.append_data_cpu,
self.block_table,
len(self.diff_rows),
cu_end,
)
self.diff_rows.clear()
def clear(self) -> None:
self.block_table.fill_(0)
self.block_table_cpu.fill_(0)
self.diff_rows.clear()
self.block_table_diff_np.fill(0)
self.append_row_indices.fill_(0)
self.append_row_indices_cpu.fill_(0)
self.append_cumsums.fill_(0)
self.append_cumsums_cpu.fill_(0)
self.append_data.fill_(0)
self.append_data_cpu.fill_(0)
def get_device_tensor(self) -> torch.Tensor:
"""Ruturns the device tensor of the block table."""
return self.block_table
def get_cpu_tensor(self) -> torch.Tensor:
"""Returns the CPU tensor of the block table."""
return self.block_table_cpu
def get_numpy_array(self) -> np.ndarray:
"""Returns the numpy array of the block table."""
return self.block_table_np

View File

@ -14,7 +14,7 @@ from vllm.utils import swap_dict_values
from vllm.v1.outputs import LogprobsTensors
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.utils import copy_slice
from vllm.v1.worker.block_table import BlockTable
from vllm.v1.worker.gpu_block_table import GPUBlockTable
_SAMPLING_EPS = 1e-5
@ -92,7 +92,7 @@ class InputBatch:
self.num_computed_tokens_cpu_tensor.numpy()
# Block table.
self.block_table = BlockTable(
self.block_table = GPUBlockTable(
max_num_reqs=max_num_reqs,
max_num_blocks_per_req=max_num_blocks_per_req,
pin_memory=pin_memory,

View File

@ -66,14 +66,18 @@ class TPUWorker:
from vllm.utils import init_cached_hf_modules
init_cached_hf_modules()
# Delay profiler initialization to the start of the profiling.
# This is because in vLLM V1, MP runtime is initialized before the
# TPU Worker is initialized. The profiler server needs to start after
# MP runtime is initialized.
self.profiler = None
self.profile_dir = None
if envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1:
# For TPU, we can only have 1 active profiler session for 1 profiler
# server. So we only profile on rank0.
self.profile_dir = envs.VLLM_TORCH_PROFILER_DIR
logger.info("Profiling enabled. Traces will be saved to: %s",
self.profile_dir)
self.profiler = xp.start_server(9012)
if self.model_config.seed is None:
self.model_config.seed = 0
@ -168,9 +172,11 @@ class TPUWorker:
def profile(self, is_start: bool = True):
if self.rank < 1:
if self.profiler is None:
if self.profile_dir is None:
raise RuntimeError("Profiler is not enabled.")
if is_start:
if self.profiler is None:
self.profiler = xp.start_server(9012)
xp.start_trace(self.profile_dir)
else:
xp.stop_trace()

View File

@ -469,6 +469,7 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
self.kv_cache_dtype,
self.block_size,
self.model_config.is_attention_free,
use_mla=self.model_config.use_mla,
) if needs_attn_backend else None
# Multi-modal data support

View File

@ -66,6 +66,7 @@ class CPUCacheEngine:
cache_config.cache_dtype,
self.block_size,
self.model_config.is_attention_free,
use_mla=self.model_config.use_mla,
)
# Initialize the cache.