mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-24 17:04:35 +08:00
Compare commits
52 Commits
Author | SHA1 | Date | |
---|---|---|---|
852ef5b4f5 | |||
db09d4ad83 | |||
c957c741d9 | |||
c07ece5ca4 | |||
7a9c20c715 | |||
005ba458b5 | |||
320a622ec4 | |||
c9927c1a6a | |||
fbd80ad409 | |||
22379d5513 | |||
1696725879 | |||
002800f081 | |||
e15932bb60 | |||
ce741ba3e4 | |||
bf87484efa | |||
8ce9c50d40 | |||
32b6816e55 | |||
c128d69856 | |||
55b28b1eee | |||
e11222333f | |||
28873a2799 | |||
0080d8329d | |||
0d93f15694 | |||
becd7a56f1 | |||
75471386de | |||
d2b2eed67c | |||
4b6f069b6f | |||
791d79de32 | |||
94d2f59895 | |||
75c0ca9d43 | |||
2a4ec90854 | |||
85ebcda94d | |||
d64bf1646c | |||
a41c20435e | |||
eedac9dba0 | |||
14f9c72bfd | |||
ad5f2fe34c | |||
4f8584756d | |||
65fc1c3127 | |||
c393af6cd7 | |||
0c04ce3234 | |||
73b3de79ea | |||
d1744376ae | |||
805de738f6 | |||
1b151ed181 | |||
e06f504a76 | |||
462ae5220a | |||
66c54aa9c3 | |||
735ecfff61 | |||
a57d13cc96 | |||
79af7e96a0 | |||
621980bdc0 |
101
.github/workflows/publish.yml
vendored
Normal file
101
.github/workflows/publish.yml
vendored
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
# This workflow will upload a Python Package to Release asset
|
||||||
|
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions
|
||||||
|
|
||||||
|
name: Create Release
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
tags:
|
||||||
|
- v*
|
||||||
|
|
||||||
|
# Needed to create release and upload assets
|
||||||
|
permissions:
|
||||||
|
contents: write
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
release:
|
||||||
|
# Retrieve tag and create release
|
||||||
|
name: Create Release
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
outputs:
|
||||||
|
upload_url: ${{ steps.create_release.outputs.upload_url }}
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v3
|
||||||
|
|
||||||
|
- name: Extract branch info
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
echo "release_tag=${GITHUB_REF#refs/*/}" >> $GITHUB_ENV
|
||||||
|
|
||||||
|
- name: Create Release
|
||||||
|
id: create_release
|
||||||
|
uses: "actions/github-script@v6"
|
||||||
|
env:
|
||||||
|
RELEASE_TAG: ${{ env.release_tag }}
|
||||||
|
with:
|
||||||
|
github-token: "${{ secrets.GITHUB_TOKEN }}"
|
||||||
|
script: |
|
||||||
|
const script = require('.github/workflows/scripts/create_release.js')
|
||||||
|
await script(github, context, core)
|
||||||
|
|
||||||
|
wheel:
|
||||||
|
name: Build Wheel
|
||||||
|
runs-on: ${{ matrix.os }}
|
||||||
|
needs: release
|
||||||
|
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
os: ['ubuntu-20.04']
|
||||||
|
python-version: ['3.8', '3.9', '3.10', '3.11']
|
||||||
|
cuda-version: ['11.8'] # Github runner can't build anything older than 11.8
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v3
|
||||||
|
|
||||||
|
- name: Set up Linux Env
|
||||||
|
if: ${{ runner.os == 'Linux' }}
|
||||||
|
run: |
|
||||||
|
bash -x .github/workflows/scripts/env.sh
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v4
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
|
||||||
|
- name: Install CUDA ${{ matrix.cuda-version }}
|
||||||
|
run: |
|
||||||
|
bash -x .github/workflows/scripts/cuda-install.sh ${{ matrix.cuda-version }} ${{ matrix.os }}
|
||||||
|
|
||||||
|
- name: Install PyTorch-cu${{ matrix.cuda-version }}
|
||||||
|
run: |
|
||||||
|
bash -x .github/workflows/scripts/pytorch-install.sh ${{ matrix.python-version }} ${{ matrix.cuda-version }}
|
||||||
|
|
||||||
|
- name: Build wheel
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
bash -x .github/workflows/scripts/build.sh ${{ matrix.python-version }} ${{ matrix.cuda-version }}
|
||||||
|
wheel_name=$(ls dist/*whl | xargs -n 1 basename)
|
||||||
|
asset_name=${wheel_name//"linux"/"manylinux1"}
|
||||||
|
echo "wheel_name=${wheel_name}" >> $GITHUB_ENV
|
||||||
|
echo "asset_name=${asset_name}" >> $GITHUB_ENV
|
||||||
|
|
||||||
|
- name: Upload Release Asset
|
||||||
|
uses: actions/upload-release-asset@v1
|
||||||
|
env:
|
||||||
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
with:
|
||||||
|
upload_url: ${{ needs.release.outputs.upload_url }}
|
||||||
|
asset_path: ./dist/${{ env.wheel_name }}
|
||||||
|
asset_name: ${{ env.asset_name }}
|
||||||
|
asset_content_type: application/*
|
||||||
|
|
||||||
|
# (Danielkinz): This last step will publish the .whl to pypi. Warning: untested
|
||||||
|
# - name: Publish package
|
||||||
|
# uses: pypa/gh-action-pypi-publish@release/v1.8
|
||||||
|
# with:
|
||||||
|
# repository-url: https://test.pypi.org/legacy/
|
||||||
|
# password: ${{ secrets.PYPI_API_TOKEN }}
|
||||||
|
# skip-existing: true
|
15
.github/workflows/scripts/build.sh
vendored
Normal file
15
.github/workflows/scripts/build.sh
vendored
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
python_executable=python$1
|
||||||
|
cuda_home=/usr/local/cuda-$2
|
||||||
|
|
||||||
|
# Update paths
|
||||||
|
PATH=${cuda_home}/bin:$PATH
|
||||||
|
LD_LIBRARY_PATH=${cuda_home}/lib64:$LD_LIBRARY_PATH
|
||||||
|
|
||||||
|
# Install requirements
|
||||||
|
$python_executable -m pip install wheel packaging
|
||||||
|
$python_executable -m pip install -r requirements.txt
|
||||||
|
|
||||||
|
# Build
|
||||||
|
$python_executable setup.py bdist_wheel --dist-dir=dist
|
20
.github/workflows/scripts/create_release.js
vendored
Normal file
20
.github/workflows/scripts/create_release.js
vendored
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
// Uses Github's API to create the release and wait for result.
|
||||||
|
// We use a JS script since github CLI doesn't provide a way to wait for the release's creation and returns immediately.
|
||||||
|
|
||||||
|
module.exports = async (github, context, core) => {
|
||||||
|
try {
|
||||||
|
const response = await github.rest.repos.createRelease({
|
||||||
|
draft: false,
|
||||||
|
generate_release_notes: true,
|
||||||
|
name: process.env.RELEASE_TAG,
|
||||||
|
owner: context.repo.owner,
|
||||||
|
prerelease: false,
|
||||||
|
repo: context.repo.repo,
|
||||||
|
tag_name: process.env.RELEASE_TAG,
|
||||||
|
});
|
||||||
|
|
||||||
|
core.setOutput('upload_url', response.data.upload_url);
|
||||||
|
} catch (error) {
|
||||||
|
core.setFailed(error.message);
|
||||||
|
}
|
||||||
|
}
|
18
.github/workflows/scripts/cuda-install.sh
vendored
Normal file
18
.github/workflows/scripts/cuda-install.sh
vendored
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Replace '.' with '-' ex: 11.8 -> 11-8
|
||||||
|
cuda_version=$(echo $1 | tr "." "-")
|
||||||
|
# Removes '-' and '.' ex: ubuntu-20.04 -> ubuntu2004
|
||||||
|
OS=$(echo $2 | tr -d ".\-")
|
||||||
|
|
||||||
|
# Installs CUDA
|
||||||
|
wget -nv https://developer.download.nvidia.com/compute/cuda/repos/${OS}/x86_64/cuda-keyring_1.1-1_all.deb
|
||||||
|
sudo dpkg -i cuda-keyring_1.1-1_all.deb
|
||||||
|
rm cuda-keyring_1.1-1_all.deb
|
||||||
|
sudo apt -qq update
|
||||||
|
sudo apt -y install cuda-${cuda_version} cuda-nvcc-${cuda_version} cuda-libraries-dev-${cuda_version}
|
||||||
|
sudo apt clean
|
||||||
|
|
||||||
|
# Test nvcc
|
||||||
|
PATH=/usr/local/cuda-$1/bin:${PATH}
|
||||||
|
nvcc --version
|
56
.github/workflows/scripts/env.sh
vendored
Normal file
56
.github/workflows/scripts/env.sh
vendored
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# This file installs common linux environment tools
|
||||||
|
|
||||||
|
export LANG C.UTF-8
|
||||||
|
|
||||||
|
# python_version=$1
|
||||||
|
|
||||||
|
sudo apt-get update && \
|
||||||
|
sudo apt-get install -y --no-install-recommends \
|
||||||
|
software-properties-common \
|
||||||
|
|
||||||
|
sudo apt-get install -y --no-install-recommends \
|
||||||
|
build-essential \
|
||||||
|
apt-utils \
|
||||||
|
ca-certificates \
|
||||||
|
wget \
|
||||||
|
git \
|
||||||
|
vim \
|
||||||
|
libssl-dev \
|
||||||
|
curl \
|
||||||
|
unzip \
|
||||||
|
unrar \
|
||||||
|
cmake \
|
||||||
|
net-tools \
|
||||||
|
sudo \
|
||||||
|
autotools-dev \
|
||||||
|
rsync \
|
||||||
|
jq \
|
||||||
|
openssh-server \
|
||||||
|
tmux \
|
||||||
|
screen \
|
||||||
|
htop \
|
||||||
|
pdsh \
|
||||||
|
openssh-client \
|
||||||
|
lshw \
|
||||||
|
dmidecode \
|
||||||
|
util-linux \
|
||||||
|
automake \
|
||||||
|
autoconf \
|
||||||
|
libtool \
|
||||||
|
net-tools \
|
||||||
|
pciutils \
|
||||||
|
libpci-dev \
|
||||||
|
libaio-dev \
|
||||||
|
libcap2 \
|
||||||
|
libtinfo5 \
|
||||||
|
fakeroot \
|
||||||
|
devscripts \
|
||||||
|
debhelper \
|
||||||
|
nfs-common
|
||||||
|
|
||||||
|
# Remove github bloat files to free up disk space
|
||||||
|
sudo rm -rf "/usr/local/share/boost"
|
||||||
|
sudo rm -rf "$AGENT_TOOLSDIRECTORY"
|
||||||
|
sudo rm -rf "/usr/share/dotnet"
|
14
.github/workflows/scripts/pytorch-install.sh
vendored
Normal file
14
.github/workflows/scripts/pytorch-install.sh
vendored
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
python_executable=python$1
|
||||||
|
cuda_version=$2
|
||||||
|
|
||||||
|
# Install torch
|
||||||
|
$python_executable -m pip install numpy pyyaml scipy ipython mkl mkl-include ninja cython typing pandas typing-extensions dataclasses setuptools && conda clean -ya
|
||||||
|
$python_executable -m pip install torch -f https://download.pytorch.org/whl/cu${cuda_version//./}/torch_stable.html
|
||||||
|
|
||||||
|
# Print version information
|
||||||
|
$python_executable --version
|
||||||
|
$python_executable -c "import torch; print('PyTorch:', torch.__version__)"
|
||||||
|
$python_executable -c "import torch; print('CUDA:', torch.version.cuda)"
|
||||||
|
$python_executable -c "from torch.utils import cpp_extension; print (cpp_extension.CUDA_HOME)"
|
@ -17,6 +17,7 @@ Easy, fast, and cheap LLM serving for everyone
|
|||||||
---
|
---
|
||||||
|
|
||||||
*Latest News* 🔥
|
*Latest News* 🔥
|
||||||
|
- [2023/08] We would like to express our sincere gratitude to [Andreessen Horowitz](https://a16z.com/2023/08/30/supporting-the-open-source-ai-community/) (a16z) for providing a generous grant to support the open-source development and research of vLLM.
|
||||||
- [2023/07] Added support for LLaMA-2! You can run and serve 7B/13B/70B LLaMA-2s on vLLM with a single command!
|
- [2023/07] Added support for LLaMA-2! You can run and serve 7B/13B/70B LLaMA-2s on vLLM with a single command!
|
||||||
- [2023/06] Serving vLLM On any Cloud with SkyPilot. Check out a 1-click [example](https://github.com/skypilot-org/skypilot/blob/master/llm/vllm) to start the vLLM demo, and the [blog post](https://blog.skypilot.co/serving-llm-24x-faster-on-the-cloud-with-vllm-and-skypilot/) for the story behind vLLM development on the clouds.
|
- [2023/06] Serving vLLM On any Cloud with SkyPilot. Check out a 1-click [example](https://github.com/skypilot-org/skypilot/blob/master/llm/vllm) to start the vLLM demo, and the [blog post](https://blog.skypilot.co/serving-llm-24x-faster-on-the-cloud-with-vllm-and-skypilot/) for the story behind vLLM development on the clouds.
|
||||||
- [2023/06] We officially released vLLM! FastChat-vLLM integration has powered [LMSYS Vicuna and Chatbot Arena](https://chat.lmsys.org) since mid-April. Check out our [blog post](https://vllm.ai).
|
- [2023/06] We officially released vLLM! FastChat-vLLM integration has powered [LMSYS Vicuna and Chatbot Arena](https://chat.lmsys.org) since mid-April. Check out our [blog post](https://vllm.ai).
|
||||||
@ -42,6 +43,7 @@ vLLM is flexible and easy to use with:
|
|||||||
|
|
||||||
vLLM seamlessly supports many Huggingface models, including the following architectures:
|
vLLM seamlessly supports many Huggingface models, including the following architectures:
|
||||||
|
|
||||||
|
- Aquila (`BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc.)
|
||||||
- Baichuan (`baichuan-inc/Baichuan-7B`, `baichuan-inc/Baichuan-13B-Chat`, etc.)
|
- Baichuan (`baichuan-inc/Baichuan-7B`, `baichuan-inc/Baichuan-13B-Chat`, etc.)
|
||||||
- BLOOM (`bigscience/bloom`, `bigscience/bloomz`, etc.)
|
- BLOOM (`bigscience/bloom`, `bigscience/bloomz`, etc.)
|
||||||
- Falcon (`tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc.)
|
- Falcon (`tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc.)
|
||||||
@ -49,9 +51,11 @@ vLLM seamlessly supports many Huggingface models, including the following archit
|
|||||||
- GPT BigCode (`bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, etc.)
|
- GPT BigCode (`bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, etc.)
|
||||||
- GPT-J (`EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc.)
|
- GPT-J (`EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc.)
|
||||||
- GPT-NeoX (`EleutherAI/gpt-neox-20b`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc.)
|
- GPT-NeoX (`EleutherAI/gpt-neox-20b`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc.)
|
||||||
|
- InternLM (`internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc.)
|
||||||
- LLaMA & LLaMA-2 (`meta-llama/Llama-2-70b-hf`, `lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.)
|
- LLaMA & LLaMA-2 (`meta-llama/Llama-2-70b-hf`, `lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.)
|
||||||
- MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.)
|
- MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.)
|
||||||
- OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.)
|
- OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.)
|
||||||
|
- Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.)
|
||||||
|
|
||||||
Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source):
|
Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source):
|
||||||
|
|
||||||
|
@ -4,9 +4,25 @@ void silu_and_mul(
|
|||||||
torch::Tensor& out,
|
torch::Tensor& out,
|
||||||
torch::Tensor& input);
|
torch::Tensor& input);
|
||||||
|
|
||||||
|
void gelu_new(
|
||||||
|
torch::Tensor& out,
|
||||||
|
torch::Tensor& input);
|
||||||
|
|
||||||
|
void gelu_fast(
|
||||||
|
torch::Tensor& out,
|
||||||
|
torch::Tensor& input);
|
||||||
|
|
||||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||||
m.def(
|
m.def(
|
||||||
"silu_and_mul",
|
"silu_and_mul",
|
||||||
&silu_and_mul,
|
&silu_and_mul,
|
||||||
"Activation function used in SwiGLU.");
|
"Activation function used in SwiGLU.");
|
||||||
|
m.def(
|
||||||
|
"gelu_new",
|
||||||
|
&gelu_new,
|
||||||
|
"GELU implementation used in GPT-2.");
|
||||||
|
m.def(
|
||||||
|
"gelu_fast",
|
||||||
|
&gelu_fast,
|
||||||
|
"Approximate GELU implementation.");
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
|
||||||
|
#include "dispatch_utils.h"
|
||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
@ -34,9 +36,7 @@ void silu_and_mul(
|
|||||||
dim3 grid(num_tokens);
|
dim3 grid(num_tokens);
|
||||||
dim3 block(std::min(d, 1024));
|
dim3 block(std::min(d, 1024));
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
VLLM_DISPATCH_FLOATING_TYPES(
|
||||||
at::ScalarType::Half,
|
|
||||||
at::ScalarType::BFloat16,
|
|
||||||
input.scalar_type(),
|
input.scalar_type(),
|
||||||
"silu_and_mul_kernel",
|
"silu_and_mul_kernel",
|
||||||
[&] {
|
[&] {
|
||||||
@ -46,3 +46,69 @@ void silu_and_mul(
|
|||||||
d);
|
d);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
|
||||||
|
// Element-wise activation kernel template.
|
||||||
|
template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
|
||||||
|
__global__ void activation_kernel(
|
||||||
|
scalar_t* __restrict__ out, // [num_tokens, d]
|
||||||
|
const scalar_t* __restrict__ input, // [num_tokens, d]
|
||||||
|
const int d) {
|
||||||
|
const int token_idx = blockIdx.x;
|
||||||
|
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||||
|
const scalar_t x = __ldg(&input[token_idx * d + idx]);
|
||||||
|
out[token_idx * d + idx] = ACT_FN(x);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace vllm
|
||||||
|
|
||||||
|
// Launch element-wise activation kernel.
|
||||||
|
#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
|
||||||
|
int num_tokens = input.size(0); \
|
||||||
|
int d = input.size(1); \
|
||||||
|
dim3 grid(num_tokens); \
|
||||||
|
dim3 block(std::min(d, 1024)); \
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|
||||||
|
VLLM_DISPATCH_FLOATING_TYPES( \
|
||||||
|
input.scalar_type(), \
|
||||||
|
"activation_kernel", \
|
||||||
|
[&] { \
|
||||||
|
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>( \
|
||||||
|
out.data_ptr<scalar_t>(), \
|
||||||
|
input.data_ptr<scalar_t>(), \
|
||||||
|
d); \
|
||||||
|
});
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
__device__ __forceinline__ T gelu_new_kernel(const T& x) {
|
||||||
|
const float x3 = (float) (x * x * x);
|
||||||
|
const T t = (T) tanhf((T) (0.79788456f * (float) (x + (T) (0.044715f * x3))));
|
||||||
|
return ((T) 0.5) * x * (((T) 1.0) + t);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
__device__ __forceinline__ T gelu_fast_kernel(const T& x) {
|
||||||
|
const float f = (float) x;
|
||||||
|
const T t = (T) tanhf(((T) (f * 0.79788456f)) * (((T) 1.0) + (T) (0.044715f * f) * x));
|
||||||
|
return ((T) 0.5) * x * (((T) 1.0) + t);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace vllm
|
||||||
|
|
||||||
|
void gelu_new(
|
||||||
|
torch::Tensor& out, // [num_tokens, d]
|
||||||
|
torch::Tensor& input) // [num_tokens, d]
|
||||||
|
{
|
||||||
|
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel);
|
||||||
|
}
|
||||||
|
|
||||||
|
void gelu_fast(
|
||||||
|
torch::Tensor& out, // [num_tokens, d]
|
||||||
|
torch::Tensor& input) // [num_tokens, d]
|
||||||
|
{
|
||||||
|
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel);
|
||||||
|
}
|
||||||
|
@ -86,6 +86,8 @@ __global__ void single_query_cached_kv_attention_kernel(
|
|||||||
const int kv_block_stride,
|
const int kv_block_stride,
|
||||||
const int kv_head_stride) {
|
const int kv_head_stride) {
|
||||||
constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
|
constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
|
||||||
|
constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS
|
||||||
|
assert(NUM_THREADS % THREAD_GROUP_SIZE == 0);
|
||||||
constexpr int NUM_TOKENS_PER_THREAD_GROUP = (BLOCK_SIZE + WARP_SIZE - 1) / WARP_SIZE;
|
constexpr int NUM_TOKENS_PER_THREAD_GROUP = (BLOCK_SIZE + WARP_SIZE - 1) / WARP_SIZE;
|
||||||
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
||||||
const int thread_idx = threadIdx.x;
|
const int thread_idx = threadIdx.x;
|
||||||
@ -120,12 +122,13 @@ __global__ void single_query_cached_kv_attention_kernel(
|
|||||||
// th vectors of the query, and so on.
|
// th vectors of the query, and so on.
|
||||||
// NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous.
|
// NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous.
|
||||||
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
|
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
|
||||||
Q_vec q_vecs[NUM_VECS_PER_THREAD];
|
__shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < NUM_VECS_PER_THREAD; i++) {
|
for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS) {
|
||||||
const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
|
const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
|
||||||
q_vecs[i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
|
q_vecs[thread_group_offset][i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
|
||||||
}
|
}
|
||||||
|
__syncthreads(); // TODO(naed90): possible speedup if this is replaced with a memory wall right before we use q_vecs
|
||||||
|
|
||||||
// Memory planning.
|
// Memory planning.
|
||||||
extern __shared__ char shared_mem[];
|
extern __shared__ char shared_mem[];
|
||||||
@ -173,9 +176,9 @@ __global__ void single_query_cached_kv_attention_kernel(
|
|||||||
|
|
||||||
// Compute dot product.
|
// Compute dot product.
|
||||||
// This includes a reduction across the threads in the same thread group.
|
// This includes a reduction across the threads in the same thread group.
|
||||||
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs, k_vecs);
|
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs);
|
||||||
// Add the ALiBi bias if slopes are given.
|
// Add the ALiBi bias if slopes are given.
|
||||||
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len) : 0;
|
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0;
|
||||||
|
|
||||||
if (thread_group_offset == 0) {
|
if (thread_group_offset == 0) {
|
||||||
// Store the partial reductions to shared memory.
|
// Store the partial reductions to shared memory.
|
||||||
@ -243,6 +246,8 @@ __global__ void single_query_cached_kv_attention_kernel(
|
|||||||
accs[i] = 0.f;
|
accs[i] = 0.f;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
scalar_t zero_value;
|
||||||
|
zero(zero_value);
|
||||||
for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) {
|
for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) {
|
||||||
const int physical_block_number = block_table[block_idx];
|
const int physical_block_number = block_table[block_idx];
|
||||||
const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
|
const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
|
||||||
@ -258,6 +263,16 @@ __global__ void single_query_cached_kv_attention_kernel(
|
|||||||
if (row_idx < HEAD_SIZE) {
|
if (row_idx < HEAD_SIZE) {
|
||||||
const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
|
const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
|
||||||
V_vec v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
|
V_vec v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
|
||||||
|
if (block_idx == num_blocks - 1) {
|
||||||
|
// NOTE(woosuk): When v_vec contains the tokens that are out of the context,
|
||||||
|
// we should explicitly zero out the values since they may contain NaNs.
|
||||||
|
// See https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
|
||||||
|
scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j <= V_VEC_SIZE; j++) {
|
||||||
|
v_vec_ptr[j] = token_idx + j < context_len ? v_vec_ptr[j] : zero_value;
|
||||||
|
}
|
||||||
|
}
|
||||||
accs[i] += dot(logits_vec, v_vec);
|
accs[i] += dot(logits_vec, v_vec);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -420,4 +420,14 @@ inline __device__ void from_float(bf16_8_t& dst, Float8_ src) {
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Zero-out a variable.
|
||||||
|
inline __device__ void zero(__nv_bfloat16& dst) {
|
||||||
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||||
|
assert(false);
|
||||||
|
#else
|
||||||
|
// Same as CUDART_ZERO_BF16 introduced in CUDA 12.2.
|
||||||
|
dst = __ushort_as_bfloat16((unsigned short)0x0000U);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
@ -390,11 +390,6 @@ inline __device__ float sum(uint4 v) {
|
|||||||
return sum(c);
|
return sum(c);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Zero-out a vector.
|
|
||||||
inline __device__ void zero(uint16_t& dst) {
|
|
||||||
dst = uint16_t(0);
|
|
||||||
}
|
|
||||||
|
|
||||||
// From float32 to float16.
|
// From float32 to float16.
|
||||||
inline __device__ void from_float(uint16_t& dst, float src) {
|
inline __device__ void from_float(uint16_t& dst, float src) {
|
||||||
dst = float_to_half(src);
|
dst = float_to_half(src);
|
||||||
@ -441,4 +436,9 @@ inline __device__ Float8_ to_float(uint4 u) {
|
|||||||
return tmp;
|
return tmp;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Zero-out a variable.
|
||||||
|
inline __device__ void zero(uint16_t& dst) {
|
||||||
|
dst = uint16_t(0);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
@ -265,4 +265,9 @@ inline __device__ Float8_ to_float(Float8_ u) {
|
|||||||
return u;
|
return u;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Zero-out a variable.
|
||||||
|
inline __device__ void zero(float& dst) {
|
||||||
|
dst = 0.f;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
|
||||||
|
#include "dispatch_utils.h"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <map>
|
#include <map>
|
||||||
@ -125,9 +127,7 @@ void copy_blocks(
|
|||||||
dim3 grid(num_layers, num_pairs);
|
dim3 grid(num_layers, num_pairs);
|
||||||
dim3 block(std::min(1024, numel_per_block));
|
dim3 block(std::min(1024, numel_per_block));
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
VLLM_DISPATCH_FLOATING_TYPES(
|
||||||
at::ScalarType::Half,
|
|
||||||
at::ScalarType::BFloat16,
|
|
||||||
key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
|
key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
|
||||||
vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||||
key_cache_ptrs_tensor.data_ptr<int64_t>(),
|
key_cache_ptrs_tensor.data_ptr<int64_t>(),
|
||||||
@ -202,9 +202,7 @@ void reshape_and_cache(
|
|||||||
dim3 grid(num_tokens);
|
dim3 grid(num_tokens);
|
||||||
dim3 block(std::min(num_heads * head_size, 512));
|
dim3 block(std::min(num_heads * head_size, 512));
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
VLLM_DISPATCH_FLOATING_TYPES(
|
||||||
at::ScalarType::Half,
|
|
||||||
at::ScalarType::BFloat16,
|
|
||||||
key.scalar_type(),
|
key.scalar_type(),
|
||||||
"reshape_and_cache_kernel",
|
"reshape_and_cache_kernel",
|
||||||
[&] {
|
[&] {
|
||||||
@ -364,9 +362,7 @@ void gather_cached_kv(
|
|||||||
dim3 grid(num_tokens);
|
dim3 grid(num_tokens);
|
||||||
dim3 block(std::min(num_heads * head_size, 512));
|
dim3 block(std::min(num_heads * head_size, 512));
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
VLLM_DISPATCH_FLOATING_TYPES(
|
||||||
at::ScalarType::Half,
|
|
||||||
at::ScalarType::BFloat16,
|
|
||||||
key.scalar_type(),
|
key.scalar_type(),
|
||||||
"gather_cached_kv_kernel_optimized",
|
"gather_cached_kv_kernel_optimized",
|
||||||
[&] {
|
[&] {
|
||||||
|
14
csrc/dispatch_utils.h
Normal file
14
csrc/dispatch_utils.h
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
/*
|
||||||
|
* Adapted from
|
||||||
|
* https://github.com/pytorch/pytorch/blob/v2.0.1/aten/src/ATen/Dispatch.h
|
||||||
|
*/
|
||||||
|
#include <torch/extension.h>
|
||||||
|
|
||||||
|
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
|
||||||
|
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||||
|
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
||||||
|
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
|
||||||
|
|
||||||
|
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
||||||
|
AT_DISPATCH_SWITCH( \
|
||||||
|
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
@ -1,6 +1,7 @@
|
|||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
|
||||||
|
#include "dispatch_utils.h"
|
||||||
#include "reduction_utils.cuh"
|
#include "reduction_utils.cuh"
|
||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
@ -46,9 +47,7 @@ void rms_norm(
|
|||||||
dim3 grid(num_tokens);
|
dim3 grid(num_tokens);
|
||||||
dim3 block(std::min(hidden_size, 1024));
|
dim3 block(std::min(hidden_size, 1024));
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
VLLM_DISPATCH_FLOATING_TYPES(
|
||||||
at::ScalarType::Half,
|
|
||||||
at::ScalarType::BFloat16,
|
|
||||||
input.scalar_type(),
|
input.scalar_type(),
|
||||||
"rms_norm_kernel",
|
"rms_norm_kernel",
|
||||||
[&] {
|
[&] {
|
||||||
|
@ -1,15 +1,16 @@
|
|||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
|
|
||||||
void rotary_embedding_neox(
|
void rotary_embedding(
|
||||||
torch::Tensor& positions,
|
torch::Tensor& positions,
|
||||||
torch::Tensor& query,
|
torch::Tensor& query,
|
||||||
torch::Tensor& key,
|
torch::Tensor& key,
|
||||||
int head_size,
|
int head_size,
|
||||||
torch::Tensor& cos_sin_cache);
|
torch::Tensor& cos_sin_cache,
|
||||||
|
bool is_neox);
|
||||||
|
|
||||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||||
m.def(
|
m.def(
|
||||||
"rotary_embedding_neox",
|
"rotary_embedding",
|
||||||
&rotary_embedding_neox,
|
&rotary_embedding,
|
||||||
"Apply GPT-NeoX style rotary embedding to query and key");
|
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
|
||||||
}
|
}
|
||||||
|
@ -1,10 +1,42 @@
|
|||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
|
||||||
|
#include "dispatch_utils.h"
|
||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
|
||||||
template<typename scalar_t>
|
template<typename scalar_t, bool IS_NEOX>
|
||||||
__global__ void rotary_embedding_neox_kernel(
|
inline __device__ void apply_rotary_embedding(
|
||||||
|
scalar_t* __restrict__ arr,
|
||||||
|
const scalar_t* __restrict__ cos_ptr,
|
||||||
|
const scalar_t* __restrict__ sin_ptr,
|
||||||
|
int rot_offset,
|
||||||
|
int embed_dim)
|
||||||
|
{
|
||||||
|
int x_index, y_index;
|
||||||
|
scalar_t cos, sin;
|
||||||
|
if (IS_NEOX) {
|
||||||
|
// GPT-NeoX style rotary embedding.
|
||||||
|
x_index = rot_offset;
|
||||||
|
y_index = embed_dim + rot_offset;
|
||||||
|
cos = __ldg(cos_ptr + x_index);
|
||||||
|
sin = __ldg(sin_ptr + x_index);
|
||||||
|
} else {
|
||||||
|
// GPT-J style rotary embedding.
|
||||||
|
x_index = 2 * rot_offset;
|
||||||
|
y_index = 2 * rot_offset + 1;
|
||||||
|
cos = __ldg(cos_ptr + x_index / 2);
|
||||||
|
sin = __ldg(sin_ptr + x_index / 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
const scalar_t x = arr[x_index];
|
||||||
|
const scalar_t y = arr[y_index];
|
||||||
|
arr[x_index] = x * cos - y * sin;
|
||||||
|
arr[y_index] = y * cos + x * sin;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename scalar_t, bool IS_NEOX>
|
||||||
|
__global__ void rotary_embedding_kernel(
|
||||||
const int64_t* __restrict__ positions, // [num_tokens]
|
const int64_t* __restrict__ positions, // [num_tokens]
|
||||||
scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size]
|
scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size]
|
||||||
scalar_t* __restrict__ key, // [num_tokens, num_kv_heads, head_size]
|
scalar_t* __restrict__ key, // [num_tokens, num_kv_heads, head_size]
|
||||||
@ -21,58 +53,37 @@ __global__ void rotary_embedding_neox_kernel(
|
|||||||
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
|
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
|
||||||
|
|
||||||
const int embed_dim = rot_dim / 2;
|
const int embed_dim = rot_dim / 2;
|
||||||
|
const scalar_t* cos_ptr = cache_ptr;
|
||||||
|
const scalar_t* sin_ptr = cache_ptr + embed_dim;
|
||||||
|
|
||||||
const int nq = num_heads * embed_dim;
|
const int nq = num_heads * embed_dim;
|
||||||
for (int i = threadIdx.x; i < nq; i += blockDim.x) {
|
for (int i = threadIdx.x; i < nq; i += blockDim.x) {
|
||||||
const int head_idx = i / embed_dim;
|
const int head_idx = i / embed_dim;
|
||||||
const int token_head = token_idx * query_stride + head_idx * head_size;
|
const int token_head = token_idx * query_stride + head_idx * head_size;
|
||||||
|
|
||||||
const int rot_offset = i % embed_dim;
|
const int rot_offset = i % embed_dim;
|
||||||
const int x_index = rot_offset;
|
apply_rotary_embedding<scalar_t, IS_NEOX>(query + token_head, cos_ptr,
|
||||||
const int y_index = embed_dim + rot_offset;
|
sin_ptr, rot_offset, embed_dim);
|
||||||
|
|
||||||
const int out_x = token_idx * query_stride + head_idx * head_size + x_index;
|
|
||||||
const int out_y = token_idx * query_stride + head_idx * head_size + y_index;
|
|
||||||
|
|
||||||
const scalar_t cos = __ldg(cache_ptr + x_index);
|
|
||||||
const scalar_t sin = __ldg(cache_ptr + y_index);
|
|
||||||
|
|
||||||
const scalar_t q_x = query[token_head + x_index];
|
|
||||||
const scalar_t q_y = query[token_head + y_index];
|
|
||||||
query[out_x] = q_x * cos - q_y * sin;
|
|
||||||
query[out_y] = q_y * cos + q_x * sin;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const int nk = num_kv_heads * embed_dim;
|
const int nk = num_kv_heads * embed_dim;
|
||||||
for (int i = threadIdx.x; i < nk; i += blockDim.x) {
|
for (int i = threadIdx.x; i < nk; i += blockDim.x) {
|
||||||
const int head_idx = i / embed_dim;
|
const int head_idx = i / embed_dim;
|
||||||
const int token_head = token_idx * key_stride + head_idx * head_size;
|
const int token_head = token_idx * key_stride + head_idx * head_size;
|
||||||
|
|
||||||
const int rot_offset = i % embed_dim;
|
const int rot_offset = i % embed_dim;
|
||||||
const int x_index = rot_offset;
|
apply_rotary_embedding<scalar_t, IS_NEOX>(key + token_head, cos_ptr,
|
||||||
const int y_index = embed_dim + rot_offset;
|
sin_ptr, rot_offset, embed_dim);
|
||||||
|
|
||||||
const int out_x = token_idx * key_stride + head_idx * head_size + x_index;
|
|
||||||
const int out_y = token_idx * key_stride + head_idx * head_size + y_index;
|
|
||||||
|
|
||||||
const scalar_t cos = __ldg(cache_ptr + x_index);
|
|
||||||
const scalar_t sin = __ldg(cache_ptr + y_index);
|
|
||||||
|
|
||||||
const scalar_t k_x = key[token_head + x_index];
|
|
||||||
const scalar_t k_y = key[token_head + y_index];
|
|
||||||
key[out_x] = k_x * cos - k_y * sin;
|
|
||||||
key[out_y] = k_y * cos + k_x * sin;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
void rotary_embedding_neox(
|
void rotary_embedding(
|
||||||
torch::Tensor& positions, // [num_tokens]
|
torch::Tensor& positions, // [num_tokens]
|
||||||
torch::Tensor& query, // [num_tokens, num_heads * head_size]
|
torch::Tensor& query, // [num_tokens, num_heads * head_size]
|
||||||
torch::Tensor& key, // [num_tokens, num_kv_heads * head_size]
|
torch::Tensor& key, // [num_tokens, num_kv_heads * head_size]
|
||||||
int head_size,
|
int head_size,
|
||||||
torch::Tensor& cos_sin_cache) // [max_position, rot_dim]
|
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
|
||||||
{
|
bool is_neox) {
|
||||||
int num_tokens = query.size(0);
|
int num_tokens = query.size(0);
|
||||||
int rot_dim = cos_sin_cache.size(1);
|
int rot_dim = cos_sin_cache.size(1);
|
||||||
int num_heads = query.size(1) / head_size;
|
int num_heads = query.size(1) / head_size;
|
||||||
@ -83,13 +94,12 @@ void rotary_embedding_neox(
|
|||||||
dim3 grid(num_tokens);
|
dim3 grid(num_tokens);
|
||||||
dim3 block(std::min(num_heads * rot_dim / 2, 512));
|
dim3 block(std::min(num_heads * rot_dim / 2, 512));
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
VLLM_DISPATCH_FLOATING_TYPES(
|
||||||
at::ScalarType::Half,
|
|
||||||
at::ScalarType::BFloat16,
|
|
||||||
query.scalar_type(),
|
query.scalar_type(),
|
||||||
"rotary_embedding_neox",
|
"rotary_embedding",
|
||||||
[&] {
|
[&] {
|
||||||
vllm::rotary_embedding_neox_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
if (is_neox) {
|
||||||
|
vllm::rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
|
||||||
positions.data_ptr<int64_t>(),
|
positions.data_ptr<int64_t>(),
|
||||||
query.data_ptr<scalar_t>(),
|
query.data_ptr<scalar_t>(),
|
||||||
key.data_ptr<scalar_t>(),
|
key.data_ptr<scalar_t>(),
|
||||||
@ -100,5 +110,18 @@ void rotary_embedding_neox(
|
|||||||
num_heads,
|
num_heads,
|
||||||
num_kv_heads,
|
num_kv_heads,
|
||||||
head_size);
|
head_size);
|
||||||
|
} else {
|
||||||
|
vllm::rotary_embedding_kernel<scalar_t, false><<<grid, block, 0, stream>>>(
|
||||||
|
positions.data_ptr<int64_t>(),
|
||||||
|
query.data_ptr<scalar_t>(),
|
||||||
|
key.data_ptr<scalar_t>(),
|
||||||
|
cos_sin_cache.data_ptr<scalar_t>(),
|
||||||
|
rot_dim,
|
||||||
|
query_stride,
|
||||||
|
key_stride,
|
||||||
|
num_heads,
|
||||||
|
num_kv_heads,
|
||||||
|
head_size);
|
||||||
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
@ -59,7 +59,7 @@ Next, you need to rewrite the :code:`forward` methods of your model by following
|
|||||||
+ kv_caches: List[KVCache],
|
+ kv_caches: List[KVCache],
|
||||||
+ input_metadata: InputMetadata,
|
+ input_metadata: InputMetadata,
|
||||||
+ cache_events: Optional[List[torch.cuda.Event]],
|
+ cache_events: Optional[List[torch.cuda.Event]],
|
||||||
+) -> Dict[int, SequenceOutputs]:
|
+) -> SamplerOutput:
|
||||||
|
|
||||||
3. Update the code by considering that :code:`input_ids` and :code:`positions` are now flattened tensors.
|
3. Update the code by considering that :code:`input_ids` and :code:`positions` are now flattened tensors.
|
||||||
4. Replace the attention operation with either :code:`GPTPagedAttention` or :code:`GPTNeoXPagedAttention`, depending on the model's architecture.
|
4. Replace the attention operation with either :code:`GPTPagedAttention` or :code:`GPTNeoXPagedAttention`, depending on the model's architecture.
|
||||||
|
@ -14,9 +14,12 @@ Alongside each architecture, we include some popular models that use it.
|
|||||||
* - Architecture
|
* - Architecture
|
||||||
- Models
|
- Models
|
||||||
- Example HuggingFace Models
|
- Example HuggingFace Models
|
||||||
|
* - :code:`AquilaForCausalLM`
|
||||||
|
- Aquila
|
||||||
|
- :code:`BAAI/Aquila-7B`, :code:`BAAI/AquilaChat-7B`, etc.
|
||||||
* - :code:`BaiChuanForCausalLM`
|
* - :code:`BaiChuanForCausalLM`
|
||||||
- Baichuan
|
- Baichuan
|
||||||
- :code:`baichuan-inc/Baichuan-7B`, `baichuan-inc/Baichuan-13B-Chat`, etc.
|
- :code:`baichuan-inc/Baichuan-7B`, :code:`baichuan-inc/Baichuan-13B-Chat`, etc.
|
||||||
* - :code:`BloomForCausalLM`
|
* - :code:`BloomForCausalLM`
|
||||||
- BLOOM, BLOOMZ, BLOOMChat
|
- BLOOM, BLOOMZ, BLOOMChat
|
||||||
- :code:`bigscience/bloom`, :code:`bigscience/bloomz`, etc.
|
- :code:`bigscience/bloom`, :code:`bigscience/bloomz`, etc.
|
||||||
@ -35,15 +38,21 @@ Alongside each architecture, we include some popular models that use it.
|
|||||||
* - :code:`GPTNeoXForCausalLM`
|
* - :code:`GPTNeoXForCausalLM`
|
||||||
- GPT-NeoX, Pythia, OpenAssistant, Dolly V2, StableLM
|
- GPT-NeoX, Pythia, OpenAssistant, Dolly V2, StableLM
|
||||||
- :code:`EleutherAI/gpt-neox-20b`, :code:`EleutherAI/pythia-12b`, :code:`OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5`, :code:`databricks/dolly-v2-12b`, :code:`stabilityai/stablelm-tuned-alpha-7b`, etc.
|
- :code:`EleutherAI/gpt-neox-20b`, :code:`EleutherAI/pythia-12b`, :code:`OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5`, :code:`databricks/dolly-v2-12b`, :code:`stabilityai/stablelm-tuned-alpha-7b`, etc.
|
||||||
|
* - :code:`InternLMForCausalLM`
|
||||||
|
- InternLM
|
||||||
|
- :code:`internlm/internlm-7b`, :code:`internlm/internlm-chat-7b`, etc.
|
||||||
* - :code:`LlamaForCausalLM`
|
* - :code:`LlamaForCausalLM`
|
||||||
- LLaMA, LLaMA-2, Vicuna, Alpaca, Koala, Guanaco
|
- LLaMA, LLaMA-2, Vicuna, Alpaca, Koala, Guanaco
|
||||||
- :code:`meta-llama/Llama-2-13b-hf`, :code:`openlm-research/open_llama_13b`, :code:`lmsys/vicuna-13b-v1.3`, :code:`young-geng/koala`, :code:`JosephusCheung/Guanaco`, etc.
|
- :code:`meta-llama/Llama-2-13b-hf`, :code:`meta-llama/Llama-2-70b-hf`, :code:`openlm-research/open_llama_13b`, :code:`lmsys/vicuna-13b-v1.3`, :code:`young-geng/koala`, etc.
|
||||||
* - :code:`MPTForCausalLM`
|
* - :code:`MPTForCausalLM`
|
||||||
- MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter
|
- MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter
|
||||||
- :code:`mosaicml/mpt-7b`, :code:`mosaicml/mpt-7b-storywriter`, :code:`mosaicml/mpt-30b`, etc.
|
- :code:`mosaicml/mpt-7b`, :code:`mosaicml/mpt-7b-storywriter`, :code:`mosaicml/mpt-30b`, etc.
|
||||||
* - :code:`OPTForCausalLM`
|
* - :code:`OPTForCausalLM`
|
||||||
- OPT, OPT-IML
|
- OPT, OPT-IML
|
||||||
- :code:`facebook/opt-66b`, :code:`facebook/opt-iml-max-30b`, etc.
|
- :code:`facebook/opt-66b`, :code:`facebook/opt-iml-max-30b`, etc.
|
||||||
|
* - :code:`QWenLMHeadModel`
|
||||||
|
- Qwen
|
||||||
|
- :code:`Qwen/Qwen-7B`, :code:`Qwen/Qwen-7B-Chat`, etc.
|
||||||
|
|
||||||
If your model uses one of the above model architectures, you can seamlessly run your model with vLLM.
|
If your model uses one of the above model architectures, you can seamlessly run your model with vLLM.
|
||||||
Otherwise, please refer to :ref:`Adding a New Model <adding_a_new_model>` for instructions on how to implement support for your model.
|
Otherwise, please refer to :ref:`Adding a New Model <adding_a_new_model>` for instructions on how to implement support for your model.
|
||||||
|
@ -10,3 +10,4 @@ types-setuptools
|
|||||||
|
|
||||||
# testing
|
# testing
|
||||||
pytest
|
pytest
|
||||||
|
pytest-forked
|
||||||
|
@ -4,8 +4,8 @@ ray >= 2.5.1
|
|||||||
sentencepiece # Required for LLaMA tokenizer.
|
sentencepiece # Required for LLaMA tokenizer.
|
||||||
numpy
|
numpy
|
||||||
torch >= 2.0.0
|
torch >= 2.0.0
|
||||||
transformers >= 4.31.0 # Required for LLaMA-2.
|
transformers >= 4.33.1 # Required for Code Llama.
|
||||||
xformers >= 0.0.19
|
xformers >= 0.0.21
|
||||||
fastapi
|
fastapi
|
||||||
uvicorn
|
uvicorn
|
||||||
pydantic < 2 # Required for OpenAI server.
|
pydantic < 2 # Required for OpenAI server.
|
||||||
|
29
setup.py
29
setup.py
@ -22,7 +22,7 @@ NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
|
|||||||
|
|
||||||
if CUDA_HOME is None:
|
if CUDA_HOME is None:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Cannot find CUDA_HOME. CUDA must be available in order to build the package.")
|
f"Cannot find CUDA_HOME. CUDA must be available to build the package.")
|
||||||
|
|
||||||
|
|
||||||
def get_nvcc_cuda_version(cuda_dir: str) -> Version:
|
def get_nvcc_cuda_version(cuda_dir: str) -> Version:
|
||||||
@ -47,12 +47,6 @@ for i in range(device_count):
|
|||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"GPUs with compute capability less than 7.0 are not supported.")
|
"GPUs with compute capability less than 7.0 are not supported.")
|
||||||
compute_capabilities.add(major * 10 + minor)
|
compute_capabilities.add(major * 10 + minor)
|
||||||
# If no GPU is available, add all supported compute capabilities.
|
|
||||||
if not compute_capabilities:
|
|
||||||
compute_capabilities = {70, 75, 80, 86, 90}
|
|
||||||
# Add target compute capabilities to NVCC flags.
|
|
||||||
for capability in compute_capabilities:
|
|
||||||
NVCC_FLAGS += ["-gencode", f"arch=compute_{capability},code=sm_{capability}"]
|
|
||||||
|
|
||||||
# Validate the NVCC CUDA version.
|
# Validate the NVCC CUDA version.
|
||||||
nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME)
|
nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME)
|
||||||
@ -61,10 +55,31 @@ if nvcc_cuda_version < Version("11.0"):
|
|||||||
if 86 in compute_capabilities and nvcc_cuda_version < Version("11.1"):
|
if 86 in compute_capabilities and nvcc_cuda_version < Version("11.1"):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"CUDA 11.1 or higher is required for GPUs with compute capability 8.6.")
|
"CUDA 11.1 or higher is required for GPUs with compute capability 8.6.")
|
||||||
|
if 89 in compute_capabilities and nvcc_cuda_version < Version("11.8"):
|
||||||
|
# CUDA 11.8 is required to generate the code targeting compute capability 8.9.
|
||||||
|
# However, GPUs with compute capability 8.9 can also run the code generated by
|
||||||
|
# the previous versions of CUDA 11 and targeting compute capability 8.0.
|
||||||
|
# Therefore, if CUDA 11.8 is not available, we target compute capability 8.0
|
||||||
|
# instead of 8.9.
|
||||||
|
compute_capabilities.remove(89)
|
||||||
|
compute_capabilities.add(80)
|
||||||
if 90 in compute_capabilities and nvcc_cuda_version < Version("11.8"):
|
if 90 in compute_capabilities and nvcc_cuda_version < Version("11.8"):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"CUDA 11.8 or higher is required for GPUs with compute capability 9.0.")
|
"CUDA 11.8 or higher is required for GPUs with compute capability 9.0.")
|
||||||
|
|
||||||
|
# If no GPU is available, add all supported compute capabilities.
|
||||||
|
if not compute_capabilities:
|
||||||
|
compute_capabilities = {70, 75, 80}
|
||||||
|
if nvcc_cuda_version >= Version("11.1"):
|
||||||
|
compute_capabilities.add(86)
|
||||||
|
if nvcc_cuda_version >= Version("11.8"):
|
||||||
|
compute_capabilities.add(89)
|
||||||
|
compute_capabilities.add(90)
|
||||||
|
|
||||||
|
# Add target compute capabilities to NVCC flags.
|
||||||
|
for capability in compute_capabilities:
|
||||||
|
NVCC_FLAGS += ["-gencode", f"arch=compute_{capability},code=sm_{capability}"]
|
||||||
|
|
||||||
# Use NVCC threads to parallelize the build.
|
# Use NVCC threads to parallelize the build.
|
||||||
if nvcc_cuda_version >= Version("11.2"):
|
if nvcc_cuda_version >= Version("11.2"):
|
||||||
num_threads = min(os.cpu_count(), 8)
|
num_threads = min(os.cpu_count(), 8)
|
||||||
|
51
tests/async_engine/api_server_async_engine.py
Normal file
51
tests/async_engine/api_server_async_engine.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
"""vllm.entrypoints.api_server with some extra logging for testing."""
|
||||||
|
import argparse
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
import uvicorn
|
||||||
|
from fastapi.responses import JSONResponse, Response
|
||||||
|
|
||||||
|
import vllm.entrypoints.api_server
|
||||||
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||||
|
|
||||||
|
app = vllm.entrypoints.api_server.app
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncLLMEngineWithStats(AsyncLLMEngine):
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self._num_aborts = 0
|
||||||
|
|
||||||
|
async def abort(self, request_id: str) -> None:
|
||||||
|
await super().abort(request_id)
|
||||||
|
self._num_aborts += 1
|
||||||
|
|
||||||
|
def testing_stats(self) -> Dict[str, Any]:
|
||||||
|
return {"num_aborted_requests": self._num_aborts}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/stats")
|
||||||
|
def stats() -> Response:
|
||||||
|
"""Get the statistics of the engine."""
|
||||||
|
return JSONResponse(engine.testing_stats())
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--host", type=str, default="localhost")
|
||||||
|
parser.add_argument("--port", type=int, default=8000)
|
||||||
|
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||||
|
engine = AsyncLLMEngineWithStats.from_engine_args(engine_args,
|
||||||
|
start_engine_loop=False)
|
||||||
|
vllm.entrypoints.api_server.engine = engine
|
||||||
|
uvicorn.run(
|
||||||
|
app,
|
||||||
|
host=args.host,
|
||||||
|
port=args.port,
|
||||||
|
log_level="debug",
|
||||||
|
timeout_keep_alive=vllm.entrypoints.api_server.TIMEOUT_KEEP_ALIVE)
|
86
tests/async_engine/test_api_server.py
Normal file
86
tests/async_engine/test_api_server.py
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from multiprocessing import Pool
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
def _query_server(prompt: str) -> dict:
|
||||||
|
response = requests.post("http://localhost:8000/generate",
|
||||||
|
json={
|
||||||
|
"prompt": prompt,
|
||||||
|
"max_tokens": 100,
|
||||||
|
"temperature": 0,
|
||||||
|
"ignore_eos": True
|
||||||
|
})
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def api_server():
|
||||||
|
script_path = Path(__file__).parent.joinpath(
|
||||||
|
"api_server_async_engine.py").absolute()
|
||||||
|
uvicorn_process = subprocess.Popen([
|
||||||
|
sys.executable, "-u",
|
||||||
|
str(script_path), "--model", "facebook/opt-125m"
|
||||||
|
])
|
||||||
|
yield
|
||||||
|
uvicorn_process.terminate()
|
||||||
|
|
||||||
|
|
||||||
|
def test_api_server(api_server):
|
||||||
|
"""
|
||||||
|
Run the API server and test it.
|
||||||
|
|
||||||
|
We run both the server and requests in separate processes.
|
||||||
|
|
||||||
|
We test that the server can handle incoming requests, including
|
||||||
|
multiple requests at the same time, and that it can handle requests
|
||||||
|
being cancelled without crashing.
|
||||||
|
"""
|
||||||
|
with Pool(32) as pool:
|
||||||
|
# Wait until the server is ready
|
||||||
|
prompts = ["Hello world"] * 1
|
||||||
|
result = None
|
||||||
|
while not result:
|
||||||
|
try:
|
||||||
|
for result in pool.map(_query_server, prompts):
|
||||||
|
break
|
||||||
|
except:
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
# Actual tests start here
|
||||||
|
# Try with 1 prompt
|
||||||
|
for result in pool.map(_query_server, prompts):
|
||||||
|
assert result
|
||||||
|
|
||||||
|
num_aborted_requests = requests.get(
|
||||||
|
"http://localhost:8000/stats").json()["num_aborted_requests"]
|
||||||
|
assert num_aborted_requests == 0
|
||||||
|
|
||||||
|
# Try with 100 prompts
|
||||||
|
prompts = ["Hello world"] * 100
|
||||||
|
for result in pool.map(_query_server, prompts):
|
||||||
|
assert result
|
||||||
|
|
||||||
|
# Cancel requests
|
||||||
|
pool.map_async(_query_server, prompts)
|
||||||
|
time.sleep(0.01)
|
||||||
|
pool.terminate()
|
||||||
|
pool.join()
|
||||||
|
|
||||||
|
# check cancellation stats
|
||||||
|
num_aborted_requests = requests.get(
|
||||||
|
"http://localhost:8000/stats").json()["num_aborted_requests"]
|
||||||
|
assert num_aborted_requests > 0
|
||||||
|
|
||||||
|
# check that server still runs after cancellations
|
||||||
|
with Pool(32) as pool:
|
||||||
|
# Try with 100 prompts
|
||||||
|
prompts = ["Hello world"] * 100
|
||||||
|
for result in pool.map(_query_server, prompts):
|
||||||
|
assert result
|
54
tests/async_engine/test_request_tracker.py
Normal file
54
tests/async_engine/test_request_tracker.py
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm.engine.async_llm_engine import RequestTracker
|
||||||
|
from vllm.outputs import RequestOutput
|
||||||
|
|
||||||
|
|
||||||
|
def test_request_tracker():
|
||||||
|
tracker = RequestTracker()
|
||||||
|
stream_1 = tracker.add_request("1")
|
||||||
|
new, finished = tracker.get_new_and_finished_requests()
|
||||||
|
assert len(new) == 1
|
||||||
|
assert new[0]["request_id"] == "1"
|
||||||
|
assert not finished
|
||||||
|
assert not stream_1.finished
|
||||||
|
|
||||||
|
stream_2 = tracker.add_request("2")
|
||||||
|
stream_3 = tracker.add_request("3")
|
||||||
|
new, finished = tracker.get_new_and_finished_requests()
|
||||||
|
assert len(new) == 2
|
||||||
|
assert new[0]["request_id"] == "2"
|
||||||
|
assert new[1]["request_id"] == "3"
|
||||||
|
assert not finished
|
||||||
|
assert not stream_2.finished
|
||||||
|
assert not stream_3.finished
|
||||||
|
|
||||||
|
# request_ids must be unique
|
||||||
|
with pytest.raises(KeyError):
|
||||||
|
tracker.add_request("1")
|
||||||
|
|
||||||
|
tracker.abort_request("1")
|
||||||
|
new, finished = tracker.get_new_and_finished_requests()
|
||||||
|
assert len(finished) == 1
|
||||||
|
assert "1" in finished
|
||||||
|
assert not new
|
||||||
|
assert stream_1.finished
|
||||||
|
|
||||||
|
stream_4 = tracker.add_request("4")
|
||||||
|
tracker.abort_request("4")
|
||||||
|
new, finished = tracker.get_new_and_finished_requests()
|
||||||
|
assert len(finished) == 1
|
||||||
|
assert "4" in finished
|
||||||
|
assert not new
|
||||||
|
assert stream_4.finished
|
||||||
|
|
||||||
|
stream_5 = tracker.add_request("5")
|
||||||
|
tracker.process_request_output(
|
||||||
|
RequestOutput("2", "output", [], [], finished=True))
|
||||||
|
new, finished = tracker.get_new_and_finished_requests()
|
||||||
|
assert len(finished) == 1
|
||||||
|
assert "2" in finished
|
||||||
|
assert len(new) == 1
|
||||||
|
assert new[0]["request_id"] == "5"
|
||||||
|
assert stream_2.finished
|
||||||
|
assert not stream_5.finished
|
178
tests/conftest.py
Normal file
178
tests/conftest.py
Normal file
@ -0,0 +1,178 @@
|
|||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||||
|
|
||||||
|
_TEST_PROMPTS = [
|
||||||
|
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs.",
|
||||||
|
"Briefly describe the major milestones in the development of artificial intelligence from 1950 to 2020.",
|
||||||
|
"Compare and contrast artificial intelligence with human intelligence in terms of processing information.",
|
||||||
|
"Describe the basic components of a neural network and how it can be trained.",
|
||||||
|
"Write a short story about a robot that dreams for the first time.",
|
||||||
|
"Analyze the impact of the COVID-19 pandemic on global economic structures and future business models.",
|
||||||
|
"Explain the cultural significance of the Mona Lisa painting, and how its perception might vary in Western versus Eastern societies.",
|
||||||
|
"Translate the following English sentence into Japanese, French, and Swahili: 'The early bird catches the worm.'",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def example_prompts() -> List[str]:
|
||||||
|
return _TEST_PROMPTS
|
||||||
|
|
||||||
|
|
||||||
|
_STR_DTYPE_TO_TORCH_DTYPE = {
|
||||||
|
"half": torch.half,
|
||||||
|
"bfloat16": torch.bfloat16,
|
||||||
|
"float": torch.float,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class HfRunner:
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
tokenizer_name: Optional[str] = None,
|
||||||
|
dtype: str = "half",
|
||||||
|
) -> None:
|
||||||
|
assert dtype in _STR_DTYPE_TO_TORCH_DTYPE
|
||||||
|
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
|
||||||
|
self.model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_name,
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
trust_remote_code=True,
|
||||||
|
).cuda()
|
||||||
|
if tokenizer_name is None:
|
||||||
|
tokenizer_name = model_name
|
||||||
|
self.tokenizer = get_tokenizer(tokenizer_name, trust_remote_code=True)
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
prompts: List[str],
|
||||||
|
**kwargs,
|
||||||
|
) -> List[Tuple[List[int], str]]:
|
||||||
|
outputs: List[Tuple[List[int], str]] = []
|
||||||
|
for prompt in prompts:
|
||||||
|
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
|
||||||
|
output_ids = self.model.generate(
|
||||||
|
input_ids.cuda(),
|
||||||
|
use_cache=True,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
output_str = self.tokenizer.batch_decode(
|
||||||
|
output_ids,
|
||||||
|
skip_special_tokens=True,
|
||||||
|
clean_up_tokenization_spaces=False,
|
||||||
|
)
|
||||||
|
output_ids = output_ids.cpu().tolist()
|
||||||
|
outputs.append((output_ids, output_str))
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def generate_greedy(
|
||||||
|
self,
|
||||||
|
prompts: List[str],
|
||||||
|
max_tokens: int,
|
||||||
|
) -> List[Tuple[List[int], str]]:
|
||||||
|
outputs = self.generate(prompts,
|
||||||
|
do_sample=False,
|
||||||
|
max_new_tokens=max_tokens)
|
||||||
|
for i in range(len(outputs)):
|
||||||
|
output_ids, output_str = outputs[i]
|
||||||
|
outputs[i] = (output_ids[0], output_str[0])
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def generate_beam_search(
|
||||||
|
self,
|
||||||
|
prompts: List[str],
|
||||||
|
beam_width: int,
|
||||||
|
max_tokens: int,
|
||||||
|
) -> List[Tuple[List[int], str]]:
|
||||||
|
outputs = self.generate(prompts,
|
||||||
|
do_sample=False,
|
||||||
|
max_new_tokens=max_tokens,
|
||||||
|
num_beams=beam_width,
|
||||||
|
num_return_sequences=beam_width)
|
||||||
|
for i in range(len(outputs)):
|
||||||
|
output_ids, output_str = outputs[i]
|
||||||
|
for j in range(len(output_ids)):
|
||||||
|
output_ids[j] = [
|
||||||
|
x for x in output_ids[j]
|
||||||
|
if x != self.tokenizer.pad_token_id
|
||||||
|
]
|
||||||
|
outputs[i] = (output_ids, output_str)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def hf_runner():
|
||||||
|
return HfRunner
|
||||||
|
|
||||||
|
|
||||||
|
class VllmRunner:
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
tokenizer_name: Optional[str] = None,
|
||||||
|
dtype: str = "half",
|
||||||
|
) -> None:
|
||||||
|
self.model = LLM(
|
||||||
|
model=model_name,
|
||||||
|
tokenizer=tokenizer_name,
|
||||||
|
trust_remote_code=True,
|
||||||
|
dtype=dtype,
|
||||||
|
swap_space=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
prompts: List[str],
|
||||||
|
sampling_params: SamplingParams,
|
||||||
|
) -> List[Tuple[List[int], str]]:
|
||||||
|
req_outputs = self.model.generate(prompts,
|
||||||
|
sampling_params=sampling_params)
|
||||||
|
outputs = []
|
||||||
|
for req_output in req_outputs:
|
||||||
|
prompt_str = req_output.prompt
|
||||||
|
prompt_ids = req_output.prompt_token_ids
|
||||||
|
req_sample_output_ids = []
|
||||||
|
req_sample_output_strs = []
|
||||||
|
for sample in req_output.outputs:
|
||||||
|
output_str = sample.text
|
||||||
|
output_ids = sample.token_ids
|
||||||
|
req_sample_output_ids.append(prompt_ids + output_ids)
|
||||||
|
req_sample_output_strs.append(prompt_str + output_str)
|
||||||
|
outputs.append((req_sample_output_ids, req_sample_output_strs))
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def generate_greedy(
|
||||||
|
self,
|
||||||
|
prompts: List[str],
|
||||||
|
max_tokens: int,
|
||||||
|
) -> List[Tuple[List[int], str]]:
|
||||||
|
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
|
||||||
|
outputs = self.generate(prompts, greedy_params)
|
||||||
|
return [(output_ids[0], output_str[0])
|
||||||
|
for output_ids, output_str in outputs]
|
||||||
|
|
||||||
|
def generate_beam_search(
|
||||||
|
self,
|
||||||
|
prompts: List[str],
|
||||||
|
beam_width: int,
|
||||||
|
max_tokens: int,
|
||||||
|
) -> List[Tuple[List[int], str]]:
|
||||||
|
beam_search_params = SamplingParams(n=beam_width,
|
||||||
|
use_beam_search=True,
|
||||||
|
temperature=0.0,
|
||||||
|
max_tokens=max_tokens)
|
||||||
|
outputs = self.generate(prompts, beam_search_params)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def vllm_runner():
|
||||||
|
return VllmRunner
|
43
tests/kernels/conftest.py
Normal file
43
tests/kernels/conftest.py
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def create_kv_caches(
|
||||||
|
num_blocks: int,
|
||||||
|
block_size: int,
|
||||||
|
num_layers: int,
|
||||||
|
num_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
seed: int,
|
||||||
|
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||||
|
torch.random.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
|
||||||
|
scale = head_size**-0.5
|
||||||
|
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
||||||
|
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
|
||||||
|
key_caches = []
|
||||||
|
for _ in range(num_layers):
|
||||||
|
key_cache = torch.empty(size=key_cache_shape,
|
||||||
|
dtype=dtype,
|
||||||
|
device='cuda')
|
||||||
|
key_cache.uniform_(-scale, scale)
|
||||||
|
key_caches.append(key_cache)
|
||||||
|
|
||||||
|
value_cache_shape = (num_blocks, num_heads, head_size, block_size)
|
||||||
|
value_caches = []
|
||||||
|
for _ in range(num_layers):
|
||||||
|
value_cache = torch.empty(size=value_cache_shape,
|
||||||
|
dtype=dtype,
|
||||||
|
device='cuda')
|
||||||
|
value_cache.uniform_(-scale, scale)
|
||||||
|
value_caches.append(value_cache)
|
||||||
|
return key_caches, value_caches
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def kv_cache_factory():
|
||||||
|
return create_kv_caches
|
@ -1,20 +1,34 @@
|
|||||||
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from transformers.activations import get_activation
|
||||||
|
|
||||||
from vllm import activation_ops
|
from vllm import activation_ops
|
||||||
|
|
||||||
|
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||||
|
NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
|
||||||
|
D = [512, 4096, 5120, 13824] # Arbitrary values for testing
|
||||||
|
SEEDS = [0]
|
||||||
|
|
||||||
|
|
||||||
def ref_silu_and_mul(x: torch.Tensor) -> torch.Tensor:
|
def ref_silu_and_mul(x: torch.Tensor) -> torch.Tensor:
|
||||||
x1, x2 = x.chunk(chunks=2, dim=1)
|
x1, x2 = x.chunk(chunks=2, dim=1)
|
||||||
return F.silu(x1) * x2
|
return F.silu(x1) * x2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||||
|
@pytest.mark.parametrize("d", D)
|
||||||
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
|
@pytest.mark.parametrize("seed", SEEDS)
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def run_silu_and_mul(
|
def test_silu_and_mul(
|
||||||
num_tokens: int,
|
num_tokens: int,
|
||||||
d: int,
|
d: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
|
seed: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
torch.random.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
x = torch.randn(num_tokens, 2 * d, dtype=dtype, device='cuda')
|
x = torch.randn(num_tokens, 2 * d, dtype=dtype, device='cuda')
|
||||||
out = torch.empty(num_tokens, d, dtype=dtype, device='cuda')
|
out = torch.empty(num_tokens, d, dtype=dtype, device='cuda')
|
||||||
activation_ops.silu_and_mul(out, x)
|
activation_ops.silu_and_mul(out, x)
|
||||||
@ -22,9 +36,40 @@ def run_silu_and_mul(
|
|||||||
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
|
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
|
||||||
|
|
||||||
|
|
||||||
def test_silu_and_mul() -> None:
|
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||||
for dtype in [torch.half, torch.bfloat16, torch.float]:
|
@pytest.mark.parametrize("d", D)
|
||||||
for num_tokens in [7, 83, 2048]:
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
for d in [512, 4096, 5120, 13824]:
|
@pytest.mark.parametrize("seed", SEEDS)
|
||||||
print(f'Testing dtype={dtype}, num_tokens={num_tokens}, d={d}')
|
@torch.inference_mode()
|
||||||
run_silu_and_mul(num_tokens, d, dtype)
|
def test_gelu_new(
|
||||||
|
num_tokens: int,
|
||||||
|
d: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
seed: int,
|
||||||
|
) -> None:
|
||||||
|
torch.random.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
x = torch.randn(num_tokens, d, dtype=dtype, device='cuda')
|
||||||
|
out = torch.empty(num_tokens, d, dtype=dtype, device='cuda')
|
||||||
|
activation_ops.gelu_new(out, x)
|
||||||
|
ref_out = get_activation("gelu_new")(x)
|
||||||
|
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||||
|
@pytest.mark.parametrize("d", D)
|
||||||
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
|
@pytest.mark.parametrize("seed", SEEDS)
|
||||||
|
def test_gelu_fast(
|
||||||
|
num_tokens: int,
|
||||||
|
d: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
seed: int,
|
||||||
|
) -> None:
|
||||||
|
torch.random.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
x = torch.randn(num_tokens, d, dtype=dtype, device='cuda')
|
||||||
|
out = torch.empty(num_tokens, d, dtype=dtype, device='cuda')
|
||||||
|
activation_ops.gelu_fast(out, x)
|
||||||
|
ref_out = get_activation("gelu_fast")(x)
|
||||||
|
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
|
||||||
|
@ -1,14 +1,24 @@
|
|||||||
import random
|
import random
|
||||||
from typing import List, Optional
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from xformers import ops as xops
|
from xformers import ops as xops
|
||||||
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
|
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
|
||||||
|
|
||||||
from vllm import attention_ops
|
from vllm import attention_ops
|
||||||
|
|
||||||
MAX_SEQ_LEN = 4096
|
MAX_SEQ_LEN = 8192
|
||||||
TEST_SEED = 0
|
NUM_BLOCKS = 128 # Arbitrary values for testing
|
||||||
|
|
||||||
|
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||||
|
NUM_GEN_SEQS = [7] # Arbitrary values for testing
|
||||||
|
NUM_PREFILL_SEQS = [1, 3, 7] # Arbitrary values for testing
|
||||||
|
NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing
|
||||||
|
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
|
||||||
|
BLOCK_SIZES = [8, 16, 32]
|
||||||
|
USE_ALIBI = [False, True]
|
||||||
|
SEEDS = [0]
|
||||||
|
|
||||||
|
|
||||||
def ref_masked_attention(
|
def ref_masked_attention(
|
||||||
@ -18,29 +28,34 @@ def ref_masked_attention(
|
|||||||
scale: float,
|
scale: float,
|
||||||
attn_mask: Optional[torch.Tensor] = None,
|
attn_mask: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
query = query * scale
|
attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
|
||||||
attn = torch.einsum('qhd,khd->hqk', query, key)
|
|
||||||
if attn_mask is not None:
|
if attn_mask is not None:
|
||||||
attn = attn + attn_mask
|
attn_weights = attn_weights + attn_mask.float()
|
||||||
attn = torch.softmax(attn, dim=-1)
|
attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
|
||||||
out = torch.einsum('hqk,khd->qhd', attn, value)
|
out = torch.einsum("hqk,khd->qhd", attn_weights, value)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def ref_single_query_cached_kv_attention(
|
def ref_single_query_cached_kv_attention(
|
||||||
output: torch.Tensor,
|
output: torch.Tensor,
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
|
num_queries_per_kv: int,
|
||||||
key_cache: torch.Tensor,
|
key_cache: torch.Tensor,
|
||||||
value_cache: torch.Tensor,
|
value_cache: torch.Tensor,
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
context_lens: torch.Tensor,
|
context_lens: torch.Tensor,
|
||||||
|
scale: float,
|
||||||
|
alibi_slopes: Optional[torch.Tensor],
|
||||||
) -> None:
|
) -> None:
|
||||||
num_heads = value_cache.shape[1]
|
num_query_heads = query.shape[1]
|
||||||
|
num_kv_heads = value_cache.shape[1]
|
||||||
head_size = value_cache.shape[2]
|
head_size = value_cache.shape[2]
|
||||||
block_size = value_cache.shape[3]
|
block_size = value_cache.shape[3]
|
||||||
|
num_seqs = query.shape[0]
|
||||||
|
|
||||||
num_input_tokens = query.shape[0]
|
block_tables = block_tables.cpu().tolist()
|
||||||
for i in range(num_input_tokens):
|
context_lens = context_lens.cpu().tolist()
|
||||||
|
for i in range(num_seqs):
|
||||||
q = query[i].unsqueeze(0)
|
q = query[i].unsqueeze(0)
|
||||||
block_table = block_tables[i]
|
block_table = block_tables[i]
|
||||||
context_len = int(context_lens[i])
|
context_len = int(context_lens[i])
|
||||||
@ -52,30 +67,138 @@ def ref_single_query_cached_kv_attention(
|
|||||||
block_offset = j % block_size
|
block_offset = j % block_size
|
||||||
|
|
||||||
k = key_cache[block_number, :, :, block_offset, :]
|
k = key_cache[block_number, :, :, block_offset, :]
|
||||||
k = k.reshape(num_heads, head_size)
|
k = k.reshape(num_kv_heads, head_size)
|
||||||
keys.append(k)
|
keys.append(k)
|
||||||
|
|
||||||
v = value_cache[block_number, :, :, block_offset]
|
v = value_cache[block_number, :, :, block_offset]
|
||||||
values.append(v)
|
values.append(v)
|
||||||
keys = torch.stack(keys, dim=0)
|
keys = torch.stack(keys, dim=0)
|
||||||
values = torch.stack(values, dim=0)
|
values = torch.stack(values, dim=0)
|
||||||
|
if num_queries_per_kv > 1:
|
||||||
|
# Handle MQA and GQA
|
||||||
|
keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1)
|
||||||
|
values = torch.repeat_interleave(values, num_queries_per_kv, dim=1)
|
||||||
|
|
||||||
scale = 1.0 / (head_size**0.5)
|
alibi_bias = None
|
||||||
out = ref_masked_attention(q, keys, values, scale)
|
if alibi_slopes is not None:
|
||||||
out = out.view(num_heads, head_size)
|
# Create the ALiBi bias used in the paged attention kernel.
|
||||||
|
position_ids = torch.arange(context_len, device="cuda").int()
|
||||||
|
alibi_bias = (position_ids - context_len + 1).float()
|
||||||
|
alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(
|
||||||
|
1, 1, -1)
|
||||||
|
|
||||||
|
out = ref_masked_attention(q, keys, values, scale, alibi_bias)
|
||||||
|
out = out.view(num_query_heads, head_size)
|
||||||
output[i].copy_(out, non_blocking=True)
|
output[i].copy_(out, non_blocking=True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
|
||||||
|
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||||
|
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||||
|
@pytest.mark.parametrize("use_alibi", USE_ALIBI)
|
||||||
|
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||||
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
|
@pytest.mark.parametrize("seed", SEEDS)
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_single_query_cached_kv_attention(
|
||||||
|
kv_cache_factory,
|
||||||
|
num_seqs: int,
|
||||||
|
num_heads: Tuple[int, int],
|
||||||
|
head_size: int,
|
||||||
|
use_alibi: bool,
|
||||||
|
block_size: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
seed: int,
|
||||||
|
) -> None:
|
||||||
|
random.seed(seed)
|
||||||
|
torch.random.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
|
||||||
|
scale = float(1.0 / (head_size**0.5))
|
||||||
|
num_query_heads, num_kv_heads = num_heads
|
||||||
|
query = torch.empty(num_seqs,
|
||||||
|
num_query_heads,
|
||||||
|
head_size,
|
||||||
|
dtype=dtype,
|
||||||
|
device="cuda")
|
||||||
|
query.uniform_(-scale, scale)
|
||||||
|
|
||||||
|
assert num_query_heads % num_kv_heads == 0
|
||||||
|
num_queries_per_kv = num_query_heads // num_kv_heads
|
||||||
|
head_mapping = torch.repeat_interleave(
|
||||||
|
torch.arange(num_kv_heads, dtype=torch.int32, device="cuda"),
|
||||||
|
num_queries_per_kv)
|
||||||
|
alibi_slopes = None
|
||||||
|
if use_alibi:
|
||||||
|
alibi_slopes = torch.randn(num_query_heads,
|
||||||
|
dtype=torch.float,
|
||||||
|
device="cuda")
|
||||||
|
|
||||||
|
context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
|
||||||
|
max_context_len = max(context_lens)
|
||||||
|
context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda")
|
||||||
|
|
||||||
|
# Create the block tables.
|
||||||
|
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
|
||||||
|
block_tables = []
|
||||||
|
for _ in range(num_seqs):
|
||||||
|
block_table = [
|
||||||
|
random.randint(0, NUM_BLOCKS - 1)
|
||||||
|
for _ in range(max_num_blocks_per_seq)
|
||||||
|
]
|
||||||
|
block_tables.append(block_table)
|
||||||
|
block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda")
|
||||||
|
|
||||||
|
# Create the KV caches.
|
||||||
|
key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
|
||||||
|
num_kv_heads, head_size, dtype,
|
||||||
|
seed)
|
||||||
|
key_cache, value_cache = key_caches[0], value_caches[0]
|
||||||
|
|
||||||
|
# Call the paged attention kernel.
|
||||||
|
output = torch.empty_like(query)
|
||||||
|
attention_ops.single_query_cached_kv_attention(
|
||||||
|
output,
|
||||||
|
query,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
head_mapping,
|
||||||
|
scale,
|
||||||
|
block_tables,
|
||||||
|
context_lens,
|
||||||
|
block_size,
|
||||||
|
max_context_len,
|
||||||
|
alibi_slopes,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run the reference implementation.
|
||||||
|
ref_output = torch.empty_like(query)
|
||||||
|
ref_single_query_cached_kv_attention(
|
||||||
|
ref_output,
|
||||||
|
query,
|
||||||
|
num_queries_per_kv,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
block_tables,
|
||||||
|
context_lens,
|
||||||
|
scale,
|
||||||
|
alibi_slopes,
|
||||||
|
)
|
||||||
|
|
||||||
|
# NOTE(woosuk): Due to the kernel-level differences in the two
|
||||||
|
# implementations, there is a small numerical difference in the two
|
||||||
|
# outputs. Thus, we use a relaxed tolerance for the test.
|
||||||
|
assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)
|
||||||
|
|
||||||
|
|
||||||
def ref_multi_query_kv_attention(
|
def ref_multi_query_kv_attention(
|
||||||
cu_seq_lens: List[int],
|
cu_seq_lens: List[int],
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
|
scale: float,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
head_size = query.shape[-1]
|
|
||||||
scale = 1.0 / (head_size**0.5)
|
|
||||||
|
|
||||||
num_seqs = len(cu_seq_lens) - 1
|
num_seqs = len(cu_seq_lens) - 1
|
||||||
ref_outputs = []
|
ref_outputs = []
|
||||||
for i in range(num_seqs):
|
for i in range(num_seqs):
|
||||||
@ -87,7 +210,7 @@ def ref_multi_query_kv_attention(
|
|||||||
attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype),
|
attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype),
|
||||||
diagonal=1)
|
diagonal=1)
|
||||||
attn_mask = attn_mask * torch.finfo(dtype).min
|
attn_mask = attn_mask * torch.finfo(dtype).min
|
||||||
attn_mask = attn_mask.to(dtype=dtype, device='cuda')
|
attn_mask = attn_mask.to(dtype=dtype, device="cuda")
|
||||||
|
|
||||||
ref_output = ref_masked_attention(
|
ref_output = ref_masked_attention(
|
||||||
query[start_idx:end_idx],
|
query[start_idx:end_idx],
|
||||||
@ -101,172 +224,43 @@ def ref_multi_query_kv_attention(
|
|||||||
return ref_output
|
return ref_output
|
||||||
|
|
||||||
|
|
||||||
def ref_multi_query_cached_kv_attention(
|
# TODO(woosuk): Add tests for USE_ALIBI=True.
|
||||||
cu_query_lens: List[int],
|
@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
|
||||||
query: torch.Tensor,
|
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||||
key_cache: torch.Tensor,
|
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||||
value_cache: torch.Tensor,
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
block_tables: torch.Tensor,
|
@pytest.mark.parametrize("seed", SEEDS)
|
||||||
context_lens: torch.Tensor,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
num_heads = value_cache.shape[1]
|
|
||||||
head_size = value_cache.shape[2]
|
|
||||||
block_size = value_cache.shape[3]
|
|
||||||
scale = 1.0 / (head_size**0.5)
|
|
||||||
|
|
||||||
num_queries = len(cu_query_lens) - 1
|
|
||||||
ref_outputs = []
|
|
||||||
for i in range(num_queries):
|
|
||||||
start_idx = cu_query_lens[i]
|
|
||||||
end_idx = cu_query_lens[i + 1]
|
|
||||||
query_len = end_idx - start_idx
|
|
||||||
context_len = int(context_lens[i])
|
|
||||||
block_table = block_tables[i]
|
|
||||||
|
|
||||||
# Create attention mask
|
|
||||||
attn_mask = torch.triu(torch.ones(query_len, context_len),
|
|
||||||
diagonal=context_len - query_len + 1) * -1e5
|
|
||||||
attn_mask = attn_mask.to(dtype=dtype, device='cuda')
|
|
||||||
|
|
||||||
keys = []
|
|
||||||
values = []
|
|
||||||
for j in range(context_len):
|
|
||||||
block_number = int(block_table[j // block_size])
|
|
||||||
block_offset = j % block_size
|
|
||||||
|
|
||||||
k = key_cache[block_number, :, :, block_offset, :]
|
|
||||||
k = k.reshape(num_heads, head_size)
|
|
||||||
keys.append(k)
|
|
||||||
|
|
||||||
v = value_cache[block_number, :, :, block_offset]
|
|
||||||
values.append(v)
|
|
||||||
keys = torch.stack(keys, dim=0)
|
|
||||||
values = torch.stack(values, dim=0)
|
|
||||||
|
|
||||||
ref_output = ref_masked_attention(
|
|
||||||
query[start_idx:end_idx],
|
|
||||||
keys,
|
|
||||||
values,
|
|
||||||
scale,
|
|
||||||
attn_mask=attn_mask,
|
|
||||||
)
|
|
||||||
ref_outputs.append(ref_output)
|
|
||||||
ref_output = torch.cat(ref_outputs, dim=0)
|
|
||||||
return ref_output
|
|
||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def run_single_query_cached_kv_attention(
|
def test_multi_query_kv_attention(
|
||||||
num_tokens: int,
|
|
||||||
num_heads: int,
|
|
||||||
head_size: int,
|
|
||||||
block_size: int,
|
|
||||||
num_blocks: int,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
num_kv_heads: int = None,
|
|
||||||
) -> None:
|
|
||||||
qkv = torch.empty(num_tokens,
|
|
||||||
3,
|
|
||||||
num_heads,
|
|
||||||
head_size,
|
|
||||||
dtype=dtype,
|
|
||||||
device='cuda')
|
|
||||||
qkv.uniform_(-1e-3, 1e-3)
|
|
||||||
query, _, _ = qkv.unbind(dim=1)
|
|
||||||
|
|
||||||
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
|
||||||
key_block_shape = (num_heads, head_size // x, block_size, x)
|
|
||||||
key_cache = torch.empty(size=(num_blocks, *key_block_shape),
|
|
||||||
dtype=dtype,
|
|
||||||
device='cuda')
|
|
||||||
key_cache.uniform_(-1e-3, 1e-3)
|
|
||||||
value_block_shape = (num_heads, head_size, block_size)
|
|
||||||
value_cache = torch.empty(size=(num_blocks, *value_block_shape),
|
|
||||||
dtype=dtype,
|
|
||||||
device='cuda')
|
|
||||||
value_cache.uniform_(-1e-3, 1e-3)
|
|
||||||
|
|
||||||
context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_tokens)]
|
|
||||||
max_context_len = max(context_lens)
|
|
||||||
context_lens = torch.tensor(context_lens, dtype=torch.int, device='cuda')
|
|
||||||
|
|
||||||
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
|
|
||||||
block_tables = []
|
|
||||||
for _ in range(num_tokens):
|
|
||||||
block_table = [
|
|
||||||
random.randint(0, num_blocks - 1)
|
|
||||||
for _ in range(max_num_blocks_per_seq)
|
|
||||||
]
|
|
||||||
block_tables.append(block_table)
|
|
||||||
block_tables = torch.tensor(block_tables, dtype=torch.int, device='cuda')
|
|
||||||
head_mapping = torch.arange(num_heads, dtype=torch.int32, device="cuda")
|
|
||||||
|
|
||||||
scale = float(1.0 / (head_size**0.5))
|
|
||||||
|
|
||||||
num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
|
||||||
assert num_heads % num_kv_heads == 0
|
|
||||||
num_queries_per_kv = num_heads // num_kv_heads
|
|
||||||
head_mapping = torch.repeat_interleave(
|
|
||||||
torch.arange(num_kv_heads, dtype=torch.int32, device="cuda"),
|
|
||||||
num_queries_per_kv)
|
|
||||||
|
|
||||||
output = torch.empty(num_tokens,
|
|
||||||
num_heads,
|
|
||||||
head_size,
|
|
||||||
dtype=dtype,
|
|
||||||
device='cuda')
|
|
||||||
attention_ops.single_query_cached_kv_attention(
|
|
||||||
output,
|
|
||||||
query,
|
|
||||||
key_cache,
|
|
||||||
value_cache,
|
|
||||||
head_mapping,
|
|
||||||
scale,
|
|
||||||
block_tables,
|
|
||||||
context_lens,
|
|
||||||
block_size,
|
|
||||||
max_context_len,
|
|
||||||
None, # ALiBi slopes.
|
|
||||||
)
|
|
||||||
|
|
||||||
ref_output = torch.empty_like(query)
|
|
||||||
ref_single_query_cached_kv_attention(
|
|
||||||
ref_output,
|
|
||||||
query,
|
|
||||||
key_cache,
|
|
||||||
value_cache,
|
|
||||||
block_tables,
|
|
||||||
context_lens,
|
|
||||||
)
|
|
||||||
# NOTE(woosuk): Due to the difference in the data types the two
|
|
||||||
# implementations use for attention softmax logits and accumulation,
|
|
||||||
# there is a small difference in the final outputs.
|
|
||||||
# We should use a relaxed tolerance for the test.
|
|
||||||
assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)
|
|
||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
|
||||||
def run_multi_query_kv_attention(
|
|
||||||
num_seqs: int,
|
num_seqs: int,
|
||||||
num_heads: int,
|
num_heads: Tuple[int, int],
|
||||||
head_size: int,
|
head_size: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
|
seed: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
random.seed(seed)
|
||||||
|
torch.random.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
|
||||||
seq_lens = random.sample(range(1, MAX_SEQ_LEN), num_seqs)
|
seq_lens = random.sample(range(1, MAX_SEQ_LEN), num_seqs)
|
||||||
num_tokens = sum(seq_lens)
|
num_tokens = sum(seq_lens)
|
||||||
|
|
||||||
scale = float(1.0 / (head_size**0.5))
|
scale = float(1.0 / (head_size**0.5))
|
||||||
|
num_query_heads, num_kv_heads = num_heads
|
||||||
qkv = torch.empty(num_tokens,
|
qkv = torch.empty(num_tokens,
|
||||||
3,
|
num_query_heads + 2 * num_kv_heads,
|
||||||
num_heads,
|
|
||||||
head_size,
|
head_size,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device='cuda')
|
device="cuda")
|
||||||
qkv.uniform_(-1e-3, 1e-3)
|
qkv.uniform_(-scale, scale)
|
||||||
query, key, value = qkv.unbind(dim=1)
|
query, key, value = qkv.split(
|
||||||
|
[num_query_heads, num_kv_heads, num_kv_heads], dim=1)
|
||||||
|
|
||||||
attn_op = xops.fmha.cutlass.FwOp()
|
num_queries_per_kv = num_query_heads // num_kv_heads
|
||||||
|
if num_queries_per_kv > 1:
|
||||||
|
# Handle MQA and GQA
|
||||||
|
key = torch.repeat_interleave(key, num_queries_per_kv, dim=1)
|
||||||
|
value = torch.repeat_interleave(value, num_queries_per_kv, dim=1)
|
||||||
attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)
|
attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)
|
||||||
output = xops.memory_efficient_attention_forward(
|
output = xops.memory_efficient_attention_forward(
|
||||||
query.unsqueeze(0),
|
query.unsqueeze(0),
|
||||||
@ -275,7 +269,6 @@ def run_multi_query_kv_attention(
|
|||||||
attn_bias=attn_bias,
|
attn_bias=attn_bias,
|
||||||
p=0.0,
|
p=0.0,
|
||||||
scale=scale,
|
scale=scale,
|
||||||
op=attn_op,
|
|
||||||
)
|
)
|
||||||
output = output.squeeze(0)
|
output = output.squeeze(0)
|
||||||
|
|
||||||
@ -287,40 +280,7 @@ def run_multi_query_kv_attention(
|
|||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
value,
|
value,
|
||||||
|
scale,
|
||||||
dtype,
|
dtype,
|
||||||
)
|
)
|
||||||
assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)
|
assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)
|
||||||
|
|
||||||
|
|
||||||
def test_single_query_cached_kv_attention() -> None:
|
|
||||||
torch.random.manual_seed(TEST_SEED)
|
|
||||||
torch.cuda.manual_seed(TEST_SEED)
|
|
||||||
for dtype in [torch.half, torch.bfloat16, torch.float]:
|
|
||||||
for block_size in [8, 16, 32]:
|
|
||||||
for head_size in [64, 80, 96, 112, 128, 256]:
|
|
||||||
print(f'Testing single_query_cached_kv_attention with '
|
|
||||||
f'dtype={dtype}, block_size={block_size}, '
|
|
||||||
f'head_size={head_size}')
|
|
||||||
run_single_query_cached_kv_attention(
|
|
||||||
num_tokens=37,
|
|
||||||
num_heads=3,
|
|
||||||
head_size=head_size,
|
|
||||||
block_size=block_size,
|
|
||||||
num_blocks=1024,
|
|
||||||
dtype=dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_multi_query_kv_attention() -> None:
|
|
||||||
torch.random.manual_seed(TEST_SEED)
|
|
||||||
torch.cuda.manual_seed(TEST_SEED)
|
|
||||||
for dtype in [torch.half, torch.bfloat16, torch.float]:
|
|
||||||
for head_size in [64, 80, 96, 112, 128, 256]:
|
|
||||||
print(f'Testing multi_query_kv_attention with dtype={dtype}, '
|
|
||||||
f'head_size={head_size}')
|
|
||||||
run_multi_query_kv_attention(
|
|
||||||
num_seqs=5,
|
|
||||||
num_heads=3,
|
|
||||||
head_size=head_size,
|
|
||||||
dtype=dtype,
|
|
||||||
)
|
|
||||||
|
@ -1,12 +1,32 @@
|
|||||||
import random
|
import random
|
||||||
|
|
||||||
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm import cache_ops
|
from vllm import cache_ops
|
||||||
|
|
||||||
|
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||||
|
NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
|
||||||
|
NUM_LAYERS = [5] # Arbitrary values for testing
|
||||||
|
NUM_HEADS = [8] # Arbitrary values for testing
|
||||||
|
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
|
||||||
|
BLOCK_SIZES = [8, 16, 32]
|
||||||
|
NUM_BLOCKS = [1024] # Arbitrary values for testing
|
||||||
|
NUM_MAPPINGS = [32, 256] # Arbitrary values for testing
|
||||||
|
SEEDS = [0]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
|
||||||
|
@pytest.mark.parametrize("num_layers", NUM_LAYERS)
|
||||||
|
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||||
|
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||||
|
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||||
|
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||||
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
|
@pytest.mark.parametrize("seed", SEEDS)
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def run_copy_blocks(
|
def test_copy_blocks(
|
||||||
|
kv_cache_factory,
|
||||||
num_mappings: int,
|
num_mappings: int,
|
||||||
num_layers: int,
|
num_layers: int,
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
@ -14,48 +34,43 @@ def run_copy_blocks(
|
|||||||
block_size: int,
|
block_size: int,
|
||||||
num_blocks: int,
|
num_blocks: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
|
seed: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
# Generate random block mappings.
|
random.seed(seed)
|
||||||
|
torch.random.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
|
||||||
|
# Generate random block mappings where each source block is mapped to two
|
||||||
|
# destination blocks.
|
||||||
|
assert 2 * num_mappings <= num_blocks
|
||||||
src_blocks = random.sample(range(num_blocks), num_mappings)
|
src_blocks = random.sample(range(num_blocks), num_mappings)
|
||||||
remainig_blocks = list(set(range(num_blocks)) - set(src_blocks))
|
remainig_blocks = list(set(range(num_blocks)) - set(src_blocks))
|
||||||
dst_blocks = random.sample(remainig_blocks, num_mappings)
|
dst_blocks = random.sample(remainig_blocks, 2 * num_mappings)
|
||||||
block_mapping = {src: [dst] for src, dst in zip(src_blocks, dst_blocks)}
|
block_mapping = {}
|
||||||
|
for i in range(num_mappings):
|
||||||
|
src = src_blocks[i]
|
||||||
|
dst1 = dst_blocks[2 * i]
|
||||||
|
dst2 = dst_blocks[2 * i + 1]
|
||||||
|
block_mapping[src] = [dst1, dst2]
|
||||||
|
|
||||||
# Create the KV cache.
|
# Create the KV caches.
|
||||||
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
key_caches, value_caches = kv_cache_factory(num_blocks, block_size,
|
||||||
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
|
num_layers, num_heads,
|
||||||
key_caches = []
|
head_size, dtype, seed)
|
||||||
for _ in range(num_layers):
|
|
||||||
key_cache = torch.randn(size=key_cache_shape,
|
|
||||||
dtype=dtype,
|
|
||||||
device='cuda')
|
|
||||||
key_caches.append(key_cache)
|
|
||||||
cloned_key_caches = []
|
|
||||||
for key_cache in key_caches:
|
|
||||||
cloned_key_caches.append(key_cache.clone())
|
|
||||||
|
|
||||||
value_cache_shape = (num_blocks, num_heads, head_size, block_size)
|
# Clone the KV caches.
|
||||||
value_caches = []
|
cloned_key_caches = [key_cache.clone() for key_cache in key_caches]
|
||||||
for _ in range(num_layers):
|
cloned_value_caches = [value_cache.clone() for value_cache in value_caches]
|
||||||
value_cache = torch.randn(size=value_cache_shape,
|
|
||||||
dtype=dtype,
|
|
||||||
device='cuda')
|
|
||||||
value_caches.append(value_cache)
|
|
||||||
cloned_value_caches = []
|
|
||||||
for value_cache in value_caches:
|
|
||||||
cloned_value_caches.append(value_cache.clone())
|
|
||||||
|
|
||||||
# Call the copy blocks kernel.
|
# Call the copy blocks kernel.
|
||||||
cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
|
cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
|
||||||
|
|
||||||
# Reference implementation.
|
# Run the reference implementation.
|
||||||
for src, dsts in block_mapping.items():
|
for src, dsts in block_mapping.items():
|
||||||
for dst in dsts:
|
for dst in dsts:
|
||||||
for key_cache, cloned_key_cache in zip(key_caches,
|
for cloned_key_cache in cloned_key_caches:
|
||||||
cloned_key_caches):
|
|
||||||
cloned_key_cache[dst] = cloned_key_cache[src]
|
cloned_key_cache[dst] = cloned_key_cache[src]
|
||||||
for value_cache, cloned_value_cache in zip(value_caches,
|
for cloned_value_cache in cloned_value_caches:
|
||||||
cloned_value_caches):
|
|
||||||
cloned_value_cache[dst] = cloned_value_cache[src]
|
cloned_value_cache[dst] = cloned_value_cache[src]
|
||||||
|
|
||||||
# Compare the results.
|
# Compare the results.
|
||||||
@ -66,15 +81,29 @@ def run_copy_blocks(
|
|||||||
assert torch.allclose(value_cache, cloned_value_cache)
|
assert torch.allclose(value_cache, cloned_value_cache)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||||
|
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||||
|
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||||
|
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||||
|
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||||
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
|
@pytest.mark.parametrize("seed", SEEDS)
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def run_reshape_and_cache(
|
def test_reshape_and_cache(
|
||||||
|
kv_cache_factory,
|
||||||
num_tokens: int,
|
num_tokens: int,
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
head_size: int,
|
head_size: int,
|
||||||
block_size: int,
|
block_size: int,
|
||||||
num_blocks: int,
|
num_blocks: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
|
seed: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
random.seed(seed)
|
||||||
|
torch.random.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
|
||||||
|
# Create a random slot mapping.
|
||||||
num_slots = block_size * num_blocks
|
num_slots = block_size * num_blocks
|
||||||
slot_mapping = random.sample(range(num_slots), num_tokens)
|
slot_mapping = random.sample(range(num_slots), num_tokens)
|
||||||
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device='cuda')
|
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device='cuda')
|
||||||
@ -87,110 +116,31 @@ def run_reshape_and_cache(
|
|||||||
device='cuda')
|
device='cuda')
|
||||||
_, key, value = qkv.unbind(dim=1)
|
_, key, value = qkv.unbind(dim=1)
|
||||||
|
|
||||||
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
# Create the KV caches.
|
||||||
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
|
key_caches, value_caches = kv_cache_factory(num_blocks, block_size, 1,
|
||||||
key_cache = torch.randn(size=key_cache_shape, dtype=dtype, device='cuda')
|
num_heads, head_size, dtype,
|
||||||
cloned_key_cache = key_cache.clone()
|
seed)
|
||||||
|
key_cache, value_cache = key_caches[0], value_caches[0]
|
||||||
|
|
||||||
value_cache_shape = (num_blocks, num_heads, head_size, block_size)
|
# Clone the KV caches.
|
||||||
value_cache = torch.randn(size=value_cache_shape,
|
cloned_key_cache = key_cache.clone()
|
||||||
dtype=dtype,
|
|
||||||
device='cuda')
|
|
||||||
cloned_value_cache = value_cache.clone()
|
cloned_value_cache = value_cache.clone()
|
||||||
|
|
||||||
|
# Call the reshape_and_cache kernel.
|
||||||
cache_ops.reshape_and_cache(key, value, key_cache, value_cache,
|
cache_ops.reshape_and_cache(key, value, key_cache, value_cache,
|
||||||
slot_mapping)
|
slot_mapping)
|
||||||
|
|
||||||
|
# Run the reference implementation.
|
||||||
|
reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
|
||||||
|
block_indicies = torch.div(slot_mapping, block_size, rounding_mode='floor')
|
||||||
|
block_indicies = block_indicies.cpu().tolist()
|
||||||
|
block_offsets = slot_mapping % block_size
|
||||||
|
block_offsets = block_offsets.cpu().tolist()
|
||||||
for i in range(num_tokens):
|
for i in range(num_tokens):
|
||||||
reshaped_key = key.reshape(num_tokens, num_heads, head_size // x, x)
|
block_idx = block_indicies[i]
|
||||||
block_idx = torch.div(slot_mapping[i],
|
block_offset = block_offsets[i]
|
||||||
block_size,
|
|
||||||
rounding_mode='floor')
|
|
||||||
block_offset = slot_mapping[i] % block_size
|
|
||||||
cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i]
|
cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i]
|
||||||
cloned_value_cache[block_idx, :, :, block_offset] = value[i]
|
cloned_value_cache[block_idx, :, :, block_offset] = value[i]
|
||||||
|
|
||||||
assert torch.allclose(key_cache, cloned_key_cache)
|
assert torch.allclose(key_cache, cloned_key_cache)
|
||||||
assert torch.allclose(value_cache, cloned_value_cache)
|
assert torch.allclose(value_cache, cloned_value_cache)
|
||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
|
||||||
def run_gather_cached_kv(
|
|
||||||
num_tokens: int,
|
|
||||||
num_heads: int,
|
|
||||||
head_size: int,
|
|
||||||
block_size: int,
|
|
||||||
num_blocks: int,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
) -> None:
|
|
||||||
num_slots = block_size * num_blocks
|
|
||||||
slot_mapping = random.sample(range(num_slots), num_tokens)
|
|
||||||
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device='cuda')
|
|
||||||
|
|
||||||
qkv = torch.randn(num_tokens,
|
|
||||||
3,
|
|
||||||
num_heads,
|
|
||||||
head_size,
|
|
||||||
dtype=dtype,
|
|
||||||
device='cuda')
|
|
||||||
_, key, value = qkv.unbind(dim=1)
|
|
||||||
|
|
||||||
qkv_clone = qkv.clone()
|
|
||||||
_, cloned_key, cloned_value = qkv_clone.unbind(dim=1)
|
|
||||||
|
|
||||||
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
|
||||||
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
|
|
||||||
key_cache = torch.randn(size=key_cache_shape, dtype=dtype, device='cuda')
|
|
||||||
|
|
||||||
value_cache_shape = (num_blocks, num_heads, head_size, block_size)
|
|
||||||
value_cache = torch.randn(size=value_cache_shape,
|
|
||||||
dtype=dtype,
|
|
||||||
device='cuda')
|
|
||||||
|
|
||||||
cache_ops.gather_cached_kv(key, value, key_cache, value_cache,
|
|
||||||
slot_mapping)
|
|
||||||
|
|
||||||
# Reference implementation.
|
|
||||||
for i in range(num_tokens):
|
|
||||||
reshaped_key = cloned_key.reshape(num_tokens, num_heads,
|
|
||||||
head_size // x, x)
|
|
||||||
block_idx = torch.div(slot_mapping[i],
|
|
||||||
block_size,
|
|
||||||
rounding_mode='floor')
|
|
||||||
block_offset = slot_mapping[i] % block_size
|
|
||||||
reshaped_key[i] = key_cache[block_idx, :, :, block_offset, :]
|
|
||||||
cloned_value[i] = value_cache[block_idx, :, :, block_offset]
|
|
||||||
|
|
||||||
assert torch.allclose(key, cloned_key)
|
|
||||||
assert torch.allclose(value, cloned_value)
|
|
||||||
|
|
||||||
|
|
||||||
def test_copy_blocks() -> None:
|
|
||||||
for dtype in [torch.half, torch.bfloat16, torch.float]:
|
|
||||||
run_copy_blocks(num_mappings=23,
|
|
||||||
num_layers=7,
|
|
||||||
num_heads=17,
|
|
||||||
head_size=16,
|
|
||||||
block_size=8,
|
|
||||||
num_blocks=1024,
|
|
||||||
dtype=dtype)
|
|
||||||
|
|
||||||
|
|
||||||
def test_reshape_and_cache() -> None:
|
|
||||||
for dtype in [torch.half, torch.bfloat16, torch.float]:
|
|
||||||
run_reshape_and_cache(num_tokens=3,
|
|
||||||
num_heads=2,
|
|
||||||
head_size=16,
|
|
||||||
block_size=8,
|
|
||||||
num_blocks=2,
|
|
||||||
dtype=dtype)
|
|
||||||
|
|
||||||
|
|
||||||
def test_gather_cached_kv() -> None:
|
|
||||||
for dtype in [torch.half, torch.bfloat16, torch.float]:
|
|
||||||
run_gather_cached_kv(num_tokens=3,
|
|
||||||
num_heads=2,
|
|
||||||
head_size=16,
|
|
||||||
block_size=8,
|
|
||||||
num_blocks=2,
|
|
||||||
dtype=dtype)
|
|
||||||
|
@ -1,35 +1,50 @@
|
|||||||
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from vllm import layernorm_ops
|
from vllm import layernorm_ops
|
||||||
|
|
||||||
|
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||||
|
HIDDEN_SIZES = [67, 768, 2048, 5120, 8192] # Arbitrary values for testing
|
||||||
|
NUM_TOKENS = [7, 83, 4096] # Arbitrary values for testing
|
||||||
|
SEEDS = [0]
|
||||||
|
|
||||||
|
|
||||||
class RefRMSNorm(nn.Module):
|
class RefRMSNorm(nn.Module):
|
||||||
|
|
||||||
def __init__(self, hidden_size, eps=1e-6):
|
def __init__(self, hidden_size, eps=1e-6):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
weight = torch.empty(hidden_size)
|
weight = torch.empty(hidden_size)
|
||||||
weight.uniform_(-1e-3, 1e-3)
|
weight.normal_(mean=1.0, std=0.1)
|
||||||
self.weight = nn.Parameter(weight)
|
self.weight = nn.Parameter(weight)
|
||||||
self.variance_epsilon = eps
|
self.variance_epsilon = eps
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
variance = hidden_states.to(torch.float32).pow(2).mean(-1,
|
input_dtype = hidden_states.dtype
|
||||||
keepdim=True)
|
hidden_states = hidden_states.to(torch.float32)
|
||||||
|
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||||
hidden_states = hidden_states * torch.rsqrt(variance +
|
hidden_states = hidden_states * torch.rsqrt(variance +
|
||||||
self.variance_epsilon)
|
self.variance_epsilon)
|
||||||
if self.weight.dtype in [torch.half, torch.float16, torch.bfloat16]:
|
return self.weight * hidden_states.to(input_dtype)
|
||||||
hidden_states = hidden_states.to(self.weight.dtype)
|
|
||||||
return self.weight * hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||||
|
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
||||||
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
|
@pytest.mark.parametrize("seed", SEEDS)
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def run_rms_norm(
|
def test_rms_norm(
|
||||||
num_tokens: int,
|
num_tokens: int,
|
||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
|
seed: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
x = torch.randn(num_tokens, hidden_size, dtype=dtype, device='cuda')
|
torch.random.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
|
||||||
|
scale = float(hidden_size**-0.5)
|
||||||
|
x = torch.empty(num_tokens, hidden_size, dtype=dtype, device="cuda")
|
||||||
|
x.uniform_(-scale, scale)
|
||||||
ref = RefRMSNorm(hidden_size).to(dtype).cuda()
|
ref = RefRMSNorm(hidden_size).to(dtype).cuda()
|
||||||
|
|
||||||
out = torch.empty_like(x)
|
out = torch.empty_like(x)
|
||||||
@ -40,17 +55,4 @@ def run_rms_norm(
|
|||||||
ref.variance_epsilon,
|
ref.variance_epsilon,
|
||||||
)
|
)
|
||||||
ref_out = ref(x)
|
ref_out = ref(x)
|
||||||
assert torch.allclose(out, ref_out, atol=1e-3, rtol=1e-5)
|
assert torch.allclose(out, ref_out, atol=1e-2, rtol=1e-5)
|
||||||
|
|
||||||
|
|
||||||
def test_rms_norm() -> None:
|
|
||||||
for dtype in [torch.half, torch.bfloat16, torch.float]:
|
|
||||||
for num_tokens in [7, 128, 2048]:
|
|
||||||
for hidden_size in [13, 64, 1024, 5120]:
|
|
||||||
print(f'Testing RMS kernel with dtype={dtype}, num_tokens='
|
|
||||||
f'{num_tokens}, hidden_size={hidden_size}')
|
|
||||||
run_rms_norm(
|
|
||||||
num_tokens=num_tokens,
|
|
||||||
hidden_size=hidden_size,
|
|
||||||
dtype=dtype,
|
|
||||||
)
|
|
||||||
|
@ -1,47 +1,70 @@
|
|||||||
from typing import Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from vllm import pos_encoding_ops
|
from vllm import pos_encoding_ops
|
||||||
|
|
||||||
|
IS_NEOX_STYLE = [True, False]
|
||||||
|
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||||
|
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
|
||||||
|
ROTARY_DIMS = [None, 32] # None means rotary dim == head size
|
||||||
|
NUM_HEADS = [7, 12, 40, 52] # Arbitrary values for testing
|
||||||
|
NUM_TOKENS = [11, 83, 2048] # Arbitrary values for testing
|
||||||
|
SEEDS = [0]
|
||||||
|
|
||||||
def rotate_half(x: torch.Tensor) -> torch.Tensor:
|
|
||||||
|
def rotate_neox(x: torch.Tensor) -> torch.Tensor:
|
||||||
x1 = x[..., :x.shape[-1] // 2]
|
x1 = x[..., :x.shape[-1] // 2]
|
||||||
x2 = x[..., x.shape[-1] // 2:]
|
x2 = x[..., x.shape[-1] // 2:]
|
||||||
return torch.cat((-x2, x1), dim=-1)
|
return torch.cat((-x2, x1), dim=-1)
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_pos_emb(
|
def rotate_gptj(x: torch.Tensor) -> torch.Tensor:
|
||||||
|
x1 = x[..., ::2]
|
||||||
|
x2 = x[..., 1::2]
|
||||||
|
x = torch.stack((-x2, x1), dim=-1)
|
||||||
|
return x.flatten(-2)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rope(
|
||||||
q: torch.Tensor,
|
q: torch.Tensor,
|
||||||
k: torch.Tensor,
|
k: torch.Tensor,
|
||||||
cos: torch.Tensor,
|
cos: torch.Tensor,
|
||||||
sin: torch.Tensor,
|
sin: torch.Tensor,
|
||||||
|
is_neox_style: bool,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
rotate_fn = rotate_neox if is_neox_style else rotate_gptj
|
||||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
q_embed = (q * cos) + (rotate_fn(q) * sin)
|
||||||
|
k_embed = (k * cos) + (rotate_fn(k) * sin)
|
||||||
return q_embed, k_embed
|
return q_embed, k_embed
|
||||||
|
|
||||||
|
|
||||||
class RefRotaryEmbeddingNeox(nn.Module):
|
class RefRotaryEmbedding(nn.Module):
|
||||||
"""Reference implementation of the GPT-NeoX style rotary embedding."""
|
"""Reference implementation of rotary embedding."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dim: int,
|
dim: int,
|
||||||
max_position_embeddings: int = 2048,
|
is_neox_style: bool,
|
||||||
|
max_position_embeddings: int = 8192,
|
||||||
base: int = 10000,
|
base: int = 10000,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.rotary_dim = dim
|
self.rotary_dim = dim
|
||||||
|
self.is_neox_style = is_neox_style
|
||||||
self.max_position_embeddings = max_position_embeddings
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
|
||||||
# Create cos and sin embeddings.
|
# Create cos and sin embeddings.
|
||||||
inv_freq = 1.0 / (base**(torch.arange(0, dim, 2) / dim))
|
inv_freq = 1.0 / (base**(torch.arange(0, dim, 2) / dim))
|
||||||
t = torch.arange(max_position_embeddings).float()
|
t = torch.arange(max_position_embeddings).float()
|
||||||
freqs = torch.einsum("i,j->ij", t, inv_freq.float())
|
freqs = torch.einsum("i,j->ij", t, inv_freq.float())
|
||||||
|
if is_neox_style:
|
||||||
emb = torch.cat((freqs, freqs), dim=-1)
|
emb = torch.cat((freqs, freqs), dim=-1)
|
||||||
|
else:
|
||||||
|
emb = torch.repeat_interleave(freqs, 2, -1)
|
||||||
cos = emb.cos().to(dtype=inv_freq.dtype)
|
cos = emb.cos().to(dtype=inv_freq.dtype)
|
||||||
sin = emb.sin().to(dtype=inv_freq.dtype)
|
sin = emb.sin().to(dtype=inv_freq.dtype)
|
||||||
self.register_buffer("cos_cached", cos, persistent=False)
|
self.register_buffer("cos_cached", cos, persistent=False)
|
||||||
@ -53,7 +76,6 @@ class RefRotaryEmbeddingNeox(nn.Module):
|
|||||||
query: torch.Tensor, # [num_tokens, num_heads, head_size]
|
query: torch.Tensor, # [num_tokens, num_heads, head_size]
|
||||||
key: torch.Tensor, # [num_tokens, num_heads, head_size]
|
key: torch.Tensor, # [num_tokens, num_heads, head_size]
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
|
||||||
query_rot = query[..., :self.rotary_dim]
|
query_rot = query[..., :self.rotary_dim]
|
||||||
query_pass = query[..., self.rotary_dim:]
|
query_pass = query[..., self.rotary_dim:]
|
||||||
key_rot = key[..., :self.rotary_dim]
|
key_rot = key[..., :self.rotary_dim]
|
||||||
@ -63,7 +85,9 @@ class RefRotaryEmbeddingNeox(nn.Module):
|
|||||||
key_rot = key_rot.transpose(0, 1)
|
key_rot = key_rot.transpose(0, 1)
|
||||||
cos = F.embedding(positions, self.cos_cached)
|
cos = F.embedding(positions, self.cos_cached)
|
||||||
sin = F.embedding(positions, self.sin_cached)
|
sin = F.embedding(positions, self.sin_cached)
|
||||||
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
|
|
||||||
|
query_rot, key_rot = apply_rope(query_rot, key_rot, cos, sin,
|
||||||
|
self.is_neox_style)
|
||||||
query_rot = query_rot.transpose(0, 1).contiguous()
|
query_rot = query_rot.transpose(0, 1).contiguous()
|
||||||
key_rot = key_rot.transpose(0, 1).contiguous()
|
key_rot = key_rot.transpose(0, 1).contiguous()
|
||||||
|
|
||||||
@ -74,30 +98,44 @@ class RefRotaryEmbeddingNeox(nn.Module):
|
|||||||
return query, key
|
return query, key
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
|
||||||
|
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||||
|
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||||
|
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||||
|
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
|
||||||
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
|
@pytest.mark.parametrize("seed", SEEDS)
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def run_rotary_embedding_neox(
|
def test_rotary_embedding(
|
||||||
|
is_neox_style: bool,
|
||||||
num_tokens: int,
|
num_tokens: int,
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
head_size: int,
|
head_size: int,
|
||||||
max_position: int,
|
rotary_dim: Optional[int],
|
||||||
rotary_dim: int,
|
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
|
seed: int,
|
||||||
|
max_position: int = 8192,
|
||||||
base: int = 10000,
|
base: int = 10000,
|
||||||
) -> None:
|
) -> None:
|
||||||
positions = torch.randint(0, max_position, (num_tokens, ), device='cuda')
|
if rotary_dim is None:
|
||||||
|
rotary_dim = head_size
|
||||||
|
torch.random.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
|
||||||
|
positions = torch.randint(0, max_position, (num_tokens, ), device="cuda")
|
||||||
query = torch.randn(num_tokens,
|
query = torch.randn(num_tokens,
|
||||||
num_heads * head_size,
|
num_heads * head_size,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device='cuda')
|
device="cuda")
|
||||||
key = torch.randn(num_tokens,
|
key = torch.randn(num_tokens,
|
||||||
num_heads * head_size,
|
num_heads * head_size,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device='cuda')
|
device="cuda")
|
||||||
|
|
||||||
# Create the rotary embedding.
|
# Create the rotary embedding.
|
||||||
inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim))
|
inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim))
|
||||||
t = torch.arange(max_position).float()
|
t = torch.arange(max_position).float()
|
||||||
freqs = torch.einsum('i,j -> ij', t, inv_freq.float())
|
freqs = torch.einsum("i,j -> ij", t, inv_freq.float())
|
||||||
cos = freqs.cos()
|
cos = freqs.cos()
|
||||||
sin = freqs.sin()
|
sin = freqs.sin()
|
||||||
cos_sin_cache = torch.cat((cos, sin), dim=-1)
|
cos_sin_cache = torch.cat((cos, sin), dim=-1)
|
||||||
@ -106,20 +144,22 @@ def run_rotary_embedding_neox(
|
|||||||
# Run the kernel. The kernel is in-place, so we need to clone the inputs.
|
# Run the kernel. The kernel is in-place, so we need to clone the inputs.
|
||||||
out_query = query.clone()
|
out_query = query.clone()
|
||||||
out_key = key.clone()
|
out_key = key.clone()
|
||||||
pos_encoding_ops.rotary_embedding_neox(
|
pos_encoding_ops.rotary_embedding(
|
||||||
positions,
|
positions,
|
||||||
out_query,
|
out_query,
|
||||||
out_key,
|
out_key,
|
||||||
head_size,
|
head_size,
|
||||||
cos_sin_cache,
|
cos_sin_cache,
|
||||||
|
is_neox_style,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Run the reference implementation.
|
# Run the reference implementation.
|
||||||
ref_rotary_embedding = RefRotaryEmbeddingNeox(
|
ref_rotary_embedding = RefRotaryEmbedding(
|
||||||
dim=rotary_dim,
|
dim=rotary_dim,
|
||||||
|
is_neox_style=is_neox_style,
|
||||||
max_position_embeddings=max_position,
|
max_position_embeddings=max_position,
|
||||||
base=base,
|
base=base,
|
||||||
).to(dtype=dtype, device='cuda')
|
).to(dtype=dtype, device="cuda")
|
||||||
ref_query, ref_key = ref_rotary_embedding(
|
ref_query, ref_key = ref_rotary_embedding(
|
||||||
positions,
|
positions,
|
||||||
query.view(num_tokens, num_heads, head_size),
|
query.view(num_tokens, num_heads, head_size),
|
||||||
@ -129,19 +169,5 @@ def run_rotary_embedding_neox(
|
|||||||
ref_key = ref_key.view(num_tokens, num_heads * head_size)
|
ref_key = ref_key.view(num_tokens, num_heads * head_size)
|
||||||
|
|
||||||
# Compare the results.
|
# Compare the results.
|
||||||
assert torch.allclose(out_query, ref_query, atol=1e-3, rtol=1e-5)
|
assert torch.allclose(out_query, ref_query, atol=1e-5, rtol=1e-5)
|
||||||
assert torch.allclose(out_key, ref_key, atol=1e-3, rtol=1e-5)
|
assert torch.allclose(out_key, ref_key, atol=1e-5, rtol=1e-5)
|
||||||
|
|
||||||
|
|
||||||
def test_rotary_embedding_neox() -> None:
|
|
||||||
for dtype in [torch.half, torch.bfloat16, torch.float]:
|
|
||||||
for head_size in [32, 64, 80, 96, 128, 160, 192, 256]:
|
|
||||||
print(f'Running tests for head_size={head_size} and dtype={dtype}')
|
|
||||||
run_rotary_embedding_neox(
|
|
||||||
num_tokens=2145,
|
|
||||||
num_heads=5,
|
|
||||||
head_size=head_size,
|
|
||||||
max_position=8192,
|
|
||||||
rotary_dim=head_size,
|
|
||||||
dtype=dtype,
|
|
||||||
)
|
|
||||||
|
45
tests/models/test_models.py
Normal file
45
tests/models/test_models.py
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
"""Compare the outputs of HF and vLLM when using greedy sampling.
|
||||||
|
|
||||||
|
Run `pytest tests/models/test_models.py --forked`.
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
MODELS = [
|
||||||
|
"facebook/opt-125m",
|
||||||
|
"gpt2",
|
||||||
|
"bigcode/tiny_starcoder_py",
|
||||||
|
"EleutherAI/gpt-j-6b",
|
||||||
|
"EleutherAI/pythia-70m",
|
||||||
|
"bigscience/bloom-560m",
|
||||||
|
"mosaicml/mpt-7b",
|
||||||
|
"tiiuae/falcon-7b",
|
||||||
|
"meta-llama/Llama-2-7b-hf",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
|
@pytest.mark.parametrize("dtype", ["half"])
|
||||||
|
@pytest.mark.parametrize("max_tokens", [128])
|
||||||
|
def test_models(
|
||||||
|
hf_runner,
|
||||||
|
vllm_runner,
|
||||||
|
example_prompts,
|
||||||
|
model: str,
|
||||||
|
dtype: str,
|
||||||
|
max_tokens: int,
|
||||||
|
) -> None:
|
||||||
|
hf_model = hf_runner(model, dtype=dtype)
|
||||||
|
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||||
|
del hf_model
|
||||||
|
|
||||||
|
vllm_model = vllm_runner(model, dtype=dtype)
|
||||||
|
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||||
|
del vllm_model
|
||||||
|
|
||||||
|
for i in range(len(example_prompts)):
|
||||||
|
hf_output_ids, hf_output_str = hf_outputs[i]
|
||||||
|
vllm_output_ids, vllm_output_str = vllm_outputs[i]
|
||||||
|
assert hf_output_str == vllm_output_str, (
|
||||||
|
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
|
||||||
|
assert hf_output_ids == vllm_output_ids, (
|
||||||
|
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
|
46
tests/samplers/test_beam_search.py
Normal file
46
tests/samplers/test_beam_search.py
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
"""Compare the outputs of HF and vLLM when using beam search.
|
||||||
|
|
||||||
|
Run `pytest tests/samplers/test_beam_search.py --forked`.
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
# FIXME(zhuohan): The test can not pass if we:
|
||||||
|
# 1. Increase max_tokens to 256.
|
||||||
|
# 2. Increase beam_width to 8.
|
||||||
|
# 3. Use the model "huggyllama/llama-7b".
|
||||||
|
MAX_TOKENS = [128]
|
||||||
|
BEAM_WIDTHS = [4]
|
||||||
|
MODELS = ["facebook/opt-125m"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
|
@pytest.mark.parametrize("dtype", ["half"])
|
||||||
|
@pytest.mark.parametrize("max_tokens", MAX_TOKENS)
|
||||||
|
@pytest.mark.parametrize("beam_width", BEAM_WIDTHS)
|
||||||
|
def test_beam_search_single_input(
|
||||||
|
hf_runner,
|
||||||
|
vllm_runner,
|
||||||
|
example_prompts,
|
||||||
|
model: str,
|
||||||
|
dtype: str,
|
||||||
|
max_tokens: int,
|
||||||
|
beam_width: int,
|
||||||
|
) -> None:
|
||||||
|
hf_model = hf_runner(model, dtype=dtype)
|
||||||
|
hf_outputs = hf_model.generate_beam_search(example_prompts, beam_width,
|
||||||
|
max_tokens)
|
||||||
|
del hf_model
|
||||||
|
|
||||||
|
vllm_model = vllm_runner(model, dtype=dtype)
|
||||||
|
vllm_outputs = vllm_model.generate_beam_search(example_prompts, beam_width,
|
||||||
|
max_tokens)
|
||||||
|
del vllm_model
|
||||||
|
|
||||||
|
for i in range(len(example_prompts)):
|
||||||
|
hf_output_ids, _ = hf_outputs[i]
|
||||||
|
vllm_output_ids, _ = vllm_outputs[i]
|
||||||
|
assert len(hf_output_ids) == len(vllm_output_ids)
|
||||||
|
for j in range(len(hf_output_ids)):
|
||||||
|
assert hf_output_ids[j] == vllm_output_ids[j], (
|
||||||
|
f"Test{i} output{j}:\nHF: {hf_output_ids}\n"
|
||||||
|
f"vLLM: {vllm_output_ids}")
|
@ -8,7 +8,7 @@ from vllm.entrypoints.llm import LLM
|
|||||||
from vllm.outputs import CompletionOutput, RequestOutput
|
from vllm.outputs import CompletionOutput, RequestOutput
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
|
|
||||||
__version__ = "0.1.3"
|
__version__ = "0.1.5"
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"LLM",
|
"LLM",
|
||||||
|
@ -24,9 +24,16 @@ class ModelConfig:
|
|||||||
downloading the model and tokenizer.
|
downloading the model and tokenizer.
|
||||||
download_dir: Directory to download and load the weights, default to the
|
download_dir: Directory to download and load the weights, default to the
|
||||||
default cache directory of huggingface.
|
default cache directory of huggingface.
|
||||||
use_np_weights: Save a numpy copy of model weights for faster loading.
|
load_format: The format of the model weights to load:
|
||||||
This can increase the disk usage by up to 2x.
|
"auto" will try to load the weights in the safetensors format and
|
||||||
use_dummy_weights: Use dummy values for model weights (for profiling).
|
fall back to the pytorch bin format if safetensors format is
|
||||||
|
not available.
|
||||||
|
"pt" will load the weights in the pytorch bin format.
|
||||||
|
"safetensors" will load the weights in the safetensors format.
|
||||||
|
"npcache" will load the weights in pytorch format and store
|
||||||
|
a numpy cache to speed up the loading.
|
||||||
|
"dummy" will initialize the weights with random values, which is
|
||||||
|
mainly for profiling.
|
||||||
dtype: Data type for model weights and activations. The "auto" option
|
dtype: Data type for model weights and activations. The "auto" option
|
||||||
will use FP16 precision for FP32 and FP16 models, and BF16 precision
|
will use FP16 precision for FP32 and FP16 models, and BF16 precision
|
||||||
for BF16 models.
|
for BF16 models.
|
||||||
@ -40,8 +47,7 @@ class ModelConfig:
|
|||||||
tokenizer_mode: str,
|
tokenizer_mode: str,
|
||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
download_dir: Optional[str],
|
download_dir: Optional[str],
|
||||||
use_np_weights: bool,
|
load_format: str,
|
||||||
use_dummy_weights: bool,
|
|
||||||
dtype: str,
|
dtype: str,
|
||||||
seed: int,
|
seed: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -50,14 +56,24 @@ class ModelConfig:
|
|||||||
self.tokenizer_mode = tokenizer_mode
|
self.tokenizer_mode = tokenizer_mode
|
||||||
self.trust_remote_code = trust_remote_code
|
self.trust_remote_code = trust_remote_code
|
||||||
self.download_dir = download_dir
|
self.download_dir = download_dir
|
||||||
self.use_np_weights = use_np_weights
|
self.load_format = load_format
|
||||||
self.use_dummy_weights = use_dummy_weights
|
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
|
|
||||||
self.hf_config = get_config(model, trust_remote_code)
|
self.hf_config = get_config(model, trust_remote_code)
|
||||||
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
|
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
|
||||||
|
self._verify_load_format()
|
||||||
self._verify_tokenizer_mode()
|
self._verify_tokenizer_mode()
|
||||||
|
|
||||||
|
def _verify_load_format(self) -> None:
|
||||||
|
load_format = self.load_format.lower()
|
||||||
|
if load_format not in [
|
||||||
|
"auto", "pt", "safetensors", "npcache", "dummy"
|
||||||
|
]:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown load format: {self.load_format}. Must be one of "
|
||||||
|
"'auto', 'pt', 'safetensors', 'npcache', or 'dummy'.")
|
||||||
|
self.load_format = load_format
|
||||||
|
|
||||||
def _verify_tokenizer_mode(self) -> None:
|
def _verify_tokenizer_mode(self) -> None:
|
||||||
tokenizer_mode = self.tokenizer_mode.lower()
|
tokenizer_mode = self.tokenizer_mode.lower()
|
||||||
if tokenizer_mode not in ["auto", "slow"]:
|
if tokenizer_mode not in ["auto", "slow"]:
|
||||||
@ -98,9 +114,11 @@ class ModelConfig:
|
|||||||
# Note: for falcon, when new_decoder_architecture is True, the
|
# Note: for falcon, when new_decoder_architecture is True, the
|
||||||
# multi_query flag is ignored and we use n_head_kv for the number of
|
# multi_query flag is ignored and we use n_head_kv for the number of
|
||||||
# KV heads.
|
# KV heads.
|
||||||
if (getattr(self.hf_config, "multi_query", False) and
|
new_decoder_arch_falcon = (
|
||||||
(self.hf_config.model_type == "falcon" and
|
self.hf_config.model_type == "falcon"
|
||||||
not getattr(self.hf_config, "new_decoder_architecture", False))):
|
and getattr(self.hf_config, "new_decoder_architecture", False))
|
||||||
|
if not new_decoder_arch_falcon and getattr(self.hf_config,
|
||||||
|
"multi_query", False):
|
||||||
# Multi-query attention, only one KV head.
|
# Multi-query attention, only one KV head.
|
||||||
return 1
|
return 1
|
||||||
# For Falcon:
|
# For Falcon:
|
||||||
|
@ -172,9 +172,7 @@ class BlockSpaceManager:
|
|||||||
def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]:
|
def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]:
|
||||||
# CPU block -> GPU block.
|
# CPU block -> GPU block.
|
||||||
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
|
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
|
||||||
for seq in seq_group.get_seqs():
|
for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
|
||||||
if seq.is_finished():
|
|
||||||
continue
|
|
||||||
new_block_table: BlockTable = []
|
new_block_table: BlockTable = []
|
||||||
block_table = self.block_tables[seq.seq_id]
|
block_table = self.block_tables[seq.seq_id]
|
||||||
|
|
||||||
@ -203,9 +201,7 @@ class BlockSpaceManager:
|
|||||||
def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]:
|
def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]:
|
||||||
# GPU block -> CPU block.
|
# GPU block -> CPU block.
|
||||||
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
|
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
|
||||||
for seq in seq_group.get_seqs():
|
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||||
if seq.is_finished():
|
|
||||||
continue
|
|
||||||
new_block_table: BlockTable = []
|
new_block_table: BlockTable = []
|
||||||
block_table = self.block_tables[seq.seq_id]
|
block_table = self.block_tables[seq.seq_id]
|
||||||
|
|
||||||
|
@ -1,14 +1,13 @@
|
|||||||
import enum
|
import enum
|
||||||
import time
|
import time
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, Iterable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
from vllm.config import CacheConfig, SchedulerConfig
|
from vllm.config import CacheConfig, SchedulerConfig
|
||||||
from vllm.core.block_manager import BlockSpaceManager
|
from vllm.core.block_manager import BlockSpaceManager
|
||||||
from vllm.core.policy import PolicyFactory
|
from vllm.core.policy import PolicyFactory
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
|
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
|
||||||
SequenceGroupMetadata, SequenceOutputs,
|
SequenceGroupMetadata, SequenceStatus)
|
||||||
SequenceStatus)
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -64,6 +63,9 @@ class Scheduler:
|
|||||||
self.scheduler_config = scheduler_config
|
self.scheduler_config = scheduler_config
|
||||||
self.cache_config = cache_config
|
self.cache_config = cache_config
|
||||||
|
|
||||||
|
self.prompt_limit = min(self.scheduler_config.max_model_len,
|
||||||
|
self.scheduler_config.max_num_batched_tokens)
|
||||||
|
|
||||||
# Instantiate the scheduling policy.
|
# Instantiate the scheduling policy.
|
||||||
self.policy = PolicyFactory.get_policy(policy_name="fcfs")
|
self.policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||||
# Create the block space manager.
|
# Create the block space manager.
|
||||||
@ -73,6 +75,7 @@ class Scheduler:
|
|||||||
num_cpu_blocks=self.cache_config.num_cpu_blocks,
|
num_cpu_blocks=self.cache_config.num_cpu_blocks,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# TODO(zhuohan): Use deque instead of list for better performance.
|
||||||
# Sequence groups in the WAITING state.
|
# Sequence groups in the WAITING state.
|
||||||
self.waiting: List[SequenceGroup] = []
|
self.waiting: List[SequenceGroup] = []
|
||||||
# Sequence groups in the RUNNING state.
|
# Sequence groups in the RUNNING state.
|
||||||
@ -84,16 +87,25 @@ class Scheduler:
|
|||||||
# Add sequence groups to the waiting queue.
|
# Add sequence groups to the waiting queue.
|
||||||
self.waiting.append(seq_group)
|
self.waiting.append(seq_group)
|
||||||
|
|
||||||
def abort_seq_group(self, request_id: str) -> None:
|
def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None:
|
||||||
|
if isinstance(request_id, str):
|
||||||
|
request_id = (request_id, )
|
||||||
|
request_ids = set(request_id)
|
||||||
for state_queue in [self.waiting, self.running, self.swapped]:
|
for state_queue in [self.waiting, self.running, self.swapped]:
|
||||||
for seq_group in state_queue:
|
# We need to reverse the list as we are removing elements
|
||||||
if seq_group.request_id == request_id:
|
# from it as we iterate over it. If we don't do it,
|
||||||
|
# indices will get messed up and we will skip over elements.
|
||||||
|
for seq_group in reversed(state_queue):
|
||||||
|
if seq_group.request_id in request_ids:
|
||||||
# Remove the sequence group from the state queue.
|
# Remove the sequence group from the state queue.
|
||||||
state_queue.remove(seq_group)
|
state_queue.remove(seq_group)
|
||||||
for seq in seq_group.seqs:
|
for seq in seq_group.get_seqs():
|
||||||
if seq.is_finished():
|
if seq.is_finished():
|
||||||
continue
|
continue
|
||||||
self.free_seq(seq, SequenceStatus.FINISHED_ABORTED)
|
seq.status = SequenceStatus.FINISHED_ABORTED
|
||||||
|
self.free_seq(seq)
|
||||||
|
request_ids.remove(seq_group.request_id)
|
||||||
|
if not request_ids:
|
||||||
return
|
return
|
||||||
|
|
||||||
def has_unfinished_seqs(self) -> bool:
|
def has_unfinished_seqs(self) -> bool:
|
||||||
@ -115,6 +127,10 @@ class Scheduler:
|
|||||||
if not self.swapped:
|
if not self.swapped:
|
||||||
ignored_seq_groups: List[SequenceGroup] = []
|
ignored_seq_groups: List[SequenceGroup] = []
|
||||||
scheduled: List[SequenceGroup] = []
|
scheduled: List[SequenceGroup] = []
|
||||||
|
# The total number of sequences on the fly, including the
|
||||||
|
# requests in the generation phase.
|
||||||
|
num_curr_seqs = sum(seq_group.get_max_num_running_seqs()
|
||||||
|
for seq_group in self.running)
|
||||||
num_batched_tokens = 0
|
num_batched_tokens = 0
|
||||||
# Optimization: We do not sort the waiting queue since the preempted
|
# Optimization: We do not sort the waiting queue since the preempted
|
||||||
# sequence groups are added to the front and the new sequence groups
|
# sequence groups are added to the front and the new sequence groups
|
||||||
@ -122,19 +138,19 @@ class Scheduler:
|
|||||||
while self.waiting:
|
while self.waiting:
|
||||||
seq_group = self.waiting[0]
|
seq_group = self.waiting[0]
|
||||||
|
|
||||||
|
assert seq_group.num_seqs() == 1, (
|
||||||
|
"Waiting sequence group should have only one prompt "
|
||||||
|
"sequence.")
|
||||||
num_prompt_tokens = seq_group.get_seqs()[0].get_len()
|
num_prompt_tokens = seq_group.get_seqs()[0].get_len()
|
||||||
prompt_limit = min(
|
if num_prompt_tokens > self.prompt_limit:
|
||||||
self.scheduler_config.max_model_len,
|
|
||||||
self.scheduler_config.max_num_batched_tokens)
|
|
||||||
if num_prompt_tokens > prompt_limit:
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Input prompt ({num_prompt_tokens} tokens) is too long"
|
f"Input prompt ({num_prompt_tokens} tokens) is too long"
|
||||||
f" and exceeds limit of {prompt_limit}")
|
f" and exceeds limit of {self.prompt_limit}")
|
||||||
for seq in seq_group.get_seqs():
|
for seq in seq_group.get_seqs():
|
||||||
seq.status = SequenceStatus.FINISHED_IGNORED
|
seq.status = SequenceStatus.FINISHED_IGNORED
|
||||||
ignored_seq_groups.append(seq_group)
|
ignored_seq_groups.append(seq_group)
|
||||||
self.waiting.pop(0)
|
self.waiting.pop(0)
|
||||||
break
|
continue
|
||||||
|
|
||||||
# If the sequence group cannot be allocated, stop.
|
# If the sequence group cannot be allocated, stop.
|
||||||
if not self.block_manager.can_allocate(seq_group):
|
if not self.block_manager.can_allocate(seq_group):
|
||||||
@ -147,11 +163,7 @@ class Scheduler:
|
|||||||
|
|
||||||
# The total number of sequences in the RUNNING state should not
|
# The total number of sequences in the RUNNING state should not
|
||||||
# exceed the maximum number of sequences.
|
# exceed the maximum number of sequences.
|
||||||
num_new_seqs = seq_group.num_seqs(
|
num_new_seqs = seq_group.get_max_num_running_seqs()
|
||||||
status=SequenceStatus.WAITING)
|
|
||||||
num_curr_seqs = sum(
|
|
||||||
seq_group.num_seqs(status=SequenceStatus.RUNNING)
|
|
||||||
for seq_group in self.running)
|
|
||||||
if (num_curr_seqs + num_new_seqs >
|
if (num_curr_seqs + num_new_seqs >
|
||||||
self.scheduler_config.max_num_seqs):
|
self.scheduler_config.max_num_seqs):
|
||||||
break
|
break
|
||||||
@ -160,6 +172,7 @@ class Scheduler:
|
|||||||
self._allocate(seq_group)
|
self._allocate(seq_group)
|
||||||
self.running.append(seq_group)
|
self.running.append(seq_group)
|
||||||
num_batched_tokens += num_prompt_tokens
|
num_batched_tokens += num_prompt_tokens
|
||||||
|
num_curr_seqs += num_new_seqs
|
||||||
scheduled.append(seq_group)
|
scheduled.append(seq_group)
|
||||||
|
|
||||||
if scheduled:
|
if scheduled:
|
||||||
@ -205,21 +218,19 @@ class Scheduler:
|
|||||||
|
|
||||||
# Swap in the sequence groups in the SWAPPED state if possible.
|
# Swap in the sequence groups in the SWAPPED state if possible.
|
||||||
self.swapped = self.policy.sort_by_priority(now, self.swapped)
|
self.swapped = self.policy.sort_by_priority(now, self.swapped)
|
||||||
while self.swapped and not blocks_to_swap_out:
|
if not preempted:
|
||||||
|
num_curr_seqs = sum(seq_group.get_max_num_running_seqs()
|
||||||
|
for seq_group in self.running)
|
||||||
|
|
||||||
|
while self.swapped:
|
||||||
seq_group = self.swapped[0]
|
seq_group = self.swapped[0]
|
||||||
# If the sequence group has been preempted in this step, stop.
|
|
||||||
if seq_group in preempted:
|
|
||||||
break
|
|
||||||
# If the sequence group cannot be swapped in, stop.
|
# If the sequence group cannot be swapped in, stop.
|
||||||
if not self.block_manager.can_swap_in(seq_group):
|
if not self.block_manager.can_swap_in(seq_group):
|
||||||
break
|
break
|
||||||
|
|
||||||
# The total number of sequences in the RUNNING state should not
|
# The total number of sequences in the RUNNING state should not
|
||||||
# exceed the maximum number of sequences.
|
# exceed the maximum number of sequences.
|
||||||
num_new_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED)
|
num_new_seqs = seq_group.get_max_num_running_seqs()
|
||||||
num_curr_seqs = sum(
|
|
||||||
seq_group.num_seqs(status=SequenceStatus.RUNNING)
|
|
||||||
for seq_group in self.running)
|
|
||||||
if (num_curr_seqs + num_new_seqs >
|
if (num_curr_seqs + num_new_seqs >
|
||||||
self.scheduler_config.max_num_seqs):
|
self.scheduler_config.max_num_seqs):
|
||||||
break
|
break
|
||||||
@ -227,8 +238,12 @@ class Scheduler:
|
|||||||
seq_group = self.swapped.pop(0)
|
seq_group = self.swapped.pop(0)
|
||||||
self._swap_in(seq_group, blocks_to_swap_in)
|
self._swap_in(seq_group, blocks_to_swap_in)
|
||||||
self._append_slot(seq_group, blocks_to_copy)
|
self._append_slot(seq_group, blocks_to_copy)
|
||||||
|
num_curr_seqs += num_new_seqs
|
||||||
self.running.append(seq_group)
|
self.running.append(seq_group)
|
||||||
|
|
||||||
|
# Each sequence in the generation phase only takes one token slot.
|
||||||
|
# Therefore, the number of batched tokens is equal to the number of
|
||||||
|
# sequences in the RUNNING state.
|
||||||
num_batched_tokens = sum(
|
num_batched_tokens = sum(
|
||||||
seq_group.num_seqs(status=SequenceStatus.RUNNING)
|
seq_group.num_seqs(status=SequenceStatus.RUNNING)
|
||||||
for seq_group in self.running)
|
for seq_group in self.running)
|
||||||
@ -270,40 +285,10 @@ class Scheduler:
|
|||||||
seq_group_metadata_list.append(seq_group_metadata)
|
seq_group_metadata_list.append(seq_group_metadata)
|
||||||
return seq_group_metadata_list, scheduler_outputs
|
return seq_group_metadata_list, scheduler_outputs
|
||||||
|
|
||||||
def update(
|
def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None:
|
||||||
self,
|
self.block_manager.fork(parent_seq, child_seq)
|
||||||
seq_outputs: Dict[int, SequenceOutputs],
|
|
||||||
) -> List[SequenceGroup]:
|
|
||||||
scheduled: List[SequenceGroup] = []
|
|
||||||
for seq_group in self.running:
|
|
||||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
|
||||||
if seq.seq_id in seq_outputs:
|
|
||||||
scheduled.append(seq_group)
|
|
||||||
break
|
|
||||||
|
|
||||||
# Update the scheduled sequences and free blocks.
|
def free_seq(self, seq: Sequence) -> None:
|
||||||
for seq_group in scheduled:
|
|
||||||
# Process beam search results before processing the new tokens.
|
|
||||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
|
||||||
output = seq_outputs[seq.seq_id]
|
|
||||||
if seq.seq_id != output.parent_seq_id:
|
|
||||||
# The sequence is a fork of the parent sequence (beam
|
|
||||||
# search). Free the current sequence.
|
|
||||||
self.block_manager.free(seq)
|
|
||||||
# Fork the parent sequence.
|
|
||||||
parent_seq = seq_group.find(output.parent_seq_id)
|
|
||||||
parent_seq.fork(seq)
|
|
||||||
self.block_manager.fork(parent_seq, seq)
|
|
||||||
|
|
||||||
# Process the new tokens.
|
|
||||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
|
||||||
# Append a new token to the sequence.
|
|
||||||
output = seq_outputs[seq.seq_id]
|
|
||||||
seq.append_token_id(output.output_token, output.logprobs)
|
|
||||||
return scheduled
|
|
||||||
|
|
||||||
def free_seq(self, seq: Sequence, finish_status: SequenceStatus) -> None:
|
|
||||||
seq.status = finish_status
|
|
||||||
self.block_manager.free(seq)
|
self.block_manager.free(seq)
|
||||||
|
|
||||||
def free_finished_seq_groups(self) -> None:
|
def free_finished_seq_groups(self) -> None:
|
||||||
@ -340,8 +325,8 @@ class Scheduler:
|
|||||||
# If preemption mode is not specified, we determine the mode as follows:
|
# If preemption mode is not specified, we determine the mode as follows:
|
||||||
# We use recomputation by default since it incurs lower overhead than
|
# We use recomputation by default since it incurs lower overhead than
|
||||||
# swapping. However, when the sequence group has multiple sequences
|
# swapping. However, when the sequence group has multiple sequences
|
||||||
# (e.g., beam search), recomputation is not supported. In such a case,
|
# (e.g., beam search), recomputation is not currently supported. In
|
||||||
# we use swapping instead.
|
# such a case, we use swapping instead.
|
||||||
# FIXME(woosuk): This makes our scheduling policy a bit bizarre.
|
# FIXME(woosuk): This makes our scheduling policy a bit bizarre.
|
||||||
# As swapped sequences are prioritized over waiting sequences,
|
# As swapped sequences are prioritized over waiting sequences,
|
||||||
# sequence groups with multiple sequences are implicitly prioritized
|
# sequence groups with multiple sequences are implicitly prioritized
|
||||||
@ -349,8 +334,7 @@ class Scheduler:
|
|||||||
# TODO(woosuk): Support recomputation for sequence groups with multiple
|
# TODO(woosuk): Support recomputation for sequence groups with multiple
|
||||||
# sequences. This may require a more sophisticated CUDA kernel.
|
# sequences. This may require a more sophisticated CUDA kernel.
|
||||||
if preemption_mode is None:
|
if preemption_mode is None:
|
||||||
seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
|
if seq_group.get_max_num_running_seqs() == 1:
|
||||||
if len(seqs) == 1:
|
|
||||||
preemption_mode = PreemptionMode.RECOMPUTE
|
preemption_mode = PreemptionMode.RECOMPUTE
|
||||||
else:
|
else:
|
||||||
preemption_mode = PreemptionMode.SWAP
|
preemption_mode = PreemptionMode.SWAP
|
||||||
@ -379,9 +363,6 @@ class Scheduler:
|
|||||||
seq_group: SequenceGroup,
|
seq_group: SequenceGroup,
|
||||||
blocks_to_swap_out: Dict[int, int],
|
blocks_to_swap_out: Dict[int, int],
|
||||||
) -> None:
|
) -> None:
|
||||||
seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
|
|
||||||
for seq in seqs:
|
|
||||||
seq.status = SequenceStatus.SWAPPED
|
|
||||||
self._swap_out(seq_group, blocks_to_swap_out)
|
self._swap_out(seq_group, blocks_to_swap_out)
|
||||||
self.swapped.append(seq_group)
|
self.swapped.append(seq_group)
|
||||||
|
|
||||||
|
@ -15,8 +15,7 @@ class EngineArgs:
|
|||||||
tokenizer_mode: str = 'auto'
|
tokenizer_mode: str = 'auto'
|
||||||
trust_remote_code: bool = False
|
trust_remote_code: bool = False
|
||||||
download_dir: Optional[str] = None
|
download_dir: Optional[str] = None
|
||||||
use_np_weights: bool = False
|
load_format: str = 'auto'
|
||||||
use_dummy_weights: bool = False
|
|
||||||
dtype: str = 'auto'
|
dtype: str = 'auto'
|
||||||
seed: int = 0
|
seed: int = 0
|
||||||
worker_use_ray: bool = False
|
worker_use_ray: bool = False
|
||||||
@ -65,14 +64,21 @@ class EngineArgs:
|
|||||||
help='directory to download and load the weights, '
|
help='directory to download and load the weights, '
|
||||||
'default to the default cache dir of '
|
'default to the default cache dir of '
|
||||||
'huggingface')
|
'huggingface')
|
||||||
parser.add_argument('--use-np-weights',
|
parser.add_argument(
|
||||||
action='store_true',
|
'--load-format',
|
||||||
help='save a numpy copy of model weights for '
|
type=str,
|
||||||
'faster loading. This can increase the disk '
|
default=EngineArgs.load_format,
|
||||||
'usage by up to 2x.')
|
choices=['auto', 'pt', 'safetensors', 'npcache', 'dummy'],
|
||||||
parser.add_argument('--use-dummy-weights',
|
help='The format of the model weights to load. '
|
||||||
action='store_true',
|
'"auto" will try to load the weights in the safetensors format '
|
||||||
help='use dummy values for model weights')
|
'and fall back to the pytorch bin format if safetensors format '
|
||||||
|
'is not available. '
|
||||||
|
'"pt" will load the weights in the pytorch bin format. '
|
||||||
|
'"safetensors" will load the weights in the safetensors format. '
|
||||||
|
'"npcache" will load the weights in pytorch format and store '
|
||||||
|
'a numpy cache to speed up the loading. '
|
||||||
|
'"dummy" will initialize the weights with random values, '
|
||||||
|
'which is mainly for profiling.')
|
||||||
# TODO(woosuk): Support FP32.
|
# TODO(woosuk): Support FP32.
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--dtype',
|
'--dtype',
|
||||||
@ -146,9 +152,8 @@ class EngineArgs:
|
|||||||
# Initialize the configs.
|
# Initialize the configs.
|
||||||
model_config = ModelConfig(self.model, self.tokenizer,
|
model_config = ModelConfig(self.model, self.tokenizer,
|
||||||
self.tokenizer_mode, self.trust_remote_code,
|
self.tokenizer_mode, self.trust_remote_code,
|
||||||
self.download_dir, self.use_np_weights,
|
self.download_dir, self.load_format,
|
||||||
self.use_dummy_weights, self.dtype,
|
self.dtype, self.seed)
|
||||||
self.seed)
|
|
||||||
cache_config = CacheConfig(self.block_size,
|
cache_config = CacheConfig(self.block_size,
|
||||||
self.gpu_memory_utilization,
|
self.gpu_memory_utilization,
|
||||||
self.swap_space)
|
self.swap_space)
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
from typing import Dict, List, Optional
|
from functools import partial
|
||||||
|
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
|
||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
@ -12,7 +13,202 @@ from vllm.sampling_params import SamplingParams
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
|
|
||||||
|
class AsyncEngineDeadError(RuntimeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _raise_exception_on_finish(task: asyncio.Task,
|
||||||
|
request_tracker: "RequestTracker") -> None:
|
||||||
|
msg = ("Task finished unexpectedly. This should never happen! "
|
||||||
|
"Please open an issue on Github.")
|
||||||
|
try:
|
||||||
|
try:
|
||||||
|
task.result()
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
return
|
||||||
|
except Exception as exc:
|
||||||
|
raise AsyncEngineDeadError(
|
||||||
|
msg + " See stack trace above for the actual cause.") from exc
|
||||||
|
raise AsyncEngineDeadError(msg)
|
||||||
|
except Exception as exc:
|
||||||
|
request_tracker.propagate_exception(exc)
|
||||||
|
raise exc
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncStream:
|
||||||
|
"""A stream of RequestOutputs for a request that can be
|
||||||
|
iterated over asynchronously."""
|
||||||
|
|
||||||
|
def __init__(self, request_id: str) -> None:
|
||||||
|
self.request_id = request_id
|
||||||
|
self._queue = asyncio.Queue()
|
||||||
|
self._finished = False
|
||||||
|
|
||||||
|
def put(self, item: RequestOutput) -> None:
|
||||||
|
if self._finished:
|
||||||
|
return
|
||||||
|
self._queue.put_nowait(item)
|
||||||
|
|
||||||
|
def finish(self) -> None:
|
||||||
|
self._queue.put_nowait(StopIteration)
|
||||||
|
self._finished = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def finished(self) -> bool:
|
||||||
|
return self._finished
|
||||||
|
|
||||||
|
def __aiter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __anext__(self) -> RequestOutput:
|
||||||
|
result = await self._queue.get()
|
||||||
|
if result is StopIteration:
|
||||||
|
raise StopAsyncIteration
|
||||||
|
elif isinstance(result, Exception):
|
||||||
|
raise result
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class RequestTracker:
|
||||||
|
"""Synchronous abstraction for tracking requests."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._request_streams: Dict[str, AsyncStream] = {}
|
||||||
|
self._finished_requests: asyncio.Queue[str] = asyncio.Queue()
|
||||||
|
self._new_requests: asyncio.Queue[Tuple[AsyncStream,
|
||||||
|
dict]] = asyncio.Queue()
|
||||||
|
|
||||||
|
def __contains__(self, item):
|
||||||
|
return item in self._request_streams
|
||||||
|
|
||||||
|
def propagate_exception(self, exc: Exception) -> None:
|
||||||
|
"""Propagate an exception to all request streams."""
|
||||||
|
for stream in self._request_streams.values():
|
||||||
|
stream.put(exc)
|
||||||
|
|
||||||
|
def process_request_output(self,
|
||||||
|
request_output: RequestOutput,
|
||||||
|
*,
|
||||||
|
verbose: bool = False) -> None:
|
||||||
|
"""Process a request output from the engine."""
|
||||||
|
request_id = request_output.request_id
|
||||||
|
|
||||||
|
self._request_streams[request_id].put(request_output)
|
||||||
|
if request_output.finished:
|
||||||
|
if verbose:
|
||||||
|
logger.info(f"Finished request {request_id}.")
|
||||||
|
self.abort_request(request_id)
|
||||||
|
|
||||||
|
def add_request(self, request_id: str,
|
||||||
|
**engine_add_request_kwargs) -> AsyncStream:
|
||||||
|
"""Add a request to be sent to the engine on the next background
|
||||||
|
loop iteration."""
|
||||||
|
if request_id in self._request_streams:
|
||||||
|
raise KeyError(f"Request {request_id} already exists.")
|
||||||
|
|
||||||
|
stream = AsyncStream(request_id)
|
||||||
|
self._new_requests.put_nowait((stream, {
|
||||||
|
"request_id": request_id,
|
||||||
|
**engine_add_request_kwargs
|
||||||
|
}))
|
||||||
|
return stream
|
||||||
|
|
||||||
|
def abort_request(self, request_id: str, *, verbose: bool = False) -> None:
|
||||||
|
"""Abort a request during next background loop iteration."""
|
||||||
|
if verbose:
|
||||||
|
logger.info(f"Aborted request {request_id}.")
|
||||||
|
|
||||||
|
self._finished_requests.put_nowait(request_id)
|
||||||
|
|
||||||
|
if request_id not in self._request_streams or self._request_streams[
|
||||||
|
request_id].finished:
|
||||||
|
# The request has already finished or been aborted.
|
||||||
|
return
|
||||||
|
|
||||||
|
self._request_streams[request_id].finish()
|
||||||
|
|
||||||
|
def get_new_and_finished_requests(self) -> Tuple[List[dict], Set[str]]:
|
||||||
|
"""Get the new requests and finished requests to be
|
||||||
|
sent to the engine."""
|
||||||
|
new_requests: List[dict] = []
|
||||||
|
finished_requests: Set[str] = set()
|
||||||
|
|
||||||
|
while not self._finished_requests.empty():
|
||||||
|
request_id = self._finished_requests.get_nowait()
|
||||||
|
finished_requests.add(request_id)
|
||||||
|
self._request_streams.pop(request_id, None)
|
||||||
|
|
||||||
|
while not self._new_requests.empty():
|
||||||
|
stream, new_request = self._new_requests.get_nowait()
|
||||||
|
if stream.request_id in finished_requests:
|
||||||
|
# The request has already been aborted.
|
||||||
|
stream.finish()
|
||||||
|
continue
|
||||||
|
self._request_streams[stream.request_id] = stream
|
||||||
|
new_requests.append(new_request)
|
||||||
|
|
||||||
|
return new_requests, finished_requests
|
||||||
|
|
||||||
|
|
||||||
|
class _AsyncLLMEngine(LLMEngine):
|
||||||
|
"""Extension of LLMEngine to add async methods."""
|
||||||
|
|
||||||
|
async def step_async(self) -> List[RequestOutput]:
|
||||||
|
"""Performs one decoding iteration and returns newly generated results.
|
||||||
|
The workers are ran asynchronously if possible.
|
||||||
|
|
||||||
|
This function performs one decoding iteration of the engine. It first
|
||||||
|
schedules the sequences to be executed in the next iteration and the
|
||||||
|
token blocks to be swapped in/out/copy. Then, it executes the model
|
||||||
|
and updates the scheduler with the model outputs. Finally, it decodes
|
||||||
|
the sequences and returns the newly generated results.
|
||||||
|
"""
|
||||||
|
(seq_group_metadata_list, scheduler_outputs,
|
||||||
|
early_return) = self._schedule()
|
||||||
|
if early_return is not None:
|
||||||
|
return early_return
|
||||||
|
|
||||||
|
# Execute the model.
|
||||||
|
output = await self._run_workers_async(
|
||||||
|
"execute_model",
|
||||||
|
seq_group_metadata_list=seq_group_metadata_list,
|
||||||
|
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
|
||||||
|
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
|
||||||
|
blocks_to_copy=scheduler_outputs.blocks_to_copy,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._process_model_outputs(output, scheduler_outputs)
|
||||||
|
|
||||||
|
async def _run_workers_async(
|
||||||
|
self,
|
||||||
|
method: str,
|
||||||
|
*args,
|
||||||
|
get_all_outputs: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> Any:
|
||||||
|
"""Runs the given method on all workers."""
|
||||||
|
all_outputs = []
|
||||||
|
for worker in self.workers:
|
||||||
|
if self.parallel_config.worker_use_ray:
|
||||||
|
executor = partial(worker.execute_method.remote, method)
|
||||||
|
else:
|
||||||
|
executor = getattr(worker, method)
|
||||||
|
|
||||||
|
output = executor(*args, **kwargs)
|
||||||
|
all_outputs.append(output)
|
||||||
|
|
||||||
|
if self.parallel_config.worker_use_ray:
|
||||||
|
all_outputs = await asyncio.gather(*all_outputs)
|
||||||
|
|
||||||
|
if get_all_outputs:
|
||||||
|
return all_outputs
|
||||||
|
|
||||||
|
# Make sure all workers have the same results.
|
||||||
|
output = all_outputs[0]
|
||||||
|
for other_output in all_outputs[1:]:
|
||||||
|
assert output == other_output
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
class AsyncLLMEngine:
|
class AsyncLLMEngine:
|
||||||
@ -37,49 +233,117 @@ class AsyncLLMEngine:
|
|||||||
*args, *kwargs: Arguments for LLMEngine.
|
*args, *kwargs: Arguments for LLMEngine.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
_engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
worker_use_ray: bool,
|
worker_use_ray: bool,
|
||||||
engine_use_ray: bool,
|
engine_use_ray: bool,
|
||||||
*args,
|
*args,
|
||||||
log_requests: bool = True,
|
log_requests: bool = True,
|
||||||
|
start_engine_loop: bool = False,
|
||||||
**kwargs) -> None:
|
**kwargs) -> None:
|
||||||
self.worker_use_ray = worker_use_ray
|
self.worker_use_ray = worker_use_ray
|
||||||
self.engine_use_ray = engine_use_ray
|
self.engine_use_ray = engine_use_ray
|
||||||
self.log_requests = log_requests
|
self.log_requests = log_requests
|
||||||
if not self.engine_use_ray:
|
self.engine = self._init_engine(*args, **kwargs)
|
||||||
engine_class = LLMEngine
|
|
||||||
elif self.worker_use_ray:
|
|
||||||
engine_class = ray.remote(num_cpus=0)(LLMEngine).remote
|
|
||||||
else:
|
|
||||||
engine_class = ray.remote(num_gpus=1)(LLMEngine).remote
|
|
||||||
self.engine = engine_class(*args, **kwargs)
|
|
||||||
# Request id -> request output.
|
|
||||||
self.request_outputs: Dict[str, RequestOutput] = {}
|
|
||||||
# Request id -> event to notify that there is new output.
|
|
||||||
self.request_events: Dict[str, asyncio.Event] = {}
|
|
||||||
self.is_engine_running = False
|
|
||||||
self.kicking_request_id: Optional[str] = None
|
|
||||||
|
|
||||||
async def engine_step(self, kicking_request_id: Optional[str] = None):
|
self.request_tracker: RequestTracker = RequestTracker()
|
||||||
|
self.background_loop = None
|
||||||
|
if start_engine_loop:
|
||||||
|
self.start_background_loop()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_running(self) -> bool:
|
||||||
|
return (self.background_loop is not None
|
||||||
|
and not self.background_loop.done())
|
||||||
|
|
||||||
|
def start_background_loop(self) -> None:
|
||||||
|
"""Start the background loop."""
|
||||||
|
if self.is_running:
|
||||||
|
raise RuntimeError("Background loop is already running.")
|
||||||
|
self.background_loop = asyncio.get_event_loop().create_task(
|
||||||
|
self.run_engine_loop())
|
||||||
|
self.background_loop.add_done_callback(
|
||||||
|
partial(_raise_exception_on_finish,
|
||||||
|
request_tracker=self.request_tracker))
|
||||||
|
|
||||||
|
def _init_engine(self, *args,
|
||||||
|
**kwargs) -> Union[_AsyncLLMEngine, "ray.ObjectRef"]:
|
||||||
|
if not self.engine_use_ray:
|
||||||
|
engine_class = self._engine_class
|
||||||
|
elif self.worker_use_ray:
|
||||||
|
engine_class = ray.remote(num_cpus=0)(self._engine_class).remote
|
||||||
|
else:
|
||||||
|
engine_class = ray.remote(num_gpus=1)(self._engine_class).remote
|
||||||
|
return engine_class(*args, **kwargs)
|
||||||
|
|
||||||
|
async def engine_step(self):
|
||||||
"""Kick the engine to process the waiting requests."""
|
"""Kick the engine to process the waiting requests."""
|
||||||
self.is_engine_running = True
|
|
||||||
self.kicking_request_id = kicking_request_id
|
new_requests, finished_requests = (
|
||||||
|
self.request_tracker.get_new_and_finished_requests())
|
||||||
|
|
||||||
|
for new_request in new_requests:
|
||||||
|
# Add the request into the vLLM engine's waiting queue.
|
||||||
|
# TODO: Maybe add add_request_batch to reduce Ray overhead
|
||||||
|
if self.engine_use_ray:
|
||||||
|
await self.engine.add_request.remote(**new_request)
|
||||||
|
else:
|
||||||
|
self.engine.add_request(**new_request)
|
||||||
|
|
||||||
|
if finished_requests:
|
||||||
|
await self._engine_abort(finished_requests)
|
||||||
|
|
||||||
if self.engine_use_ray:
|
if self.engine_use_ray:
|
||||||
request_outputs = await self.engine.step.remote()
|
request_outputs = await self.engine.step.remote()
|
||||||
else:
|
else:
|
||||||
# Yield to the event loop to allow other coroutines to run
|
request_outputs = await self.engine.step_async()
|
||||||
# while is_engine_running is True. This let the engine to add new
|
|
||||||
# requests into the queue.
|
|
||||||
await asyncio.sleep(0)
|
|
||||||
request_outputs = self.engine.step()
|
|
||||||
self.is_engine_running = False
|
|
||||||
self.kicking_request_id = None
|
|
||||||
|
|
||||||
# Notify the waiting coroutines that there are new outputs ready.
|
# Put the outputs into the corresponding streams.
|
||||||
for request_output in request_outputs:
|
for request_output in request_outputs:
|
||||||
request_id = request_output.request_id
|
self.request_tracker.process_request_output(
|
||||||
self.request_outputs[request_id] = request_output
|
request_output, verbose=self.log_requests)
|
||||||
self.request_events[request_id].set()
|
|
||||||
|
async def _engine_abort(self, request_ids: Iterable[str]):
|
||||||
|
if self.engine_use_ray:
|
||||||
|
await self.engine.abort_request.remote(request_ids)
|
||||||
|
else:
|
||||||
|
self.engine.abort_request(request_ids)
|
||||||
|
|
||||||
|
async def run_engine_loop(self):
|
||||||
|
while True:
|
||||||
|
await self.engine_step()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
async def add_request(
|
||||||
|
self,
|
||||||
|
request_id: str,
|
||||||
|
prompt: Optional[str],
|
||||||
|
sampling_params: SamplingParams,
|
||||||
|
prompt_token_ids: Optional[List[int]] = None,
|
||||||
|
arrival_time: Optional[float] = None,
|
||||||
|
) -> AsyncStream:
|
||||||
|
if self.log_requests:
|
||||||
|
logger.info(f"Received request {request_id}: "
|
||||||
|
f"prompt: {prompt!r}, "
|
||||||
|
f"sampling params: {sampling_params}, "
|
||||||
|
f"prompt token ids: {prompt_token_ids}.")
|
||||||
|
|
||||||
|
if not self.is_running:
|
||||||
|
raise AsyncEngineDeadError(
|
||||||
|
"Background loop is not running. If it was running, "
|
||||||
|
"inspect the output to find the stacktrace of the "
|
||||||
|
"error that caused the background loop to stop "
|
||||||
|
"(AsyncEngineDeadError).")
|
||||||
|
|
||||||
|
stream = self.request_tracker.add_request(
|
||||||
|
request_id,
|
||||||
|
prompt=prompt,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
prompt_token_ids=prompt_token_ids,
|
||||||
|
arrival_time=arrival_time)
|
||||||
|
|
||||||
|
return stream
|
||||||
|
|
||||||
async def generate(
|
async def generate(
|
||||||
self,
|
self,
|
||||||
@ -108,76 +372,19 @@ class AsyncLLMEngine:
|
|||||||
# Preprocess the request.
|
# Preprocess the request.
|
||||||
arrival_time = time.time()
|
arrival_time = time.time()
|
||||||
|
|
||||||
# Create an event to notify us that there is new output from the
|
try:
|
||||||
# vLLM engine.
|
stream = await self.add_request(request_id,
|
||||||
request_event = asyncio.Event()
|
|
||||||
self.request_events[request_id] = request_event
|
|
||||||
|
|
||||||
if self.log_requests:
|
|
||||||
logger.info(f"Received request {request_id}: "
|
|
||||||
f"prompt: {prompt!r}, "
|
|
||||||
f"sampling params: {sampling_params}, "
|
|
||||||
f"prompt token ids: {prompt_token_ids}.")
|
|
||||||
|
|
||||||
# Add the request into the vLLM engine's waiting queue.
|
|
||||||
if self.engine_use_ray:
|
|
||||||
await self.engine.add_request.remote(
|
|
||||||
request_id,
|
|
||||||
prompt,
|
|
||||||
sampling_params,
|
|
||||||
prompt_token_ids=prompt_token_ids,
|
|
||||||
arrival_time=arrival_time)
|
|
||||||
else:
|
|
||||||
self.engine.add_request(request_id,
|
|
||||||
prompt,
|
prompt,
|
||||||
sampling_params,
|
sampling_params,
|
||||||
prompt_token_ids=prompt_token_ids,
|
prompt_token_ids=prompt_token_ids,
|
||||||
arrival_time=arrival_time)
|
arrival_time=arrival_time)
|
||||||
|
|
||||||
# The vLLM engine does not have a background loop that keeps
|
async for request_output in stream:
|
||||||
# processing incoming requests. Therefore, we need to keep kicking
|
|
||||||
# the engine to process the requests.
|
|
||||||
while True:
|
|
||||||
if request_id not in self.request_events:
|
|
||||||
# The request has been aborted.
|
|
||||||
return
|
|
||||||
|
|
||||||
# Kick the engine if the engine is not running.
|
|
||||||
if not self.is_engine_running:
|
|
||||||
try:
|
|
||||||
await self.engine_step(request_id)
|
|
||||||
except RuntimeError as e:
|
|
||||||
await self.abort(request_id)
|
|
||||||
raise e
|
|
||||||
|
|
||||||
# Wait for new output. The group_event will be set in engine_step
|
|
||||||
# when there is new output available for the sequence group.
|
|
||||||
# Added a timeout to prevent deadlock.
|
|
||||||
try:
|
|
||||||
await asyncio.wait_for(request_event.wait(),
|
|
||||||
timeout=TIMEOUT_TO_PREVENT_DEADLOCK)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
continue
|
|
||||||
# Reset the event to wait for the next output.
|
|
||||||
request_event.clear()
|
|
||||||
|
|
||||||
# Decode and return new outputs.
|
|
||||||
request_output = self.request_outputs[request_id]
|
|
||||||
yield request_output
|
yield request_output
|
||||||
|
except Exception as e:
|
||||||
# Once finished, release the resources of the sequence group.
|
# If there is an exception, abort the request.
|
||||||
if request_output.finished:
|
self._abort(request_id)
|
||||||
if self.log_requests:
|
raise e
|
||||||
logger.info(f"Finished request {request_id}.")
|
|
||||||
|
|
||||||
del self.request_outputs[request_id]
|
|
||||||
del self.request_events[request_id]
|
|
||||||
# Kick the engine if the engine is not running. This is to
|
|
||||||
# prevent that there are still requests in engine's waiting
|
|
||||||
# queue to be executed.
|
|
||||||
if not self.is_engine_running:
|
|
||||||
await self.engine_step()
|
|
||||||
break
|
|
||||||
|
|
||||||
async def abort(self, request_id: str) -> None:
|
async def abort(self, request_id: str) -> None:
|
||||||
"""Abort a request.
|
"""Abort a request.
|
||||||
@ -188,28 +395,26 @@ class AsyncLLMEngine:
|
|||||||
Args:
|
Args:
|
||||||
request_id: The unique id of the request.
|
request_id: The unique id of the request.
|
||||||
"""
|
"""
|
||||||
if request_id not in self.request_events:
|
if not self.is_running:
|
||||||
# The request has already finished or been aborted.
|
raise AsyncEngineDeadError(
|
||||||
return
|
"Background loop is not running. If it was running, "
|
||||||
|
"inspect the output to find the stacktrace of the "
|
||||||
|
"error that caused the background loop to stop "
|
||||||
|
"(AsyncEngineDeadError).")
|
||||||
|
|
||||||
if self.log_requests:
|
return self._abort(request_id)
|
||||||
logger.info(f"Aborted request {request_id}.")
|
|
||||||
|
|
||||||
if self.engine_use_ray:
|
def _abort(self, request_id: str) -> None:
|
||||||
await self.engine.abort_request.remote(request_id)
|
"""Abort a request.
|
||||||
else:
|
|
||||||
self.engine.abort_request(request_id)
|
|
||||||
|
|
||||||
if request_id in self.request_events:
|
Abort a submitted request. If the request is finished or not found,
|
||||||
del self.request_events[request_id]
|
this method will be a no-op.
|
||||||
if request_id in self.request_outputs:
|
|
||||||
del self.request_outputs[request_id]
|
|
||||||
|
|
||||||
# To prevent deadlock when a request is aborted while the engine is
|
Args:
|
||||||
# running.
|
request_id: The unique id of the request.
|
||||||
if self.kicking_request_id == request_id:
|
"""
|
||||||
self.is_engine_running = False
|
self.request_tracker.abort_request(request_id,
|
||||||
self.kicking_request_id = None
|
verbose=self.log_requests)
|
||||||
|
|
||||||
async def get_model_config(self) -> ModelConfig:
|
async def get_model_config(self) -> ModelConfig:
|
||||||
"""Get the model configuration of the vLLM engine."""
|
"""Get the model configuration of the vLLM engine."""
|
||||||
@ -220,7 +425,8 @@ class AsyncLLMEngine:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_engine_args(cls,
|
def from_engine_args(cls,
|
||||||
engine_args: AsyncEngineArgs) -> "AsyncLLMEngine":
|
engine_args: AsyncEngineArgs,
|
||||||
|
start_engine_loop: bool = False) -> "AsyncLLMEngine":
|
||||||
"""Creates an async LLM engine from the engine arguments."""
|
"""Creates an async LLM engine from the engine arguments."""
|
||||||
# Create the engine configs.
|
# Create the engine configs.
|
||||||
engine_configs = engine_args.create_engine_configs()
|
engine_configs = engine_args.create_engine_configs()
|
||||||
@ -235,5 +441,6 @@ class AsyncLLMEngine:
|
|||||||
distributed_init_method,
|
distributed_init_method,
|
||||||
placement_group,
|
placement_group,
|
||||||
log_requests=not engine_args.disable_log_requests,
|
log_requests=not engine_args.disable_log_requests,
|
||||||
log_stats=not engine_args.disable_log_stats)
|
log_stats=not engine_args.disable_log_stats,
|
||||||
|
start_engine_loop=start_engine_loop)
|
||||||
return engine
|
return engine
|
||||||
|
@ -1,17 +1,19 @@
|
|||||||
import time
|
|
||||||
import copy
|
import copy
|
||||||
|
import time
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, List, Optional, Tuple, TYPE_CHECKING
|
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
|
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
|
||||||
SchedulerConfig)
|
SchedulerConfig)
|
||||||
from vllm.core.scheduler import Scheduler
|
from vllm.core.scheduler import Scheduler, SchedulerOutputs
|
||||||
from vllm.engine.arg_utils import EngineArgs
|
from vllm.engine.arg_utils import EngineArgs
|
||||||
from vllm.engine.ray_utils import initialize_cluster, ray, RayWorker
|
from vllm.engine.ray_utils import RayWorker, initialize_cluster, ray
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.outputs import RequestOutput
|
from vllm.outputs import RequestOutput
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
|
from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup,
|
||||||
|
SequenceGroupMetadata, SequenceOutputs,
|
||||||
|
SequenceStatus)
|
||||||
from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
|
from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
|
||||||
get_tokenizer)
|
get_tokenizer)
|
||||||
from vllm.utils import Counter
|
from vllm.utils import Counter
|
||||||
@ -74,9 +76,8 @@ class LLMEngine:
|
|||||||
f"tokenizer_mode={model_config.tokenizer_mode}, "
|
f"tokenizer_mode={model_config.tokenizer_mode}, "
|
||||||
f"trust_remote_code={model_config.trust_remote_code}, "
|
f"trust_remote_code={model_config.trust_remote_code}, "
|
||||||
f"dtype={model_config.dtype}, "
|
f"dtype={model_config.dtype}, "
|
||||||
f"use_dummy_weights={model_config.use_dummy_weights}, "
|
|
||||||
f"download_dir={model_config.download_dir!r}, "
|
f"download_dir={model_config.download_dir!r}, "
|
||||||
f"use_np_weights={model_config.use_np_weights}, "
|
f"load_format={model_config.load_format}, "
|
||||||
f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
|
f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
|
||||||
f"seed={model_config.seed})")
|
f"seed={model_config.seed})")
|
||||||
# TODO(woosuk): Print more configs in debug mode.
|
# TODO(woosuk): Print more configs in debug mode.
|
||||||
@ -135,7 +136,8 @@ class LLMEngine:
|
|||||||
get_all_outputs=True,
|
get_all_outputs=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _init_workers_ray(self, placement_group: "PlacementGroup"):
|
def _init_workers_ray(self, placement_group: "PlacementGroup",
|
||||||
|
**ray_remote_kwargs):
|
||||||
# Lazy import the Worker to avoid importing torch.cuda/xformers
|
# Lazy import the Worker to avoid importing torch.cuda/xformers
|
||||||
# before CUDA_VISIBLE_DEVICES is set in the Worker
|
# before CUDA_VISIBLE_DEVICES is set in the Worker
|
||||||
from vllm.worker.worker import Worker # pylint: disable=import-outside-toplevel
|
from vllm.worker.worker import Worker # pylint: disable=import-outside-toplevel
|
||||||
@ -150,6 +152,7 @@ class LLMEngine:
|
|||||||
scheduling_strategy=PlacementGroupSchedulingStrategy(
|
scheduling_strategy=PlacementGroupSchedulingStrategy(
|
||||||
placement_group=placement_group,
|
placement_group=placement_group,
|
||||||
placement_group_capture_child_tasks=True),
|
placement_group_capture_child_tasks=True),
|
||||||
|
**ray_remote_kwargs,
|
||||||
)(RayWorker).remote()
|
)(RayWorker).remote()
|
||||||
self.workers.append(worker)
|
self.workers.append(worker)
|
||||||
|
|
||||||
@ -255,24 +258,21 @@ class LLMEngine:
|
|||||||
|
|
||||||
# Create the sequences.
|
# Create the sequences.
|
||||||
block_size = self.cache_config.block_size
|
block_size = self.cache_config.block_size
|
||||||
seqs: List[Sequence] = []
|
|
||||||
for _ in range(sampling_params.best_of):
|
|
||||||
seq_id = next(self.seq_counter)
|
seq_id = next(self.seq_counter)
|
||||||
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size)
|
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size)
|
||||||
seqs.append(seq)
|
|
||||||
|
|
||||||
# Create the sequence group.
|
# Create the sequence group.
|
||||||
seq_group = SequenceGroup(request_id, seqs, sampling_params,
|
seq_group = SequenceGroup(request_id, [seq], sampling_params,
|
||||||
arrival_time)
|
arrival_time)
|
||||||
|
|
||||||
# Add the sequence group to the scheduler.
|
# Add the sequence group to the scheduler.
|
||||||
self.scheduler.add_seq_group(seq_group)
|
self.scheduler.add_seq_group(seq_group)
|
||||||
|
|
||||||
def abort_request(self, request_id: str) -> None:
|
def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
|
||||||
"""Aborts a request with the given ID.
|
"""Aborts a request(s) with the given ID.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request_id: The ID of the request to abort.
|
request_id: The ID(s) of the request to abort.
|
||||||
"""
|
"""
|
||||||
self.scheduler.abort_seq_group(request_id)
|
self.scheduler.abort_seq_group(request_id)
|
||||||
|
|
||||||
@ -288,6 +288,251 @@ class LLMEngine:
|
|||||||
"""Returns True if there are unfinished requests."""
|
"""Returns True if there are unfinished requests."""
|
||||||
return self.scheduler.has_unfinished_seqs()
|
return self.scheduler.has_unfinished_seqs()
|
||||||
|
|
||||||
|
def _schedule(
|
||||||
|
self
|
||||||
|
) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs,
|
||||||
|
Optional[List[RequestOutput]]]:
|
||||||
|
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
|
||||||
|
if scheduler_outputs.is_empty():
|
||||||
|
return seq_group_metadata_list, scheduler_outputs, [
|
||||||
|
RequestOutput.from_seq_group(seq_group)
|
||||||
|
for seq_group in scheduler_outputs.ignored_seq_groups
|
||||||
|
]
|
||||||
|
return seq_group_metadata_list, scheduler_outputs, None
|
||||||
|
|
||||||
|
def _check_beam_search_early_stopping(
|
||||||
|
self,
|
||||||
|
early_stopping: Union[bool, str],
|
||||||
|
sampling_params: SamplingParams,
|
||||||
|
best_running_seq: Sequence,
|
||||||
|
current_worst_seq: Sequence,
|
||||||
|
) -> bool:
|
||||||
|
assert sampling_params.use_beam_search
|
||||||
|
length_penalty = sampling_params.length_penalty
|
||||||
|
if early_stopping is True:
|
||||||
|
return True
|
||||||
|
|
||||||
|
current_worst_score = (current_worst_seq.get_beam_search_score(
|
||||||
|
length_penalty=length_penalty,
|
||||||
|
eos_token_id=self.tokenizer.eos_token_id))
|
||||||
|
if early_stopping is False:
|
||||||
|
highest_attainable_score = (best_running_seq.get_beam_search_score(
|
||||||
|
length_penalty=length_penalty,
|
||||||
|
eos_token_id=self.tokenizer.eos_token_id))
|
||||||
|
else:
|
||||||
|
assert early_stopping == "never"
|
||||||
|
if length_penalty > 0.0:
|
||||||
|
# If length_penalty > 0.0, beam search will prefer longer
|
||||||
|
# sequences. The highest attainable score calculation is
|
||||||
|
# based on the longest possible sequence length in this case.
|
||||||
|
max_possible_length = max(
|
||||||
|
best_running_seq.get_prompt_len() +
|
||||||
|
sampling_params.max_tokens,
|
||||||
|
self.scheduler_config.max_model_len)
|
||||||
|
highest_attainable_score = (
|
||||||
|
best_running_seq.get_beam_search_score(
|
||||||
|
length_penalty=length_penalty,
|
||||||
|
eos_token_id=self.tokenizer.eos_token_id,
|
||||||
|
seq_len=max_possible_length))
|
||||||
|
else:
|
||||||
|
# Otherwise, beam search will prefer shorter sequences. The
|
||||||
|
# highest attainable score calculation is based on the current
|
||||||
|
# sequence length.
|
||||||
|
highest_attainable_score = (
|
||||||
|
best_running_seq.get_beam_search_score(
|
||||||
|
length_penalty=length_penalty,
|
||||||
|
eos_token_id=self.tokenizer.eos_token_id))
|
||||||
|
return current_worst_score >= highest_attainable_score
|
||||||
|
|
||||||
|
def _process_sequence_group_samples(
|
||||||
|
self, seq_group: SequenceGroup,
|
||||||
|
samples: List[SequenceOutputs]) -> None:
|
||||||
|
parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
|
||||||
|
existing_finished_seqs = seq_group.get_finished_seqs()
|
||||||
|
parent_child_dict = {
|
||||||
|
parent_seq.seq_id: []
|
||||||
|
for parent_seq in parent_seqs
|
||||||
|
}
|
||||||
|
for sample in samples:
|
||||||
|
parent_child_dict[sample.parent_seq_id].append(sample)
|
||||||
|
# List of (child, parent)
|
||||||
|
child_seqs: List[Tuple[Sequence, Sequence]] = []
|
||||||
|
|
||||||
|
# Process the child samples for each parent sequence
|
||||||
|
for parent in parent_seqs:
|
||||||
|
child_samples: List[SequenceOutputs] = parent_child_dict[
|
||||||
|
parent.seq_id]
|
||||||
|
if len(child_samples) == 0:
|
||||||
|
# This parent sequence has no children samples. Remove
|
||||||
|
# the parent sequence from the sequence group since it will
|
||||||
|
# not be used in the future iterations.
|
||||||
|
parent.status = SequenceStatus.FINISHED_ABORTED
|
||||||
|
seq_group.remove(parent.seq_id)
|
||||||
|
self.scheduler.free_seq(parent)
|
||||||
|
continue
|
||||||
|
# Fork the parent sequence if there are multiple child samples.
|
||||||
|
for child_sample in child_samples[:-1]:
|
||||||
|
new_child_seq_id = next(self.seq_counter)
|
||||||
|
child = parent.fork(new_child_seq_id)
|
||||||
|
child.append_token_id(child_sample.output_token,
|
||||||
|
child_sample.logprobs)
|
||||||
|
child_seqs.append((child, parent))
|
||||||
|
# Continue the parent sequence for the last child sample.
|
||||||
|
# We reuse the parent sequence here to reduce redundant memory
|
||||||
|
# copies, especially when using non-beam search sampling methods.
|
||||||
|
last_child_sample = child_samples[-1]
|
||||||
|
parent.append_token_id(last_child_sample.output_token,
|
||||||
|
last_child_sample.logprobs)
|
||||||
|
child_seqs.append((parent, parent))
|
||||||
|
|
||||||
|
for seq, _ in child_seqs:
|
||||||
|
self._decode_sequence(seq)
|
||||||
|
self._check_stop(seq, seq_group.sampling_params)
|
||||||
|
|
||||||
|
# Non-beam search case
|
||||||
|
if not seq_group.sampling_params.use_beam_search:
|
||||||
|
# For newly created child sequences, add them to the sequence group
|
||||||
|
# and fork them in block manager if they are not finished.
|
||||||
|
for seq, parent in child_seqs:
|
||||||
|
if seq is not parent:
|
||||||
|
seq_group.add(seq)
|
||||||
|
if not seq.is_finished():
|
||||||
|
self.scheduler.fork_seq(parent, seq)
|
||||||
|
|
||||||
|
# Free the finished and selected parent sequences' memory in block
|
||||||
|
# manager. Keep them in the sequence group as candidate output.
|
||||||
|
# NOTE: we need to fork the new sequences before freeing the
|
||||||
|
# old sequences.
|
||||||
|
for seq, parent in child_seqs:
|
||||||
|
if seq is parent and seq.is_finished():
|
||||||
|
self.scheduler.free_seq(seq)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Beam search case
|
||||||
|
# Select the child sequences to keep in the sequence group.
|
||||||
|
selected_child_seqs = []
|
||||||
|
unselected_child_seqs = []
|
||||||
|
beam_width = seq_group.sampling_params.best_of
|
||||||
|
length_penalty = seq_group.sampling_params.length_penalty
|
||||||
|
|
||||||
|
# Select the newly finished sequences with the highest scores
|
||||||
|
# to replace existing finished sequences.
|
||||||
|
# Tuple of (seq, parent, is_new)
|
||||||
|
existing_finished_seqs = [(seq, None, False)
|
||||||
|
for seq in existing_finished_seqs]
|
||||||
|
new_finished_seqs = [(seq, parent, True) for seq, parent in child_seqs
|
||||||
|
if seq.is_finished()]
|
||||||
|
all_finished_seqs = existing_finished_seqs + new_finished_seqs
|
||||||
|
# Sort the finished sequences by their scores.
|
||||||
|
all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score(
|
||||||
|
length_penalty=length_penalty,
|
||||||
|
eos_token_id=self.tokenizer.eos_token_id),
|
||||||
|
reverse=True)
|
||||||
|
for seq, parent, is_new in all_finished_seqs[:beam_width]:
|
||||||
|
if is_new:
|
||||||
|
# A newly generated child sequence finishes and has a high
|
||||||
|
# score, so we will add it into the sequence group.
|
||||||
|
selected_child_seqs.append((seq, parent))
|
||||||
|
for seq, parent, is_new in all_finished_seqs[beam_width:]:
|
||||||
|
if is_new:
|
||||||
|
# A newly generated child sequence finishes but has a low
|
||||||
|
# score, so we will not add it into the sequence group.
|
||||||
|
# Additionally, if this sequence is a continuation of a
|
||||||
|
# parent sequence, we will need remove the parent sequence
|
||||||
|
# from the sequence group.
|
||||||
|
unselected_child_seqs.append((seq, parent))
|
||||||
|
else:
|
||||||
|
# An existing finished sequence has a low score, so we will
|
||||||
|
# remove it from the sequence group.
|
||||||
|
seq_group.remove(seq.seq_id)
|
||||||
|
|
||||||
|
# select the top beam_width sequences from the running
|
||||||
|
# sequences for the next iteration to continue the beam
|
||||||
|
# search.
|
||||||
|
running_child_seqs = [(seq, parent) for seq, parent in child_seqs
|
||||||
|
if not seq.is_finished()]
|
||||||
|
# Sort the running sequences by their scores.
|
||||||
|
running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score(
|
||||||
|
length_penalty=length_penalty,
|
||||||
|
eos_token_id=self.tokenizer.eos_token_id),
|
||||||
|
reverse=True)
|
||||||
|
|
||||||
|
# Check if we can stop the beam search.
|
||||||
|
if len(running_child_seqs) == 0:
|
||||||
|
# No running sequences, stop the beam search.
|
||||||
|
stop_beam_search = True
|
||||||
|
elif len(all_finished_seqs) < beam_width:
|
||||||
|
# Not enough finished sequences, continue the beam search.
|
||||||
|
stop_beam_search = False
|
||||||
|
else:
|
||||||
|
# Check the early stopping criteria
|
||||||
|
best_running_seq = running_child_seqs[0][0]
|
||||||
|
current_worst_seq = all_finished_seqs[beam_width - 1][0]
|
||||||
|
stop_beam_search = self._check_beam_search_early_stopping(
|
||||||
|
seq_group.sampling_params.early_stopping,
|
||||||
|
seq_group.sampling_params, best_running_seq, current_worst_seq)
|
||||||
|
|
||||||
|
if stop_beam_search:
|
||||||
|
# Stop the beam search and remove all the running sequences from
|
||||||
|
# the sequence group.
|
||||||
|
unselected_child_seqs.extend(running_child_seqs)
|
||||||
|
else:
|
||||||
|
# Continue the beam search and select the top beam_width sequences
|
||||||
|
# to continue the beam search.
|
||||||
|
selected_child_seqs.extend(running_child_seqs[:beam_width])
|
||||||
|
# The remaining running sequences will not be used in the next
|
||||||
|
# iteration. Again, if these sequences are continuations of
|
||||||
|
# parent sequences, we will need to remove the parent sequences
|
||||||
|
# from the sequence group.
|
||||||
|
unselected_child_seqs.extend(running_child_seqs[beam_width:])
|
||||||
|
|
||||||
|
# For newly created child sequences, add them to the sequence group
|
||||||
|
# and fork them in block manager if they are not finished.
|
||||||
|
for seq, parent in selected_child_seqs:
|
||||||
|
if seq is not parent:
|
||||||
|
seq_group.add(seq)
|
||||||
|
if not seq.is_finished():
|
||||||
|
self.scheduler.fork_seq(parent, seq)
|
||||||
|
|
||||||
|
# Free the finished and selected parent sequences' memory in block
|
||||||
|
# manager. Keep them in the sequence group as candidate output.
|
||||||
|
for seq, parent in selected_child_seqs:
|
||||||
|
if seq is parent and seq.is_finished():
|
||||||
|
self.scheduler.free_seq(seq)
|
||||||
|
|
||||||
|
# Remove the unselected parent sequences from the sequence group and
|
||||||
|
# free their memory in block manager.
|
||||||
|
for seq, parent in unselected_child_seqs:
|
||||||
|
if seq is parent:
|
||||||
|
# Remove the parent sequence if it is not selected for next
|
||||||
|
# iteration
|
||||||
|
seq_group.remove(seq.seq_id)
|
||||||
|
self.scheduler.free_seq(seq)
|
||||||
|
|
||||||
|
def _process_model_outputs(
|
||||||
|
self, output: SamplerOutput,
|
||||||
|
scheduler_outputs: SchedulerOutputs) -> List[RequestOutput]:
|
||||||
|
# Update the scheduled sequence groups with the model outputs.
|
||||||
|
scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups
|
||||||
|
for seq_group, samples in zip(scheduled_seq_groups, output):
|
||||||
|
self._process_sequence_group_samples(seq_group, samples)
|
||||||
|
|
||||||
|
# Free the finished sequence groups.
|
||||||
|
self.scheduler.free_finished_seq_groups()
|
||||||
|
|
||||||
|
# Create the outputs.
|
||||||
|
request_outputs: List[RequestOutput] = []
|
||||||
|
for seq_group in (scheduled_seq_groups +
|
||||||
|
scheduler_outputs.ignored_seq_groups):
|
||||||
|
request_output = RequestOutput.from_seq_group(seq_group)
|
||||||
|
request_outputs.append(request_output)
|
||||||
|
|
||||||
|
if self.log_stats:
|
||||||
|
# Log the system stats.
|
||||||
|
self._log_system_stats(scheduler_outputs.prompt_run,
|
||||||
|
scheduler_outputs.num_batched_tokens)
|
||||||
|
return request_outputs
|
||||||
|
|
||||||
def step(self) -> List[RequestOutput]:
|
def step(self) -> List[RequestOutput]:
|
||||||
"""Performs one decoding iteration and returns newly generated results.
|
"""Performs one decoding iteration and returns newly generated results.
|
||||||
|
|
||||||
@ -297,17 +542,10 @@ class LLMEngine:
|
|||||||
and updates the scheduler with the model outputs. Finally, it decodes
|
and updates the scheduler with the model outputs. Finally, it decodes
|
||||||
the sequences and returns the newly generated results.
|
the sequences and returns the newly generated results.
|
||||||
"""
|
"""
|
||||||
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
|
(seq_group_metadata_list, scheduler_outputs,
|
||||||
if scheduler_outputs.is_empty():
|
early_return) = self._schedule()
|
||||||
if not scheduler_outputs.ignored_seq_groups:
|
if early_return is not None:
|
||||||
# Nothing to do.
|
return early_return
|
||||||
return []
|
|
||||||
# If there are ignored seq groups, we need to return them as the
|
|
||||||
# request outputs.
|
|
||||||
return [
|
|
||||||
RequestOutput.from_seq_group(seq_group)
|
|
||||||
for seq_group in scheduler_outputs.ignored_seq_groups
|
|
||||||
]
|
|
||||||
|
|
||||||
# Execute the model.
|
# Execute the model.
|
||||||
output = self._run_workers(
|
output = self._run_workers(
|
||||||
@ -317,27 +555,8 @@ class LLMEngine:
|
|||||||
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
|
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
|
||||||
blocks_to_copy=scheduler_outputs.blocks_to_copy,
|
blocks_to_copy=scheduler_outputs.blocks_to_copy,
|
||||||
)
|
)
|
||||||
# Update the scheduler with the model outputs.
|
|
||||||
seq_groups = self.scheduler.update(output)
|
|
||||||
|
|
||||||
# Decode the sequences.
|
return self._process_model_outputs(output, scheduler_outputs)
|
||||||
self._decode_sequences(seq_groups)
|
|
||||||
# Stop the sequences that meet the stopping criteria.
|
|
||||||
self._stop_sequences(seq_groups)
|
|
||||||
# Free the finished sequence groups.
|
|
||||||
self.scheduler.free_finished_seq_groups()
|
|
||||||
|
|
||||||
# Create the outputs.
|
|
||||||
request_outputs: List[RequestOutput] = []
|
|
||||||
for seq_group in seq_groups + scheduler_outputs.ignored_seq_groups:
|
|
||||||
request_output = RequestOutput.from_seq_group(seq_group)
|
|
||||||
request_outputs.append(request_output)
|
|
||||||
|
|
||||||
if self.log_stats:
|
|
||||||
# Log the system stats.
|
|
||||||
self._log_system_stats(scheduler_outputs.prompt_run,
|
|
||||||
scheduler_outputs.num_batched_tokens)
|
|
||||||
return request_outputs
|
|
||||||
|
|
||||||
def _log_system_stats(
|
def _log_system_stats(
|
||||||
self,
|
self,
|
||||||
@ -402,10 +621,8 @@ class LLMEngine:
|
|||||||
f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%")
|
f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%")
|
||||||
self.last_logging_time = now
|
self.last_logging_time = now
|
||||||
|
|
||||||
def _decode_sequences(self, seq_groups: List[SequenceGroup]) -> None:
|
def _decode_sequence(self, seq: Sequence) -> None:
|
||||||
"""Decodes the sequence outputs."""
|
"""Decodes the new token for a sequence."""
|
||||||
for seq_group in seq_groups:
|
|
||||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
|
||||||
new_token, new_output_text = detokenize_incrementally(
|
new_token, new_output_text = detokenize_incrementally(
|
||||||
self.tokenizer,
|
self.tokenizer,
|
||||||
seq.output_tokens,
|
seq.output_tokens,
|
||||||
@ -416,41 +633,32 @@ class LLMEngine:
|
|||||||
seq.output_tokens.append(new_token)
|
seq.output_tokens.append(new_token)
|
||||||
seq.output_text = new_output_text
|
seq.output_text = new_output_text
|
||||||
|
|
||||||
def _stop_sequences(self, seq_groups: List[SequenceGroup]) -> None:
|
def _check_stop(self, seq: Sequence,
|
||||||
|
sampling_params: SamplingParams) -> None:
|
||||||
"""Stop the finished sequences."""
|
"""Stop the finished sequences."""
|
||||||
for seq_group in seq_groups:
|
|
||||||
sampling_params = seq_group.sampling_params
|
|
||||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
|
||||||
# Check if the sequence has generated a stop string.
|
|
||||||
stopped = False
|
|
||||||
for stop_str in sampling_params.stop:
|
for stop_str in sampling_params.stop:
|
||||||
if seq.output_text.endswith(stop_str):
|
if seq.output_text.endswith(stop_str):
|
||||||
# Truncate the output text so that the stop string is
|
# Truncate the output text so that the stop string is
|
||||||
# not included in the output.
|
# not included in the output.
|
||||||
seq.output_text = seq.output_text[:-len(stop_str)]
|
seq.output_text = seq.output_text[:-len(stop_str)]
|
||||||
self.scheduler.free_seq(
|
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||||
seq, SequenceStatus.FINISHED_STOPPED)
|
return
|
||||||
stopped = True
|
|
||||||
break
|
|
||||||
if stopped:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Check if the sequence has reached max_model_len.
|
# Check if the sequence has reached max_model_len.
|
||||||
if seq.get_len() > self.scheduler_config.max_model_len:
|
if seq.get_len() > self.scheduler_config.max_model_len:
|
||||||
self.scheduler.free_seq(
|
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
|
||||||
seq, SequenceStatus.FINISHED_LENGTH_CAPPED)
|
return
|
||||||
continue
|
|
||||||
# Check if the sequence has reached max_tokens.
|
# Check if the sequence has reached max_tokens.
|
||||||
if seq.get_output_len() == sampling_params.max_tokens:
|
if seq.get_output_len() == sampling_params.max_tokens:
|
||||||
self.scheduler.free_seq(
|
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
|
||||||
seq, SequenceStatus.FINISHED_LENGTH_CAPPED)
|
return
|
||||||
continue
|
|
||||||
# Check if the sequence has generated the EOS token.
|
# Check if the sequence has generated the EOS token.
|
||||||
if not sampling_params.ignore_eos:
|
if ((not sampling_params.ignore_eos)
|
||||||
if seq.get_last_token_id() == self.tokenizer.eos_token_id:
|
and seq.get_last_token_id() == self.tokenizer.eos_token_id):
|
||||||
self.scheduler.free_seq(
|
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||||
seq, SequenceStatus.FINISHED_STOPPED)
|
return
|
||||||
continue
|
|
||||||
|
|
||||||
def _run_workers(
|
def _run_workers(
|
||||||
self,
|
self,
|
||||||
|
@ -14,6 +14,7 @@ from vllm.utils import random_uuid
|
|||||||
TIMEOUT_KEEP_ALIVE = 5 # seconds.
|
TIMEOUT_KEEP_ALIVE = 5 # seconds.
|
||||||
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds.
|
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds.
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
engine = None
|
||||||
|
|
||||||
|
|
||||||
@app.post("/generate")
|
@app.post("/generate")
|
||||||
@ -30,6 +31,10 @@ async def generate(request: Request) -> Response:
|
|||||||
stream = request_dict.pop("stream", False)
|
stream = request_dict.pop("stream", False)
|
||||||
sampling_params = SamplingParams(**request_dict)
|
sampling_params = SamplingParams(**request_dict)
|
||||||
request_id = random_uuid()
|
request_id = random_uuid()
|
||||||
|
|
||||||
|
if not engine.is_running:
|
||||||
|
engine.start_background_loop()
|
||||||
|
|
||||||
results_generator = engine.generate(prompt, sampling_params, request_id)
|
results_generator = engine.generate(prompt, sampling_params, request_id)
|
||||||
|
|
||||||
# Streaming case
|
# Streaming case
|
||||||
@ -75,7 +80,8 @@ if __name__ == "__main__":
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||||
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
engine = AsyncLLMEngine.from_engine_args(engine_args,
|
||||||
|
start_engine_loop=False)
|
||||||
|
|
||||||
uvicorn.run(app,
|
uvicorn.run(app,
|
||||||
host=args.host,
|
host=args.host,
|
||||||
|
@ -3,18 +3,18 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
from http import HTTPStatus
|
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
from typing import AsyncGenerator, Dict, List, Optional
|
from http import HTTPStatus
|
||||||
from packaging import version
|
from typing import AsyncGenerator, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import fastapi
|
import fastapi
|
||||||
|
import uvicorn
|
||||||
from fastapi import BackgroundTasks, Request
|
from fastapi import BackgroundTasks, Request
|
||||||
from fastapi.exceptions import RequestValidationError
|
from fastapi.exceptions import RequestValidationError
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.responses import JSONResponse, StreamingResponse
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
import uvicorn
|
from packaging import version
|
||||||
|
|
||||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||||
@ -44,6 +44,7 @@ TIMEOUT_KEEP_ALIVE = 5 # seconds
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
served_model = None
|
served_model = None
|
||||||
app = fastapi.FastAPI()
|
app = fastapi.FastAPI()
|
||||||
|
engine = None
|
||||||
|
|
||||||
|
|
||||||
def create_error_response(status_code: HTTPStatus,
|
def create_error_response(status_code: HTTPStatus,
|
||||||
@ -115,12 +116,22 @@ async def get_gen_prompt(request) -> str:
|
|||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
async def check_length(request, prompt):
|
async def check_length(
|
||||||
|
request: Union[ChatCompletionRequest, CompletionRequest],
|
||||||
|
prompt: Optional[str] = None,
|
||||||
|
prompt_ids: Optional[List[int]] = None
|
||||||
|
) -> Tuple[List[int], Optional[JSONResponse]]:
|
||||||
|
assert (not (prompt is None and prompt_ids is None)
|
||||||
|
and not (prompt is not None and prompt_ids is not None)
|
||||||
|
), "Either prompt or prompt_ids should be provided."
|
||||||
|
if prompt_ids is not None:
|
||||||
|
input_ids = prompt_ids
|
||||||
|
else:
|
||||||
input_ids = tokenizer(prompt).input_ids
|
input_ids = tokenizer(prompt).input_ids
|
||||||
token_num = len(input_ids)
|
token_num = len(input_ids)
|
||||||
|
|
||||||
if token_num + request.max_tokens > max_model_len:
|
if token_num + request.max_tokens > max_model_len:
|
||||||
return create_error_response(
|
return input_ids, create_error_response(
|
||||||
HTTPStatus.BAD_REQUEST,
|
HTTPStatus.BAD_REQUEST,
|
||||||
f"This model's maximum context length is {max_model_len} tokens. "
|
f"This model's maximum context length is {max_model_len} tokens. "
|
||||||
f"However, you requested {request.max_tokens + token_num} tokens "
|
f"However, you requested {request.max_tokens + token_num} tokens "
|
||||||
@ -129,7 +140,7 @@ async def check_length(request, prompt):
|
|||||||
f"Please reduce the length of the messages or completion.",
|
f"Please reduce the length of the messages or completion.",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return None
|
return input_ids, None
|
||||||
|
|
||||||
|
|
||||||
@app.get("/v1/models")
|
@app.get("/v1/models")
|
||||||
@ -168,7 +179,8 @@ def create_logprobs(token_ids: List[int],
|
|||||||
|
|
||||||
|
|
||||||
@app.post("/v1/chat/completions")
|
@app.post("/v1/chat/completions")
|
||||||
async def create_chat_completion(raw_request: Request):
|
async def create_chat_completion(request: ChatCompletionRequest,
|
||||||
|
raw_request: Request):
|
||||||
"""Completion API similar to OpenAI's API.
|
"""Completion API similar to OpenAI's API.
|
||||||
|
|
||||||
See https://platform.openai.com/docs/api-reference/chat/create
|
See https://platform.openai.com/docs/api-reference/chat/create
|
||||||
@ -178,9 +190,11 @@ async def create_chat_completion(raw_request: Request):
|
|||||||
- function_call (Users should implement this by themselves)
|
- function_call (Users should implement this by themselves)
|
||||||
- logit_bias (to be supported by vLLM engine)
|
- logit_bias (to be supported by vLLM engine)
|
||||||
"""
|
"""
|
||||||
request = ChatCompletionRequest(**await raw_request.json())
|
|
||||||
logger.info(f"Received chat completion request: {request}")
|
logger.info(f"Received chat completion request: {request}")
|
||||||
|
|
||||||
|
if not engine.is_running:
|
||||||
|
engine.start_background_loop()
|
||||||
|
|
||||||
error_check_ret = await check_model(request)
|
error_check_ret = await check_model(request)
|
||||||
if error_check_ret is not None:
|
if error_check_ret is not None:
|
||||||
return error_check_ret
|
return error_check_ret
|
||||||
@ -191,7 +205,7 @@ async def create_chat_completion(raw_request: Request):
|
|||||||
"logit_bias is not currently supported")
|
"logit_bias is not currently supported")
|
||||||
|
|
||||||
prompt = await get_gen_prompt(request)
|
prompt = await get_gen_prompt(request)
|
||||||
error_check_ret = await check_length(request, prompt)
|
token_ids, error_check_ret = await check_length(request, prompt=prompt)
|
||||||
if error_check_ret is not None:
|
if error_check_ret is not None:
|
||||||
return error_check_ret
|
return error_check_ret
|
||||||
|
|
||||||
@ -215,7 +229,8 @@ async def create_chat_completion(raw_request: Request):
|
|||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
|
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
|
||||||
|
|
||||||
result_generator = engine.generate(prompt, sampling_params, request_id)
|
result_generator = engine.generate(prompt, sampling_params, request_id,
|
||||||
|
token_ids)
|
||||||
|
|
||||||
async def abort_request() -> None:
|
async def abort_request() -> None:
|
||||||
await engine.abort(request_id)
|
await engine.abort(request_id)
|
||||||
@ -337,7 +352,7 @@ async def create_chat_completion(raw_request: Request):
|
|||||||
|
|
||||||
|
|
||||||
@app.post("/v1/completions")
|
@app.post("/v1/completions")
|
||||||
async def create_completion(raw_request: Request):
|
async def create_completion(request: CompletionRequest, raw_request: Request):
|
||||||
"""Completion API similar to OpenAI's API.
|
"""Completion API similar to OpenAI's API.
|
||||||
|
|
||||||
See https://platform.openai.com/docs/api-reference/completions/create
|
See https://platform.openai.com/docs/api-reference/completions/create
|
||||||
@ -350,9 +365,11 @@ async def create_completion(raw_request: Request):
|
|||||||
suffix)
|
suffix)
|
||||||
- logit_bias (to be supported by vLLM engine)
|
- logit_bias (to be supported by vLLM engine)
|
||||||
"""
|
"""
|
||||||
request = CompletionRequest(**await raw_request.json())
|
|
||||||
logger.info(f"Received completion request: {request}")
|
logger.info(f"Received completion request: {request}")
|
||||||
|
|
||||||
|
if not engine.is_running:
|
||||||
|
engine.start_background_loop()
|
||||||
|
|
||||||
error_check_ret = await check_model(request)
|
error_check_ret = await check_model(request)
|
||||||
if error_check_ret is not None:
|
if error_check_ret is not None:
|
||||||
return error_check_ret
|
return error_check_ret
|
||||||
@ -375,17 +392,34 @@ async def create_completion(raw_request: Request):
|
|||||||
|
|
||||||
model_name = request.model
|
model_name = request.model
|
||||||
request_id = f"cmpl-{random_uuid()}"
|
request_id = f"cmpl-{random_uuid()}"
|
||||||
|
|
||||||
|
use_token_ids = False
|
||||||
if isinstance(request.prompt, list):
|
if isinstance(request.prompt, list):
|
||||||
if len(request.prompt) == 0:
|
if len(request.prompt) == 0:
|
||||||
return create_error_response(HTTPStatus.BAD_REQUEST,
|
return create_error_response(HTTPStatus.BAD_REQUEST,
|
||||||
"please provide at least one prompt")
|
"please provide at least one prompt")
|
||||||
|
first_element = request.prompt[0]
|
||||||
|
if isinstance(first_element, int):
|
||||||
|
use_token_ids = True
|
||||||
|
prompt = request.prompt
|
||||||
|
elif isinstance(first_element, (str, list)):
|
||||||
|
# TODO: handles multiple prompt case in list[list[int]]
|
||||||
if len(request.prompt) > 1:
|
if len(request.prompt) > 1:
|
||||||
return create_error_response(
|
return create_error_response(
|
||||||
HTTPStatus.BAD_REQUEST,
|
HTTPStatus.BAD_REQUEST,
|
||||||
"multiple prompts in a batch is not currently supported")
|
"multiple prompts in a batch is not currently supported")
|
||||||
|
use_token_ids = not isinstance(first_element, str)
|
||||||
prompt = request.prompt[0]
|
prompt = request.prompt[0]
|
||||||
else:
|
else:
|
||||||
prompt = request.prompt
|
prompt = request.prompt
|
||||||
|
|
||||||
|
if use_token_ids:
|
||||||
|
_, error_check_ret = await check_length(request, prompt_ids=prompt)
|
||||||
|
else:
|
||||||
|
token_ids, error_check_ret = await check_length(request, prompt=prompt)
|
||||||
|
if error_check_ret is not None:
|
||||||
|
return error_check_ret
|
||||||
|
|
||||||
created_time = int(time.time())
|
created_time = int(time.time())
|
||||||
try:
|
try:
|
||||||
sampling_params = SamplingParams(
|
sampling_params = SamplingParams(
|
||||||
@ -405,7 +439,14 @@ async def create_completion(raw_request: Request):
|
|||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
|
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
|
||||||
|
|
||||||
result_generator = engine.generate(prompt, sampling_params, request_id)
|
if use_token_ids:
|
||||||
|
result_generator = engine.generate(None,
|
||||||
|
sampling_params,
|
||||||
|
request_id,
|
||||||
|
prompt_token_ids=prompt)
|
||||||
|
else:
|
||||||
|
result_generator = engine.generate(prompt, sampling_params, request_id,
|
||||||
|
token_ids)
|
||||||
|
|
||||||
# Similar to the OpenAI API, when n != best_of, we do not stream the
|
# Similar to the OpenAI API, when n != best_of, we do not stream the
|
||||||
# results. In addition, we do not stream the results when use beam search.
|
# results. In addition, we do not stream the results when use beam search.
|
||||||
@ -586,7 +627,8 @@ if __name__ == "__main__":
|
|||||||
served_model = args.model
|
served_model = args.model
|
||||||
|
|
||||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||||
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
engine = AsyncLLMEngine.from_engine_args(engine_args,
|
||||||
|
start_engine_loop=False)
|
||||||
engine_model_config = asyncio.run(engine.get_model_config())
|
engine_model_config = asyncio.run(engine.get_model_config())
|
||||||
max_model_len = engine_model_config.get_max_model_len()
|
max_model_len = engine_model_config.get_max_model_len()
|
||||||
|
|
||||||
|
@ -74,7 +74,8 @@ class ChatCompletionRequest(BaseModel):
|
|||||||
|
|
||||||
class CompletionRequest(BaseModel):
|
class CompletionRequest(BaseModel):
|
||||||
model: str
|
model: str
|
||||||
prompt: Union[str, List[str]]
|
# a string, array of strings, array of tokens, or array of token arrays
|
||||||
|
prompt: Union[List[int], List[List[int]], str, List[str]]
|
||||||
suffix: Optional[str] = None
|
suffix: Optional[str] = None
|
||||||
max_tokens: Optional[int] = 16
|
max_tokens: Optional[int] = 16
|
||||||
temperature: Optional[float] = 1.0
|
temperature: Optional[float] = 1.0
|
||||||
|
@ -4,23 +4,6 @@ import torch.nn as nn
|
|||||||
|
|
||||||
from vllm import activation_ops
|
from vllm import activation_ops
|
||||||
|
|
||||||
_ACTIVATION_REGISTRY = {
|
|
||||||
"gelu": nn.GELU(),
|
|
||||||
# NOTE: The following GELU functions may introduce small rounding errors.
|
|
||||||
"gelu_new": nn.GELU(approximate="tanh"),
|
|
||||||
"gelu_fast": nn.GELU(approximate="tanh"),
|
|
||||||
"gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
|
|
||||||
"relu": nn.ReLU(),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def get_act_fn(act_fn: str) -> nn.Module:
|
|
||||||
"""Get an activation function by name."""
|
|
||||||
act_fn = act_fn.lower()
|
|
||||||
if act_fn in _ACTIVATION_REGISTRY:
|
|
||||||
return _ACTIVATION_REGISTRY[act_fn]
|
|
||||||
raise ValueError(f"Activation function {act_fn!r} is not supported.")
|
|
||||||
|
|
||||||
|
|
||||||
class SiluAndMul(nn.Module):
|
class SiluAndMul(nn.Module):
|
||||||
"""An activation function for SwiGLU.
|
"""An activation function for SwiGLU.
|
||||||
@ -38,3 +21,40 @@ class SiluAndMul(nn.Module):
|
|||||||
out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device)
|
out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device)
|
||||||
activation_ops.silu_and_mul(out, x)
|
activation_ops.silu_and_mul(out, x)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class NewGELU(nn.Module):
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
num_tokens = x.shape[0]
|
||||||
|
d = x.shape[1]
|
||||||
|
out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device)
|
||||||
|
activation_ops.gelu_new(out, x)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class FastGELU(nn.Module):
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
num_tokens = x.shape[0]
|
||||||
|
d = x.shape[1]
|
||||||
|
out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device)
|
||||||
|
activation_ops.gelu_fast(out, x)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
_ACTIVATION_REGISTRY = {
|
||||||
|
"gelu": nn.GELU(),
|
||||||
|
"gelu_fast": FastGELU(),
|
||||||
|
"gelu_new": NewGELU(),
|
||||||
|
"gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
|
||||||
|
"relu": nn.ReLU(),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_act_fn(act_fn: str) -> nn.Module:
|
||||||
|
"""Get an activation function by name."""
|
||||||
|
act_fn = act_fn.lower()
|
||||||
|
if act_fn in _ACTIVATION_REGISTRY:
|
||||||
|
return _ACTIVATION_REGISTRY[act_fn]
|
||||||
|
raise ValueError(f"Activation function {act_fn!r} is not supported.")
|
||||||
|
@ -61,7 +61,6 @@ class PagedAttention(nn.Module):
|
|||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
self.scale = float(scale)
|
self.scale = float(scale)
|
||||||
self.attn_op = xops.fmha.cutlass.FwOp()
|
|
||||||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
||||||
|
|
||||||
assert self.num_heads % self.num_kv_heads == 0
|
assert self.num_heads % self.num_kv_heads == 0
|
||||||
@ -115,7 +114,6 @@ class PagedAttention(nn.Module):
|
|||||||
attn_bias=input_metadata.attn_bias[0],
|
attn_bias=input_metadata.attn_bias[0],
|
||||||
p=0.0,
|
p=0.0,
|
||||||
scale=self.scale,
|
scale=self.scale,
|
||||||
op=self.attn_op,
|
|
||||||
)
|
)
|
||||||
# TODO(woosuk): Unnecessary copy. Optimize.
|
# TODO(woosuk): Unnecessary copy. Optimize.
|
||||||
output.copy_(out.squeeze(0))
|
output.copy_(out.squeeze(0))
|
||||||
@ -244,7 +242,7 @@ class PagedAttention(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class PagedAttentionWithRoPE(PagedAttention):
|
class PagedAttentionWithRoPE(PagedAttention):
|
||||||
"""PagedAttention with GPT-NeoX style rotary embedding."""
|
"""PagedAttention with rotary embedding."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -255,8 +253,10 @@ class PagedAttentionWithRoPE(PagedAttention):
|
|||||||
max_position: int = 8192,
|
max_position: int = 8192,
|
||||||
base: int = 10000,
|
base: int = 10000,
|
||||||
num_kv_heads: Optional[int] = None,
|
num_kv_heads: Optional[int] = None,
|
||||||
|
is_neox_style: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(num_heads, head_size, scale, num_kv_heads)
|
super().__init__(num_heads, head_size, scale, num_kv_heads)
|
||||||
|
self.is_neox_style = is_neox_style
|
||||||
|
|
||||||
# Create the cos and sin cache.
|
# Create the cos and sin cache.
|
||||||
inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim))
|
inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim))
|
||||||
@ -305,12 +305,13 @@ class PagedAttentionWithRoPE(PagedAttention):
|
|||||||
|
|
||||||
# Apply rotary embedding to the query and key before passing them
|
# Apply rotary embedding to the query and key before passing them
|
||||||
# to the attention op.
|
# to the attention op.
|
||||||
pos_encoding_ops.rotary_embedding_neox(
|
pos_encoding_ops.rotary_embedding(
|
||||||
positions,
|
positions,
|
||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
self.head_size,
|
self.head_size,
|
||||||
self.cos_sin_cache,
|
self.cos_sin_cache,
|
||||||
|
self.is_neox_style,
|
||||||
)
|
)
|
||||||
return super().forward(
|
return super().forward(
|
||||||
query,
|
query,
|
||||||
@ -357,11 +358,12 @@ class PagedAttentionWithALiBi(PagedAttention):
|
|||||||
# be sliced from a tensor whose length is a multiple of 8.
|
# be sliced from a tensor whose length is a multiple of 8.
|
||||||
padded_len = (prompt_len + 7) // 8 * 8
|
padded_len = (prompt_len + 7) // 8 * 8
|
||||||
bias = torch.empty(
|
bias = torch.empty(
|
||||||
|
1, # batch_size
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
padded_len,
|
prompt_len,
|
||||||
padded_len,
|
padded_len,
|
||||||
device=self.alibi_slopes.device,
|
device=self.alibi_slopes.device,
|
||||||
)[:, :prompt_len, :prompt_len].copy_(bias)
|
)[:, :, :, :prompt_len].copy_(bias)
|
||||||
bias.mul_(self.alibi_slopes[:, None, None])
|
bias.mul_(self.alibi_slopes[:, None, None])
|
||||||
attn_bias = LowerTriangularMaskWithTensorBias(bias)
|
attn_bias = LowerTriangularMaskWithTensorBias(bias)
|
||||||
input_metadata.attn_bias.append(attn_bias)
|
input_metadata.attn_bias.append(attn_bias)
|
||||||
@ -403,7 +405,6 @@ class PagedAttentionWithALiBi(PagedAttention):
|
|||||||
attn_bias=input_metadata.attn_bias[i],
|
attn_bias=input_metadata.attn_bias[i],
|
||||||
p=0.0,
|
p=0.0,
|
||||||
scale=self.scale,
|
scale=self.scale,
|
||||||
op=self.attn_op,
|
|
||||||
)
|
)
|
||||||
# TODO(woosuk): Unnecessary copy. Optimize.
|
# TODO(woosuk): Unnecessary copy. Optimize.
|
||||||
output[start:end].copy_(out.squeeze(0))
|
output[start:end].copy_(out.squeeze(0))
|
||||||
|
@ -9,7 +9,7 @@ from vllm.model_executor.input_metadata import InputMetadata
|
|||||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||||
gather_from_tensor_model_parallel_region)
|
gather_from_tensor_model_parallel_region)
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.sequence import SequenceOutputs
|
from vllm.sequence import SamplerOutput, SequenceOutputs
|
||||||
|
|
||||||
_SAMPLING_EPS = 1e-5
|
_SAMPLING_EPS = 1e-5
|
||||||
|
|
||||||
@ -39,7 +39,7 @@ class Sampler(nn.Module):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
embedding_bias: Optional[torch.Tensor] = None,
|
embedding_bias: Optional[torch.Tensor] = None,
|
||||||
) -> Dict[int, SequenceOutputs]:
|
) -> SamplerOutput:
|
||||||
# Get the hidden states that we use for sampling.
|
# Get the hidden states that we use for sampling.
|
||||||
hidden_states = _prune_hidden_states(hidden_states, input_metadata)
|
hidden_states = _prune_hidden_states(hidden_states, input_metadata)
|
||||||
|
|
||||||
@ -71,20 +71,20 @@ class Sampler(nn.Module):
|
|||||||
# Use in-place division to avoid creating a new tensor.
|
# Use in-place division to avoid creating a new tensor.
|
||||||
logits.div_(t.unsqueeze(dim=1))
|
logits.div_(t.unsqueeze(dim=1))
|
||||||
|
|
||||||
|
# Apply top-p and top-k truncation.
|
||||||
|
top_ps, top_ks = _get_top_p_top_k(input_metadata, self.vocab_size)
|
||||||
|
assert len(top_ps) == len(top_ks) == logits.shape[0]
|
||||||
|
do_top_p = any(p < 1.0 - _SAMPLING_EPS for p in top_ps)
|
||||||
|
do_top_k = any(k != self.vocab_size for k in top_ks)
|
||||||
|
if do_top_p or do_top_k:
|
||||||
|
logits = _apply_top_p_top_k(logits, top_ps, top_ks)
|
||||||
|
|
||||||
# We use float32 for probabilities and log probabilities.
|
# We use float32 for probabilities and log probabilities.
|
||||||
# Compute the probabilities.
|
# Compute the probabilities.
|
||||||
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
||||||
# Compute the log probabilities (before applying top-p and top-k).
|
# Compute the log probabilities (before applying top-p and top-k).
|
||||||
logprobs = torch.log(probs)
|
logprobs = torch.log(probs)
|
||||||
|
|
||||||
# Apply top-p and top-k truncation.
|
|
||||||
top_ps, top_ks = _get_top_p_top_k(input_metadata, self.vocab_size)
|
|
||||||
assert len(top_ps) == len(top_ks) == probs.shape[0]
|
|
||||||
do_top_p = any(p < 1.0 - _SAMPLING_EPS for p in top_ps)
|
|
||||||
do_top_k = any(k != self.vocab_size for k in top_ks)
|
|
||||||
if do_top_p or do_top_k:
|
|
||||||
probs = _apply_top_p_top_k(probs, top_ps, top_ks)
|
|
||||||
|
|
||||||
# Sample the next tokens.
|
# Sample the next tokens.
|
||||||
return _sample(probs, logprobs, input_metadata)
|
return _sample(probs, logprobs, input_metadata)
|
||||||
|
|
||||||
@ -100,7 +100,8 @@ def _prune_hidden_states(
|
|||||||
start_idx += prompt_len
|
start_idx += prompt_len
|
||||||
last_token_indicies.extend(
|
last_token_indicies.extend(
|
||||||
range(start_idx, start_idx + input_metadata.num_generation_tokens))
|
range(start_idx, start_idx + input_metadata.num_generation_tokens))
|
||||||
return hidden_states[last_token_indicies]
|
return hidden_states.index_select(
|
||||||
|
0, torch.tensor(last_token_indicies, device=hidden_states.device))
|
||||||
|
|
||||||
|
|
||||||
def _get_penalties(
|
def _get_penalties(
|
||||||
@ -157,7 +158,7 @@ def _apply_penalties(
|
|||||||
continue
|
continue
|
||||||
p = presence_penalties[i]
|
p = presence_penalties[i]
|
||||||
f = frequency_penalties[i]
|
f = frequency_penalties[i]
|
||||||
if p < _SAMPLING_EPS and f < _SAMPLING_EPS:
|
if abs(p) < _SAMPLING_EPS and abs(f) < _SAMPLING_EPS:
|
||||||
continue
|
continue
|
||||||
indices.append(i)
|
indices.append(i)
|
||||||
|
|
||||||
@ -235,31 +236,32 @@ def _get_top_p_top_k(
|
|||||||
|
|
||||||
|
|
||||||
def _apply_top_p_top_k(
|
def _apply_top_p_top_k(
|
||||||
probs: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
top_ps: List[float],
|
top_ps: List[float],
|
||||||
top_ks: List[int],
|
top_ks: List[int],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
p = torch.tensor(top_ps, dtype=probs.dtype, device=probs.device)
|
p = torch.tensor(top_ps, dtype=logits.dtype, device=logits.device)
|
||||||
k = torch.tensor(top_ks, dtype=torch.int, device=probs.device)
|
k = torch.tensor(top_ks, dtype=torch.int, device=logits.device)
|
||||||
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
|
logits_sort, logits_idx = logits.sort(dim=-1, descending=True)
|
||||||
|
|
||||||
# Apply top-p.
|
# Apply top-p.
|
||||||
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
probs_sort = logits_sort.softmax(dim=-1)
|
||||||
|
probs_sum = probs_sort.cumsum(dim=-1)
|
||||||
top_p_mask = (probs_sum - probs_sort) > p.unsqueeze(dim=1)
|
top_p_mask = (probs_sum - probs_sort) > p.unsqueeze(dim=1)
|
||||||
probs_sort[top_p_mask] = 0.0
|
logits_sort[top_p_mask] = -float("inf")
|
||||||
|
|
||||||
# Apply top-k.
|
# Apply top-k.
|
||||||
# Create a mask for the top-k elements.
|
# Create a mask for the top-k elements.
|
||||||
top_k_mask = torch.arange(probs_idx.shape[-1], device=probs_idx.device)
|
top_k_mask = torch.arange(logits_idx.shape[-1], device=logits_idx.device)
|
||||||
top_k_mask = top_k_mask.expand(probs_idx.shape[0], -1)
|
top_k_mask = top_k_mask.expand(logits_idx.shape[0], -1)
|
||||||
top_k_mask = top_k_mask >= k.unsqueeze(dim=1)
|
top_k_mask = top_k_mask >= k.unsqueeze(dim=1)
|
||||||
probs_sort[top_k_mask] = 0.0
|
logits_sort[top_k_mask] = -float("inf")
|
||||||
|
|
||||||
# Re-sort the probabilities.
|
# Re-sort the probabilities.
|
||||||
probs = torch.gather(probs_sort,
|
logits = torch.gather(logits_sort,
|
||||||
dim=-1,
|
dim=-1,
|
||||||
index=torch.argsort(probs_idx, dim=-1))
|
index=torch.argsort(logits_idx, dim=-1))
|
||||||
return probs
|
return logits
|
||||||
|
|
||||||
|
|
||||||
def _get_topk_logprobs(
|
def _get_topk_logprobs(
|
||||||
@ -290,7 +292,13 @@ def _sample_from_prompt(
|
|||||||
if sampling_params.use_beam_search:
|
if sampling_params.use_beam_search:
|
||||||
# Beam search.
|
# Beam search.
|
||||||
beam_width = sampling_params.best_of
|
beam_width = sampling_params.best_of
|
||||||
_, next_token_ids = torch.topk(prob, beam_width)
|
# Sample 2 * beam_width candidates to make sure that with high
|
||||||
|
# probability we can get `beam_width` candidates in addition to
|
||||||
|
# the finished sequences for the next iteration. See
|
||||||
|
# https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563
|
||||||
|
# for details. See also HF reference:
|
||||||
|
# https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
|
||||||
|
_, next_token_ids = torch.topk(prob, 2 * beam_width)
|
||||||
next_token_ids = next_token_ids.tolist()
|
next_token_ids = next_token_ids.tolist()
|
||||||
elif sampling_params.temperature < _SAMPLING_EPS:
|
elif sampling_params.temperature < _SAMPLING_EPS:
|
||||||
# Greedy sampling.
|
# Greedy sampling.
|
||||||
@ -328,29 +336,11 @@ def _sample_from_generation_tokens(
|
|||||||
|
|
||||||
vocab_size = logprobs.size(-1)
|
vocab_size = logprobs.size(-1)
|
||||||
beam_width = len(seq_ids)
|
beam_width = len(seq_ids)
|
||||||
_, topk_ids = torch.topk(logprobs.flatten(), beam_width)
|
_, topk_ids = torch.topk(logprobs.flatten(), 2 * beam_width)
|
||||||
topk_ids = topk_ids.tolist()
|
topk_ids = topk_ids.tolist()
|
||||||
seq_idx = [i // vocab_size for i in topk_ids]
|
seq_idx = [i // vocab_size for i in topk_ids]
|
||||||
beam_seq_ids = [seq_ids[i] for i in seq_idx]
|
parent_seq_ids = [seq_ids[i] for i in seq_idx]
|
||||||
token_ids = [i % vocab_size for i in topk_ids]
|
next_token_ids = [i % vocab_size for i in topk_ids]
|
||||||
|
|
||||||
beam_outputs: Dict[int, Tuple[int, int]] = {}
|
|
||||||
outstanding_beams: List[Tuple[int, int]] = []
|
|
||||||
# If a beam survives, continue with it.
|
|
||||||
for seq_id, token_id in zip(beam_seq_ids, token_ids):
|
|
||||||
if seq_id not in beam_outputs:
|
|
||||||
beam_outputs[seq_id] = (seq_id, token_id)
|
|
||||||
else:
|
|
||||||
outstanding_beams.append((seq_id, token_id))
|
|
||||||
|
|
||||||
# If a beam is discarded, fork another beam.
|
|
||||||
for seq_id in seq_ids:
|
|
||||||
if seq_id not in beam_outputs:
|
|
||||||
beam_outputs[seq_id] = outstanding_beams.pop()
|
|
||||||
assert not outstanding_beams
|
|
||||||
|
|
||||||
parent_seq_ids = [beam_outputs[seq_id][0] for seq_id in seq_ids]
|
|
||||||
next_token_ids = [beam_outputs[seq_id][1] for seq_id in seq_ids]
|
|
||||||
elif sampling_params.temperature < _SAMPLING_EPS:
|
elif sampling_params.temperature < _SAMPLING_EPS:
|
||||||
# Greedy sampling.
|
# Greedy sampling.
|
||||||
assert len(seq_ids) == 1
|
assert len(seq_ids) == 1
|
||||||
@ -372,16 +362,18 @@ def _sample(
|
|||||||
probs: torch.Tensor,
|
probs: torch.Tensor,
|
||||||
logprobs: torch.Tensor,
|
logprobs: torch.Tensor,
|
||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
) -> Dict[int, SequenceOutputs]:
|
) -> SamplerOutput:
|
||||||
seq_outputs: Dict[int, SequenceOutputs] = {}
|
seq_outputs: SamplerOutput = []
|
||||||
|
|
||||||
# TODO(woosuk): Optimize.
|
# TODO(woosuk): Optimize.
|
||||||
idx = 0
|
idx = 0
|
||||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
for i, seq_group in enumerate(input_metadata.seq_groups):
|
||||||
|
seq_group_outputs: List[SequenceOutputs] = []
|
||||||
seq_ids, sampling_params = seq_group
|
seq_ids, sampling_params = seq_group
|
||||||
if i < input_metadata.num_prompts:
|
if i < input_metadata.num_prompts:
|
||||||
# Generate the next tokens for a prompt input.
|
# Generate the next tokens for a prompt input.
|
||||||
assert len(seq_ids) == sampling_params.best_of
|
assert len(seq_ids) == 1, "Prompt input should have only one seq."
|
||||||
|
parent_seq_id = seq_ids[0]
|
||||||
prob = probs[idx]
|
prob = probs[idx]
|
||||||
logprob = logprobs[idx]
|
logprob = logprobs[idx]
|
||||||
idx += 1
|
idx += 1
|
||||||
@ -393,17 +385,18 @@ def _sample(
|
|||||||
sampling_params.logprobs)
|
sampling_params.logprobs)
|
||||||
|
|
||||||
# Build the output.
|
# Build the output.
|
||||||
for seq_id, next_token_id in zip(seq_ids, next_token_ids):
|
for next_token_id in next_token_ids:
|
||||||
output_logprobs = next_logprobs.copy()
|
output_logprobs = next_logprobs.copy()
|
||||||
output_logprobs[next_token_id] = logprob[next_token_id].item()
|
output_logprobs[next_token_id] = logprob[next_token_id].item()
|
||||||
seq_outputs[seq_id] = SequenceOutputs(seq_id, seq_id,
|
seq_group_outputs.append(
|
||||||
next_token_id,
|
SequenceOutputs(parent_seq_id, next_token_id,
|
||||||
output_logprobs)
|
output_logprobs))
|
||||||
else:
|
else:
|
||||||
# Generate the next tokens for generation tokens.
|
# Generate the next tokens for generation tokens.
|
||||||
prob = probs[idx:idx + len(seq_ids)]
|
num_parent_seqs = len(seq_ids)
|
||||||
logprob = logprobs[idx:idx + len(seq_ids)]
|
prob = probs[idx:idx + num_parent_seqs]
|
||||||
idx += len(seq_ids)
|
logprob = logprobs[idx:idx + num_parent_seqs]
|
||||||
|
idx += num_parent_seqs
|
||||||
|
|
||||||
# Sample the next tokens.
|
# Sample the next tokens.
|
||||||
seq_logprobs = [
|
seq_logprobs = [
|
||||||
@ -420,17 +413,15 @@ def _sample(
|
|||||||
logprob[j], sampling_params.logprobs)
|
logprob[j], sampling_params.logprobs)
|
||||||
|
|
||||||
# Build the output.
|
# Build the output.
|
||||||
for seq_id, parent_seq_id, next_token_id in zip(
|
for parent_seq_id, next_token_id in zip(parent_seq_ids,
|
||||||
seq_ids, parent_seq_ids, next_token_ids):
|
next_token_ids):
|
||||||
j = seq_ids.index(parent_seq_id)
|
j = seq_ids.index(parent_seq_id)
|
||||||
output_logprobs = next_logprobs[parent_seq_id].copy()
|
output_logprobs = next_logprobs[parent_seq_id].copy()
|
||||||
output_logprobs[next_token_id] = logprob[j,
|
output_logprobs[next_token_id] = logprob[j,
|
||||||
next_token_id].item()
|
next_token_id].item()
|
||||||
seq_outputs[seq_id] = SequenceOutputs(
|
seq_group_outputs.append(
|
||||||
seq_id,
|
SequenceOutputs(parent_seq_id, next_token_id,
|
||||||
parent_seq_id,
|
output_logprobs))
|
||||||
next_token_id,
|
seq_outputs.append(seq_group_outputs)
|
||||||
output_logprobs,
|
|
||||||
)
|
|
||||||
|
|
||||||
return seq_outputs
|
return seq_outputs
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
"""Utilities for selecting and loading models."""
|
"""Utilities for selecting and loading models."""
|
||||||
|
import contextlib
|
||||||
from typing import Type
|
from typing import Type
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -11,6 +12,7 @@ from vllm.model_executor.weight_utils import initialize_dummy_weights
|
|||||||
|
|
||||||
# TODO(woosuk): Lazy-load the model classes.
|
# TODO(woosuk): Lazy-load the model classes.
|
||||||
_MODEL_REGISTRY = {
|
_MODEL_REGISTRY = {
|
||||||
|
"AquilaModel": AquilaForCausalLM,
|
||||||
"BaiChuanForCausalLM": BaiChuanForCausalLM, # baichuan-7b
|
"BaiChuanForCausalLM": BaiChuanForCausalLM, # baichuan-7b
|
||||||
"BaichuanForCausalLM": BaichuanForCausalLM, # baichuan-13b
|
"BaichuanForCausalLM": BaichuanForCausalLM, # baichuan-13b
|
||||||
"BloomForCausalLM": BloomForCausalLM,
|
"BloomForCausalLM": BloomForCausalLM,
|
||||||
@ -19,14 +21,25 @@ _MODEL_REGISTRY = {
|
|||||||
"GPTBigCodeForCausalLM": GPTBigCodeForCausalLM,
|
"GPTBigCodeForCausalLM": GPTBigCodeForCausalLM,
|
||||||
"GPTJForCausalLM": GPTJForCausalLM,
|
"GPTJForCausalLM": GPTJForCausalLM,
|
||||||
"GPTNeoXForCausalLM": GPTNeoXForCausalLM,
|
"GPTNeoXForCausalLM": GPTNeoXForCausalLM,
|
||||||
|
"InternLMForCausalLM": InternLMForCausalLM,
|
||||||
"LlamaForCausalLM": LlamaForCausalLM,
|
"LlamaForCausalLM": LlamaForCausalLM,
|
||||||
"LLaMAForCausalLM": LlamaForCausalLM, # For decapoda-research/llama-*
|
"LLaMAForCausalLM": LlamaForCausalLM, # For decapoda-research/llama-*
|
||||||
"MPTForCausalLM": MPTForCausalLM,
|
"MPTForCausalLM": MPTForCausalLM,
|
||||||
"OPTForCausalLM": OPTForCausalLM,
|
"OPTForCausalLM": OPTForCausalLM,
|
||||||
|
"QWenLMHeadModel": QWenLMHeadModel,
|
||||||
"RWForCausalLM": FalconForCausalLM,
|
"RWForCausalLM": FalconForCausalLM,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def _set_default_torch_dtype(dtype: torch.dtype):
|
||||||
|
"""Sets the default torch dtype to the given dtype."""
|
||||||
|
old_dtype = torch.get_default_dtype()
|
||||||
|
torch.set_default_dtype(dtype)
|
||||||
|
yield
|
||||||
|
torch.set_default_dtype(old_dtype)
|
||||||
|
|
||||||
|
|
||||||
def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
|
def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
|
||||||
architectures = getattr(config, "architectures", [])
|
architectures = getattr(config, "architectures", [])
|
||||||
for arch in architectures:
|
for arch in architectures:
|
||||||
@ -39,12 +52,11 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
|
|||||||
|
|
||||||
def get_model(model_config: ModelConfig) -> nn.Module:
|
def get_model(model_config: ModelConfig) -> nn.Module:
|
||||||
model_class = _get_model_architecture(model_config.hf_config)
|
model_class = _get_model_architecture(model_config.hf_config)
|
||||||
torch.set_default_dtype(model_config.dtype)
|
with _set_default_torch_dtype(model_config.dtype):
|
||||||
|
|
||||||
# Create a model instance.
|
# Create a model instance.
|
||||||
# The weights will be initialized as empty tensors.
|
# The weights will be initialized as empty tensors.
|
||||||
model = model_class(model_config.hf_config)
|
model = model_class(model_config.hf_config)
|
||||||
if model_config.use_dummy_weights:
|
if model_config.load_format == "dummy":
|
||||||
model = model.cuda()
|
model = model.cuda()
|
||||||
# NOTE(woosuk): For accurate performance evaluation, we assign
|
# NOTE(woosuk): For accurate performance evaluation, we assign
|
||||||
# random values to the weights.
|
# random values to the weights.
|
||||||
@ -52,6 +64,6 @@ def get_model(model_config: ModelConfig) -> nn.Module:
|
|||||||
else:
|
else:
|
||||||
# Load the weights from the cached or downloaded files.
|
# Load the weights from the cached or downloaded files.
|
||||||
model.load_weights(model_config.model, model_config.download_dir,
|
model.load_weights(model_config.model, model_config.download_dir,
|
||||||
model_config.use_np_weights)
|
model_config.load_format)
|
||||||
model = model.cuda()
|
model = model.cuda()
|
||||||
return model.eval()
|
return model.eval()
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
from vllm.model_executor.models.aquila import AquilaForCausalLM
|
||||||
from vllm.model_executor.models.baichuan import (BaiChuanForCausalLM,
|
from vllm.model_executor.models.baichuan import (BaiChuanForCausalLM,
|
||||||
BaichuanForCausalLM)
|
BaichuanForCausalLM)
|
||||||
from vllm.model_executor.models.bloom import BloomForCausalLM
|
from vllm.model_executor.models.bloom import BloomForCausalLM
|
||||||
@ -6,11 +7,14 @@ from vllm.model_executor.models.gpt2 import GPT2LMHeadModel
|
|||||||
from vllm.model_executor.models.gpt_bigcode import GPTBigCodeForCausalLM
|
from vllm.model_executor.models.gpt_bigcode import GPTBigCodeForCausalLM
|
||||||
from vllm.model_executor.models.gpt_j import GPTJForCausalLM
|
from vllm.model_executor.models.gpt_j import GPTJForCausalLM
|
||||||
from vllm.model_executor.models.gpt_neox import GPTNeoXForCausalLM
|
from vllm.model_executor.models.gpt_neox import GPTNeoXForCausalLM
|
||||||
|
from vllm.model_executor.models.internlm import InternLMForCausalLM
|
||||||
from vllm.model_executor.models.llama import LlamaForCausalLM
|
from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||||
from vllm.model_executor.models.mpt import MPTForCausalLM
|
from vllm.model_executor.models.mpt import MPTForCausalLM
|
||||||
from vllm.model_executor.models.opt import OPTForCausalLM
|
from vllm.model_executor.models.opt import OPTForCausalLM
|
||||||
|
from vllm.model_executor.models.qwen import QWenLMHeadModel
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"AquilaForCausalLM",
|
||||||
"BaiChuanForCausalLM",
|
"BaiChuanForCausalLM",
|
||||||
"BaichuanForCausalLM",
|
"BaichuanForCausalLM",
|
||||||
"BloomForCausalLM",
|
"BloomForCausalLM",
|
||||||
@ -19,7 +23,9 @@ __all__ = [
|
|||||||
"GPTBigCodeForCausalLM",
|
"GPTBigCodeForCausalLM",
|
||||||
"GPTJForCausalLM",
|
"GPTJForCausalLM",
|
||||||
"GPTNeoXForCausalLM",
|
"GPTNeoXForCausalLM",
|
||||||
|
"InternLMForCausalLM",
|
||||||
"LlamaForCausalLM",
|
"LlamaForCausalLM",
|
||||||
"MPTForCausalLM",
|
"MPTForCausalLM",
|
||||||
"OPTForCausalLM",
|
"OPTForCausalLM",
|
||||||
|
"QWenLMHeadModel",
|
||||||
]
|
]
|
||||||
|
357
vllm/model_executor/models/aquila.py
Normal file
357
vllm/model_executor/models/aquila.py
Normal file
@ -0,0 +1,357 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Adapted from
|
||||||
|
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
|
||||||
|
# Copyright 2023 The vLLM team.
|
||||||
|
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||||
|
# and OPT implementations in this library. It has been modified from its
|
||||||
|
# original forms to accommodate minor architectural differences compared
|
||||||
|
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Inference-only LLaMA model compatible with HuggingFace weights.
|
||||||
|
|
||||||
|
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
||||||
|
InputMetadata to extract the original 2D shape of the input.
|
||||||
|
"""
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from vllm.model_executor.input_metadata import InputMetadata
|
||||||
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
|
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
|
||||||
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
|
from vllm.model_executor.weight_utils import (
|
||||||
|
hf_model_weights_iterator, load_padded_tensor_parallel_vocab,
|
||||||
|
load_tensor_parallel_weights)
|
||||||
|
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||||
|
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||||
|
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||||
|
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
||||||
|
from vllm.sequence import SamplerOutput
|
||||||
|
from vllm.transformers_utils.configs.aquila import AquilaConfig
|
||||||
|
|
||||||
|
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||||
|
|
||||||
|
|
||||||
|
class AquilaMLP(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
intermediate_size: int,
|
||||||
|
hidden_act: str,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.gate_up_proj = ColumnParallelLinear(hidden_size,
|
||||||
|
2 * intermediate_size,
|
||||||
|
bias=False,
|
||||||
|
gather_output=False,
|
||||||
|
perform_initialization=False)
|
||||||
|
self.down_proj = RowParallelLinear(intermediate_size,
|
||||||
|
hidden_size,
|
||||||
|
bias=False,
|
||||||
|
input_is_parallel=True,
|
||||||
|
perform_initialization=False)
|
||||||
|
if hidden_act != "silu":
|
||||||
|
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||||
|
"Only silu is supported for now.")
|
||||||
|
self.act_fn = SiluAndMul()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
gate_up, _ = self.gate_up_proj(x)
|
||||||
|
x = self.act_fn(gate_up)
|
||||||
|
x, _ = self.down_proj(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class AquilaRMSNorm(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, hidden_size, eps=1e-6):
|
||||||
|
"""
|
||||||
|
AquilaRMSNorm is equivalent to T5LayerNorm
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||||
|
self.variance_epsilon = eps
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
input_dtype = hidden_states.dtype
|
||||||
|
variance = hidden_states.to(torch.float32).pow(2).mean(-1,
|
||||||
|
keepdim=True)
|
||||||
|
hidden_states = hidden_states * torch.rsqrt(variance +
|
||||||
|
self.variance_epsilon)
|
||||||
|
|
||||||
|
return (self.weight * hidden_states).to(input_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class AquilaAttention(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
num_heads: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
self.total_num_heads = num_heads
|
||||||
|
assert self.total_num_heads % tp_size == 0
|
||||||
|
self.num_heads = self.total_num_heads // tp_size
|
||||||
|
self.total_num_kv_heads = num_kv_heads
|
||||||
|
assert self.total_num_kv_heads % tp_size == 0
|
||||||
|
self.num_kv_heads = self.total_num_kv_heads // tp_size
|
||||||
|
self.head_dim = hidden_size // self.total_num_heads
|
||||||
|
self.q_size = self.num_heads * self.head_dim
|
||||||
|
self.kv_size = self.num_kv_heads * self.head_dim
|
||||||
|
self.scaling = self.head_dim**-0.5
|
||||||
|
|
||||||
|
self.qkv_proj = ColumnParallelLinear(
|
||||||
|
hidden_size,
|
||||||
|
(self.total_num_heads + 2 * self.total_num_kv_heads) *
|
||||||
|
self.head_dim,
|
||||||
|
bias=False,
|
||||||
|
gather_output=False,
|
||||||
|
perform_initialization=False,
|
||||||
|
)
|
||||||
|
self.o_proj = RowParallelLinear(
|
||||||
|
self.total_num_heads * self.head_dim,
|
||||||
|
hidden_size,
|
||||||
|
bias=False,
|
||||||
|
input_is_parallel=True,
|
||||||
|
perform_initialization=False,
|
||||||
|
)
|
||||||
|
self.attn = PagedAttentionWithRoPE(
|
||||||
|
self.num_heads,
|
||||||
|
self.head_dim,
|
||||||
|
self.scaling,
|
||||||
|
rotary_dim=self.head_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
kv_cache: KVCache,
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
cache_event: Optional[torch.cuda.Event],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
|
k_cache, v_cache = kv_cache
|
||||||
|
attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
|
||||||
|
input_metadata, cache_event)
|
||||||
|
output, _ = self.o_proj(attn_output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class AquilaDecoderLayer(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config: AquilaConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.self_attn = AquilaAttention(
|
||||||
|
hidden_size=self.hidden_size,
|
||||||
|
num_heads=config.num_attention_heads,
|
||||||
|
num_kv_heads=config.num_attention_heads,
|
||||||
|
)
|
||||||
|
self.mlp = AquilaMLP(
|
||||||
|
hidden_size=self.hidden_size,
|
||||||
|
intermediate_size=config.intermediate_size,
|
||||||
|
hidden_act=config.hidden_act,
|
||||||
|
)
|
||||||
|
self.input_layernorm = AquilaRMSNorm(config.hidden_size,
|
||||||
|
eps=config.rms_norm_eps)
|
||||||
|
self.post_attention_layernorm = AquilaRMSNorm(config.hidden_size,
|
||||||
|
eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
kv_cache: KVCache,
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
cache_event: Optional[torch.cuda.Event],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# Self Attention
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
hidden_states = self.self_attn(
|
||||||
|
positions=positions,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
input_metadata=input_metadata,
|
||||||
|
cache_event=cache_event,
|
||||||
|
)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
# Fully Connected
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class AquilaModel(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config: AquilaConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.padding_idx = config.pad_token_id
|
||||||
|
self.vocab_size = config.vocab_size
|
||||||
|
|
||||||
|
#vocab_size = ((config.vocab_size + 63) // 64) * 64
|
||||||
|
self.embed_tokens = VocabParallelEmbedding(
|
||||||
|
config.vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
perform_initialization=False)
|
||||||
|
self.layers = nn.ModuleList([
|
||||||
|
AquilaDecoderLayer(config) for _ in range(config.num_hidden_layers)
|
||||||
|
])
|
||||||
|
self.norm = AquilaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
kv_caches: List[KVCache],
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
cache_events: Optional[List[torch.cuda.Event]],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
for i in range(len(self.layers)):
|
||||||
|
if cache_events is None:
|
||||||
|
cache_event = None
|
||||||
|
else:
|
||||||
|
cache_event = cache_events[i]
|
||||||
|
layer = self.layers[i]
|
||||||
|
hidden_states = layer(
|
||||||
|
positions,
|
||||||
|
hidden_states,
|
||||||
|
kv_caches[i],
|
||||||
|
input_metadata,
|
||||||
|
cache_event,
|
||||||
|
)
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class AquilaForCausalLM(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.model = AquilaModel(config)
|
||||||
|
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
||||||
|
self.lm_head = ColumnParallelLinear(config.hidden_size,
|
||||||
|
vocab_size,
|
||||||
|
bias=False,
|
||||||
|
gather_output=False,
|
||||||
|
perform_initialization=False)
|
||||||
|
self.sampler = Sampler(config.vocab_size)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
kv_caches: List[KVCache],
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
cache_events: Optional[List[torch.cuda.Event]],
|
||||||
|
) -> SamplerOutput:
|
||||||
|
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||||
|
input_metadata, cache_events)
|
||||||
|
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
||||||
|
input_metadata)
|
||||||
|
return next_tokens
|
||||||
|
|
||||||
|
_column_parallel_weights = [
|
||||||
|
"qkv_proj.weight", "gate_proj.weight", "up_proj.weight"
|
||||||
|
]
|
||||||
|
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
|
||||||
|
|
||||||
|
def load_weights(self,
|
||||||
|
model_name_or_path: str,
|
||||||
|
cache_dir: Optional[str] = None,
|
||||||
|
load_format: str = "auto"):
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||||
|
q_proj_shard_size = (self.config.hidden_size // tp_size)
|
||||||
|
kv_proj_shard_size = (self.config.hidden_size //
|
||||||
|
self.config.num_attention_heads *
|
||||||
|
self.config.num_attention_heads // tp_size)
|
||||||
|
attention_weight_specs = [
|
||||||
|
# (weight_name, shard_size, offset)
|
||||||
|
("q_proj", q_proj_shard_size, 0),
|
||||||
|
("k_proj", kv_proj_shard_size, q_proj_shard_size),
|
||||||
|
("v_proj", kv_proj_shard_size,
|
||||||
|
q_proj_shard_size + kv_proj_shard_size),
|
||||||
|
]
|
||||||
|
state_dict = self.state_dict()
|
||||||
|
|
||||||
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
|
model_name_or_path, cache_dir, load_format):
|
||||||
|
if "rotary_emb.inv_freq" in name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
is_attention_weight = False
|
||||||
|
for weight_name, shard_size, offset in attention_weight_specs:
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
param = state_dict[name.replace(weight_name, "qkv_proj")]
|
||||||
|
|
||||||
|
loaded_weight = loaded_weight[
|
||||||
|
shard_size * tensor_model_parallel_rank:shard_size *
|
||||||
|
(tensor_model_parallel_rank + 1)]
|
||||||
|
param_slice = param.data[offset:offset + shard_size]
|
||||||
|
assert param_slice.shape == loaded_weight.shape
|
||||||
|
|
||||||
|
param_slice.copy_(loaded_weight)
|
||||||
|
is_attention_weight = True
|
||||||
|
break
|
||||||
|
if is_attention_weight:
|
||||||
|
continue
|
||||||
|
|
||||||
|
is_gate_up_weight = False
|
||||||
|
for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
param = state_dict[name.replace(weight_name, "gate_up_proj")]
|
||||||
|
shard_size = param.shape[0] // 2
|
||||||
|
loaded_weight = loaded_weight[
|
||||||
|
shard_size * tensor_model_parallel_rank:shard_size *
|
||||||
|
(tensor_model_parallel_rank + 1)]
|
||||||
|
param_slice = param.data[shard_size * stride_id:shard_size *
|
||||||
|
(stride_id + 1)]
|
||||||
|
assert param_slice.shape == loaded_weight.shape
|
||||||
|
param_slice.copy_(loaded_weight)
|
||||||
|
is_gate_up_weight = True
|
||||||
|
break
|
||||||
|
if is_gate_up_weight:
|
||||||
|
continue
|
||||||
|
|
||||||
|
param = state_dict[name]
|
||||||
|
if "embed_tokens" in name or "lm_head" in name:
|
||||||
|
load_padded_tensor_parallel_vocab(param, loaded_weight,
|
||||||
|
tensor_model_parallel_rank)
|
||||||
|
continue
|
||||||
|
|
||||||
|
load_tensor_parallel_weights(param, loaded_weight, name,
|
||||||
|
self._column_parallel_weights,
|
||||||
|
self._row_parallel_weights,
|
||||||
|
tensor_model_parallel_rank)
|
@ -23,23 +23,25 @@ The input of the model is flattened to a 1D tensor of tokens. The model uses
|
|||||||
InputMetadata to extract the original 2D shape of the input.
|
InputMetadata to extract the original 2D shape of the input.
|
||||||
"""
|
"""
|
||||||
import math
|
import math
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from vllm.sequence import SequenceOutputs
|
|
||||||
from vllm.model_executor.input_metadata import InputMetadata
|
from vllm.model_executor.input_metadata import InputMetadata
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE, PagedAttentionWithALiBi
|
from vllm.model_executor.layers.attention import (PagedAttentionWithRoPE,
|
||||||
|
PagedAttentionWithALiBi)
|
||||||
from vllm.model_executor.layers.sampler import Sampler
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
|
from vllm.model_executor.weight_utils import (
|
||||||
load_tensor_parallel_weights)
|
convert_pyslice_to_tensor, hf_model_weights_iterator,
|
||||||
|
load_padded_tensor_parallel_vocab, load_tensor_parallel_weights)
|
||||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||||
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
||||||
|
from vllm.sequence import SamplerOutput
|
||||||
from vllm.transformers_utils.configs.baichuan import BaiChuanConfig
|
from vllm.transformers_utils.configs.baichuan import BaiChuanConfig
|
||||||
|
|
||||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||||
@ -288,41 +290,30 @@ class BaiChuanBaseForCausalLM(nn.Module):
|
|||||||
kv_caches: List[KVCache],
|
kv_caches: List[KVCache],
|
||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
cache_events: Optional[List[torch.cuda.Event]],
|
cache_events: Optional[List[torch.cuda.Event]],
|
||||||
) -> Dict[int, SequenceOutputs]:
|
) -> SamplerOutput:
|
||||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||||
input_metadata, cache_events)
|
input_metadata, cache_events)
|
||||||
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
||||||
input_metadata)
|
input_metadata)
|
||||||
return next_tokens
|
return next_tokens
|
||||||
|
|
||||||
_column_parallel_weights = [
|
_column_parallel_weights = []
|
||||||
"embed_tokens.weight",
|
|
||||||
"lm_head.weight",
|
|
||||||
]
|
|
||||||
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
|
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
|
||||||
|
|
||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
use_np_cache: bool = False):
|
load_format: str = "auto"):
|
||||||
tp_world_size = get_tensor_model_parallel_world_size()
|
tp_world_size = get_tensor_model_parallel_world_size()
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
|
|
||||||
for name, loaded_weight in hf_model_weights_iterator(
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
model_name_or_path, cache_dir, use_np_cache):
|
model_name_or_path, cache_dir, load_format):
|
||||||
if "rotary_emb.inv_freq" in name:
|
if "rotary_emb.inv_freq" in name:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if "embed_tokens" in name or "lm_head" in name:
|
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
||||||
# Consider padding in the vocab size.
|
|
||||||
param = state_dict[name]
|
|
||||||
padded_vocab_size = param.shape[0] * tp_world_size
|
|
||||||
num_extra_rows = padded_vocab_size - self.config.vocab_size
|
|
||||||
extra_rows = torch.empty(num_extra_rows,
|
|
||||||
loaded_weight.shape[1])
|
|
||||||
extra_rows = extra_rows.to(loaded_weight)
|
|
||||||
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
|
|
||||||
|
|
||||||
if "W_pack" in name:
|
if "W_pack" in name:
|
||||||
total_num_heads = self.config.num_attention_heads
|
total_num_heads = self.config.num_attention_heads
|
||||||
@ -355,6 +346,12 @@ class BaiChuanBaseForCausalLM(nn.Module):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
param = state_dict[name]
|
param = state_dict[name]
|
||||||
|
|
||||||
|
if "embed_tokens" in name or "lm_head" in name:
|
||||||
|
load_padded_tensor_parallel_vocab(param, loaded_weight,
|
||||||
|
tp_rank)
|
||||||
|
continue
|
||||||
|
|
||||||
load_tensor_parallel_weights(
|
load_tensor_parallel_weights(
|
||||||
param,
|
param,
|
||||||
loaded_weight,
|
loaded_weight,
|
||||||
|
@ -21,7 +21,7 @@ The input of the model is flattened to a 1D tensor of tokens. The model uses
|
|||||||
InputMetadata to extract the original 2D shape of the input.
|
InputMetadata to extract the original 2D shape of the input.
|
||||||
"""
|
"""
|
||||||
import math
|
import math
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -37,7 +37,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
|
|||||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||||
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
||||||
from vllm.sequence import SequenceOutputs
|
from vllm.sequence import SamplerOutput
|
||||||
|
|
||||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||||
|
|
||||||
@ -264,7 +264,7 @@ class BloomForCausalLM(nn.Module):
|
|||||||
kv_caches: List[KVCache],
|
kv_caches: List[KVCache],
|
||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
cache_events: Optional[List[torch.cuda.Event]],
|
cache_events: Optional[List[torch.cuda.Event]],
|
||||||
) -> Dict[int, SequenceOutputs]:
|
) -> SamplerOutput:
|
||||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||||
input_metadata, cache_events)
|
input_metadata, cache_events)
|
||||||
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
|
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
|
||||||
@ -279,11 +279,11 @@ class BloomForCausalLM(nn.Module):
|
|||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
use_np_cache: bool = False):
|
load_format: str = "auto"):
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
for name, loaded_weight in hf_model_weights_iterator(
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
model_name_or_path, cache_dir, use_np_cache):
|
model_name_or_path, cache_dir, load_format):
|
||||||
if name == "lm_head.weight":
|
if name == "lm_head.weight":
|
||||||
# Since hidden_states are parallelized, we need to
|
# Since hidden_states are parallelized, we need to
|
||||||
# load lm_head.weight in parallel.
|
# load lm_head.weight in parallel.
|
||||||
|
@ -19,7 +19,7 @@
|
|||||||
"""PyTorch Falcon model."""
|
"""PyTorch Falcon model."""
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -31,14 +31,15 @@ from vllm.model_executor.layers.attention import (PagedAttention,
|
|||||||
PagedAttentionWithALiBi,
|
PagedAttentionWithALiBi,
|
||||||
PagedAttentionWithRoPE)
|
PagedAttentionWithRoPE)
|
||||||
from vllm.model_executor.layers.sampler import Sampler
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
|
from vllm.model_executor.weight_utils import (convert_pyslice_to_tensor,
|
||||||
|
hf_model_weights_iterator,
|
||||||
load_tensor_parallel_weights)
|
load_tensor_parallel_weights)
|
||||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||||
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear,
|
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear,
|
||||||
reduce_from_tensor_model_parallel_region)
|
reduce_from_tensor_model_parallel_region)
|
||||||
from vllm.sequence import SequenceOutputs
|
from vllm.sequence import SamplerOutput
|
||||||
from vllm.transformers_utils.configs import RWConfig
|
from vllm.transformers_utils.configs import RWConfig
|
||||||
|
|
||||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||||
@ -397,7 +398,7 @@ class FalconForCausalLM(nn.Module):
|
|||||||
kv_caches: List[KVCache],
|
kv_caches: List[KVCache],
|
||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
cache_events: Optional[List[torch.cuda.Event]],
|
cache_events: Optional[List[torch.cuda.Event]],
|
||||||
) -> Dict[int, SequenceOutputs]:
|
) -> SamplerOutput:
|
||||||
hidden_states = self.transformer(
|
hidden_states = self.transformer(
|
||||||
input_ids,
|
input_ids,
|
||||||
positions,
|
positions,
|
||||||
@ -419,7 +420,7 @@ class FalconForCausalLM(nn.Module):
|
|||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
use_np_cache: bool = False):
|
load_format: str = "auto"):
|
||||||
tp_size = (get_tensor_model_parallel_world_size())
|
tp_size = (get_tensor_model_parallel_world_size())
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
|
||||||
@ -451,8 +452,9 @@ class FalconForCausalLM(nn.Module):
|
|||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
|
|
||||||
for name, loaded_weight in hf_model_weights_iterator(
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
model_name_or_path, cache_dir, use_np_cache):
|
model_name_or_path, cache_dir, load_format):
|
||||||
if "query_key_value" in name:
|
if "query_key_value" in name:
|
||||||
|
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
||||||
loaded_weight_size = loaded_weight.size()
|
loaded_weight_size = loaded_weight.size()
|
||||||
loaded_weight = loaded_weight.view(
|
loaded_weight = loaded_weight.view(
|
||||||
total_num_kv_heads, num_query_heads_per_kv_head + 2,
|
total_num_kv_heads, num_query_heads_per_kv_head + 2,
|
||||||
|
@ -21,7 +21,7 @@
|
|||||||
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
||||||
InputMetadata to extract the original 2D shape of the input.
|
InputMetadata to extract the original 2D shape of the input.
|
||||||
"""
|
"""
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -31,13 +31,14 @@ from vllm.model_executor.input_metadata import InputMetadata
|
|||||||
from vllm.model_executor.layers.activation import get_act_fn
|
from vllm.model_executor.layers.activation import get_act_fn
|
||||||
from vllm.model_executor.layers.attention import PagedAttention
|
from vllm.model_executor.layers.attention import PagedAttention
|
||||||
from vllm.model_executor.layers.sampler import Sampler
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
|
from vllm.model_executor.weight_utils import (
|
||||||
load_tensor_parallel_weights)
|
convert_pyslice_to_tensor, hf_model_weights_iterator,
|
||||||
|
load_padded_tensor_parallel_vocab, load_tensor_parallel_weights)
|
||||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||||
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
||||||
from vllm.sequence import SequenceOutputs
|
from vllm.sequence import SamplerOutput
|
||||||
|
|
||||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||||
|
|
||||||
@ -217,27 +218,27 @@ class GPT2LMHeadModel(nn.Module):
|
|||||||
kv_caches: List[KVCache],
|
kv_caches: List[KVCache],
|
||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
cache_events: Optional[List[torch.cuda.Event]],
|
cache_events: Optional[List[torch.cuda.Event]],
|
||||||
) -> Dict[int, SequenceOutputs]:
|
) -> SamplerOutput:
|
||||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||||
input_metadata, cache_events)
|
input_metadata, cache_events)
|
||||||
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
|
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
|
||||||
input_metadata)
|
input_metadata)
|
||||||
return next_tokens
|
return next_tokens
|
||||||
|
|
||||||
_column_parallel_weights = ["wte.weight", "c_fc.weight", "c_fc.bias"]
|
_column_parallel_weights = ["c_fc.weight", "c_fc.bias"]
|
||||||
_row_parallel_weights = ["c_proj.weight"]
|
_row_parallel_weights = ["c_proj.weight"]
|
||||||
|
|
||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
use_np_cache: bool = False):
|
load_format: str = "auto"):
|
||||||
tensor_model_parallel_world_size = (
|
tensor_model_parallel_world_size = (
|
||||||
get_tensor_model_parallel_world_size())
|
get_tensor_model_parallel_world_size())
|
||||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
|
|
||||||
for name, loaded_weight in hf_model_weights_iterator(
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
model_name_or_path, cache_dir, use_np_cache):
|
model_name_or_path, cache_dir, load_format):
|
||||||
if "lm_head.weight" in name:
|
if "lm_head.weight" in name:
|
||||||
# GPT-2 ties the weights of the embedding layer and the final
|
# GPT-2 ties the weights of the embedding layer and the final
|
||||||
# linear layer.
|
# linear layer.
|
||||||
@ -250,6 +251,8 @@ class GPT2LMHeadModel(nn.Module):
|
|||||||
if not name.startswith("transformer."):
|
if not name.startswith("transformer."):
|
||||||
name = "transformer." + name
|
name = "transformer." + name
|
||||||
|
|
||||||
|
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
||||||
|
|
||||||
# The HF's GPT-2 implementation uses Conv1D instead of Linear.
|
# The HF's GPT-2 implementation uses Conv1D instead of Linear.
|
||||||
# Because of this, we need to transpose the weights.
|
# Because of this, we need to transpose the weights.
|
||||||
for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
|
for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
|
||||||
@ -261,14 +264,9 @@ class GPT2LMHeadModel(nn.Module):
|
|||||||
param = state_dict[name]
|
param = state_dict[name]
|
||||||
|
|
||||||
if name == "transformer.wte.weight":
|
if name == "transformer.wte.weight":
|
||||||
# Consider padding in the vocab size.
|
load_padded_tensor_parallel_vocab(param, loaded_weight,
|
||||||
padded_vocab_size = (param.shape[0] *
|
tensor_model_parallel_rank)
|
||||||
tensor_model_parallel_world_size)
|
continue
|
||||||
num_extra_rows = padded_vocab_size - self.config.vocab_size
|
|
||||||
extra_rows = torch.empty(num_extra_rows,
|
|
||||||
loaded_weight.shape[1])
|
|
||||||
extra_rows = extra_rows.to(loaded_weight)
|
|
||||||
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
|
|
||||||
|
|
||||||
# For the fused QKV linear layer, manually shard the weights.
|
# For the fused QKV linear layer, manually shard the weights.
|
||||||
if "c_attn" in name:
|
if "c_attn" in name:
|
||||||
|
@ -22,7 +22,7 @@
|
|||||||
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
||||||
InputMetadata to extract the original 2D shape of the input.
|
InputMetadata to extract the original 2D shape of the input.
|
||||||
"""
|
"""
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -32,13 +32,14 @@ from vllm.model_executor.input_metadata import InputMetadata
|
|||||||
from vllm.model_executor.layers.activation import get_act_fn
|
from vllm.model_executor.layers.activation import get_act_fn
|
||||||
from vllm.model_executor.layers.attention import PagedAttention
|
from vllm.model_executor.layers.attention import PagedAttention
|
||||||
from vllm.model_executor.layers.sampler import Sampler
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
|
from vllm.model_executor.weight_utils import (
|
||||||
load_tensor_parallel_weights)
|
convert_pyslice_to_tensor, hf_model_weights_iterator,
|
||||||
|
load_padded_tensor_parallel_vocab, load_tensor_parallel_weights)
|
||||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||||
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
||||||
from vllm.sequence import SequenceOutputs
|
from vllm.sequence import SamplerOutput
|
||||||
|
|
||||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||||
|
|
||||||
@ -49,10 +50,11 @@ class GPTBigCodeAttention(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
total_num_heads = config.num_attention_heads
|
total_num_heads = config.num_attention_heads
|
||||||
tensor_model_parallel_world_size = (
|
self.tensor_model_parallel_world_size = (
|
||||||
get_tensor_model_parallel_world_size())
|
get_tensor_model_parallel_world_size())
|
||||||
assert total_num_heads % tensor_model_parallel_world_size == 0
|
assert total_num_heads % self.tensor_model_parallel_world_size == 0
|
||||||
self.num_heads = total_num_heads // tensor_model_parallel_world_size
|
self.num_heads = (total_num_heads //
|
||||||
|
self.tensor_model_parallel_world_size)
|
||||||
self.head_dim = self.hidden_size // total_num_heads
|
self.head_dim = self.hidden_size // total_num_heads
|
||||||
self.scale = self.head_dim**-0.5
|
self.scale = self.head_dim**-0.5
|
||||||
|
|
||||||
@ -101,7 +103,10 @@ class GPTBigCodeAttention(nn.Module):
|
|||||||
k, v = kv.split([self.kv_dim, self.kv_dim], dim=-1)
|
k, v = kv.split([self.kv_dim, self.kv_dim], dim=-1)
|
||||||
else:
|
else:
|
||||||
qkv, _ = self.c_attn(hidden_states)
|
qkv, _ = self.c_attn(hidden_states)
|
||||||
q, k, v = qkv.split([self.hidden_size, self.kv_dim, self.kv_dim],
|
q, k, v = qkv.split([
|
||||||
|
self.hidden_size // self.tensor_model_parallel_world_size,
|
||||||
|
self.kv_dim, self.kv_dim
|
||||||
|
],
|
||||||
dim=-1)
|
dim=-1)
|
||||||
key_cache, value_cache = kv_cache
|
key_cache, value_cache = kv_cache
|
||||||
attn_output = self.attn(q, k, v, key_cache, value_cache,
|
attn_output = self.attn(q, k, v, key_cache, value_cache,
|
||||||
@ -241,27 +246,27 @@ class GPTBigCodeForCausalLM(nn.Module):
|
|||||||
kv_caches: List[KVCache],
|
kv_caches: List[KVCache],
|
||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
cache_events: Optional[List[torch.cuda.Event]],
|
cache_events: Optional[List[torch.cuda.Event]],
|
||||||
) -> Dict[int, SequenceOutputs]:
|
) -> SamplerOutput:
|
||||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||||
input_metadata, cache_events)
|
input_metadata, cache_events)
|
||||||
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
|
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
|
||||||
input_metadata)
|
input_metadata)
|
||||||
return next_tokens
|
return next_tokens
|
||||||
|
|
||||||
_column_parallel_weights = ["wte.weight", "c_fc.weight", "c_fc.bias"]
|
_column_parallel_weights = ["c_fc.weight", "c_fc.bias"]
|
||||||
_row_parallel_weights = ["c_proj.weight"]
|
_row_parallel_weights = ["c_proj.weight"]
|
||||||
|
|
||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
use_np_cache: bool = False):
|
load_format: str = "auto"):
|
||||||
tensor_model_parallel_world_size = (
|
tensor_model_parallel_world_size = (
|
||||||
get_tensor_model_parallel_world_size())
|
get_tensor_model_parallel_world_size())
|
||||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
|
|
||||||
for name, loaded_weight in hf_model_weights_iterator(
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
model_name_or_path, cache_dir, use_np_cache):
|
model_name_or_path, cache_dir, load_format):
|
||||||
if "lm_head.weight" in name:
|
if "lm_head.weight" in name:
|
||||||
# GPT-2 ties the weights of the embedding layer and the final
|
# GPT-2 ties the weights of the embedding layer and the final
|
||||||
# linear layer.
|
# linear layer.
|
||||||
@ -290,6 +295,7 @@ class GPTBigCodeForCausalLM(nn.Module):
|
|||||||
head_start = tensor_model_parallel_rank * num_heads
|
head_start = tensor_model_parallel_rank * num_heads
|
||||||
head_end = (tensor_model_parallel_rank + 1) * num_heads
|
head_end = (tensor_model_parallel_rank + 1) * num_heads
|
||||||
|
|
||||||
|
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
||||||
wq, wk, wv = torch.split(
|
wq, wk, wv = torch.split(
|
||||||
loaded_weight, [hidden_size, total_kv_size, total_kv_size],
|
loaded_weight, [hidden_size, total_kv_size, total_kv_size],
|
||||||
dim=0)
|
dim=0)
|
||||||
@ -324,14 +330,9 @@ class GPTBigCodeForCausalLM(nn.Module):
|
|||||||
param = state_dict[name]
|
param = state_dict[name]
|
||||||
|
|
||||||
if name == "transformer.wte.weight":
|
if name == "transformer.wte.weight":
|
||||||
# Consider padding in the vocab size.
|
load_padded_tensor_parallel_vocab(param, loaded_weight,
|
||||||
padded_vocab_size = param.shape[
|
tensor_model_parallel_rank)
|
||||||
0] * tensor_model_parallel_world_size
|
continue
|
||||||
num_extra_rows = padded_vocab_size - self.config.vocab_size
|
|
||||||
extra_rows = torch.empty(num_extra_rows,
|
|
||||||
loaded_weight.shape[1])
|
|
||||||
extra_rows = extra_rows.to(loaded_weight)
|
|
||||||
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
|
|
||||||
|
|
||||||
load_tensor_parallel_weights(param, loaded_weight, name,
|
load_tensor_parallel_weights(param, loaded_weight, name,
|
||||||
self._column_parallel_weights,
|
self._column_parallel_weights,
|
||||||
|
@ -20,7 +20,7 @@
|
|||||||
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
||||||
InputMetadata to extract the original 2D shape of the input.
|
InputMetadata to extract the original 2D shape of the input.
|
||||||
"""
|
"""
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -36,7 +36,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
|
|||||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||||
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
||||||
from vllm.sequence import SequenceOutputs
|
from vllm.sequence import SamplerOutput
|
||||||
|
|
||||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||||
|
|
||||||
@ -67,8 +67,11 @@ class GPTJAttention(nn.Module):
|
|||||||
scaling = self.head_size**-0.5
|
scaling = self.head_size**-0.5
|
||||||
assert getattr(config, "rotary", True)
|
assert getattr(config, "rotary", True)
|
||||||
assert config.rotary_dim % 2 == 0
|
assert config.rotary_dim % 2 == 0
|
||||||
self.attn = PagedAttentionWithRoPE(self.num_heads, self.head_size,
|
self.attn = PagedAttentionWithRoPE(self.num_heads,
|
||||||
scaling, config.rotary_dim)
|
self.head_size,
|
||||||
|
scaling,
|
||||||
|
config.rotary_dim,
|
||||||
|
is_neox_style=False)
|
||||||
self.warmup = False
|
self.warmup = False
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -203,7 +206,7 @@ class GPTJForCausalLM(nn.Module):
|
|||||||
kv_caches: List[KVCache],
|
kv_caches: List[KVCache],
|
||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
cache_events: Optional[List[torch.cuda.Event]],
|
cache_events: Optional[List[torch.cuda.Event]],
|
||||||
) -> Dict[int, SequenceOutputs]:
|
) -> SamplerOutput:
|
||||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||||
input_metadata, cache_events)
|
input_metadata, cache_events)
|
||||||
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
||||||
@ -219,11 +222,11 @@ class GPTJForCausalLM(nn.Module):
|
|||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
use_np_cache: bool = False):
|
load_format: str = "auto"):
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
for name, loaded_weight in hf_model_weights_iterator(
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
model_name_or_path, cache_dir, use_np_cache):
|
model_name_or_path, cache_dir, load_format):
|
||||||
if "attn.bias" in name or "attn.masked_bias" in name:
|
if "attn.bias" in name or "attn.masked_bias" in name:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@ -20,7 +20,7 @@
|
|||||||
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
||||||
InputMetadata to extract the original 2D shape of the input.
|
InputMetadata to extract the original 2D shape of the input.
|
||||||
"""
|
"""
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -36,7 +36,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
|
|||||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||||
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
||||||
from vllm.sequence import SequenceOutputs
|
from vllm.sequence import SamplerOutput
|
||||||
|
|
||||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||||
|
|
||||||
@ -215,7 +215,7 @@ class GPTNeoXForCausalLM(nn.Module):
|
|||||||
kv_caches: List[KVCache],
|
kv_caches: List[KVCache],
|
||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
cache_events: Optional[List[torch.cuda.Event]],
|
cache_events: Optional[List[torch.cuda.Event]],
|
||||||
) -> Dict[int, SequenceOutputs]:
|
) -> SamplerOutput:
|
||||||
hidden_states = self.gpt_neox(input_ids, positions, kv_caches,
|
hidden_states = self.gpt_neox(input_ids, positions, kv_caches,
|
||||||
input_metadata, cache_events)
|
input_metadata, cache_events)
|
||||||
next_tokens = self.sampler(self.embed_out.weight, hidden_states,
|
next_tokens = self.sampler(self.embed_out.weight, hidden_states,
|
||||||
@ -231,11 +231,11 @@ class GPTNeoXForCausalLM(nn.Module):
|
|||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
use_np_cache: bool = False):
|
load_format: str = "auto"):
|
||||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
for name, loaded_weight in hf_model_weights_iterator(
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
model_name_or_path, cache_dir, use_np_cache):
|
model_name_or_path, cache_dir, load_format):
|
||||||
if ("attention.bias" in name or "attention.masked_bias" in name
|
if ("attention.bias" in name or "attention.masked_bias" in name
|
||||||
or "rotary_emb.inv_freq" in name):
|
or "rotary_emb.inv_freq" in name):
|
||||||
continue
|
continue
|
||||||
|
292
vllm/model_executor/models/internlm.py
Normal file
292
vllm/model_executor/models/internlm.py
Normal file
@ -0,0 +1,292 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from transformers import LlamaConfig
|
||||||
|
|
||||||
|
from vllm.model_executor.input_metadata import InputMetadata
|
||||||
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
|
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
|
||||||
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
|
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||||
|
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||||
|
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||||
|
ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding)
|
||||||
|
from vllm.model_executor.weight_utils import (
|
||||||
|
hf_model_weights_iterator, load_padded_tensor_parallel_vocab,
|
||||||
|
load_tensor_parallel_weights)
|
||||||
|
from vllm.sequence import SamplerOutput
|
||||||
|
|
||||||
|
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||||
|
|
||||||
|
|
||||||
|
class InternLMMLP(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
intermediate_size: int,
|
||||||
|
hidden_act: str,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.gate_up_proj = ColumnParallelLinear(hidden_size,
|
||||||
|
2 * intermediate_size,
|
||||||
|
bias=False,
|
||||||
|
gather_output=False,
|
||||||
|
perform_initialization=False)
|
||||||
|
self.down_proj = RowParallelLinear(intermediate_size,
|
||||||
|
hidden_size,
|
||||||
|
bias=False,
|
||||||
|
input_is_parallel=True,
|
||||||
|
perform_initialization=False)
|
||||||
|
if hidden_act != "silu":
|
||||||
|
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||||
|
"Only silu is supported for now.")
|
||||||
|
self.act_fn = SiluAndMul()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
gate_up, _ = self.gate_up_proj(x)
|
||||||
|
x = self.act_fn(gate_up)
|
||||||
|
x, _ = self.down_proj(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class InternLMAttention(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
num_heads: int,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
tensor_model_parallel_world_size = (
|
||||||
|
get_tensor_model_parallel_world_size())
|
||||||
|
self.total_num_heads = num_heads
|
||||||
|
assert self.total_num_heads % tensor_model_parallel_world_size == 0
|
||||||
|
self.num_heads = (self.total_num_heads //
|
||||||
|
tensor_model_parallel_world_size)
|
||||||
|
self.head_dim = hidden_size // self.total_num_heads
|
||||||
|
self.scaling = self.head_dim**-0.5
|
||||||
|
|
||||||
|
self.qkv_proj = ColumnParallelLinear(
|
||||||
|
hidden_size,
|
||||||
|
3 * self.total_num_heads * self.head_dim,
|
||||||
|
bias=True,
|
||||||
|
gather_output=False,
|
||||||
|
perform_initialization=False,
|
||||||
|
)
|
||||||
|
self.o_proj = RowParallelLinear(
|
||||||
|
self.total_num_heads * self.head_dim,
|
||||||
|
hidden_size,
|
||||||
|
bias=True,
|
||||||
|
input_is_parallel=True,
|
||||||
|
perform_initialization=False,
|
||||||
|
)
|
||||||
|
self.attn = PagedAttentionWithRoPE(self.num_heads,
|
||||||
|
self.head_dim,
|
||||||
|
self.scaling,
|
||||||
|
rotary_dim=self.head_dim)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
kv_cache: KVCache,
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
cache_event: Optional[torch.cuda.Event],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
|
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||||
|
k_cache, v_cache = kv_cache
|
||||||
|
attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
|
||||||
|
input_metadata, cache_event)
|
||||||
|
output, _ = self.o_proj(attn_output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class InternLMDecoderLayer(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config: LlamaConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.self_attn = InternLMAttention(
|
||||||
|
hidden_size=self.hidden_size,
|
||||||
|
num_heads=config.num_attention_heads,
|
||||||
|
)
|
||||||
|
self.mlp = InternLMMLP(
|
||||||
|
hidden_size=self.hidden_size,
|
||||||
|
intermediate_size=config.intermediate_size,
|
||||||
|
hidden_act=config.hidden_act,
|
||||||
|
)
|
||||||
|
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||||
|
eps=config.rms_norm_eps)
|
||||||
|
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||||
|
eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
kv_cache: KVCache,
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
cache_event: Optional[torch.cuda.Event],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# Self Attention
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
hidden_states = self.self_attn(
|
||||||
|
positions=positions,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
input_metadata=input_metadata,
|
||||||
|
cache_event=cache_event,
|
||||||
|
)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
# Fully Connected
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class InternLMModel(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config: LlamaConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.padding_idx = config.pad_token_id
|
||||||
|
self.vocab_size = config.vocab_size
|
||||||
|
|
||||||
|
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
||||||
|
self.embed_tokens = VocabParallelEmbedding(
|
||||||
|
vocab_size, config.hidden_size, perform_initialization=False)
|
||||||
|
self.layers = nn.ModuleList([
|
||||||
|
InternLMDecoderLayer(config)
|
||||||
|
for _ in range(config.num_hidden_layers)
|
||||||
|
])
|
||||||
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
kv_caches: List[KVCache],
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
cache_events: Optional[List[torch.cuda.Event]],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
for i in range(len(self.layers)):
|
||||||
|
if cache_events is None:
|
||||||
|
cache_event = None
|
||||||
|
else:
|
||||||
|
cache_event = cache_events[i]
|
||||||
|
layer = self.layers[i]
|
||||||
|
hidden_states = layer(
|
||||||
|
positions,
|
||||||
|
hidden_states,
|
||||||
|
kv_caches[i],
|
||||||
|
input_metadata,
|
||||||
|
cache_event,
|
||||||
|
)
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class InternLMForCausalLM(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.model = InternLMModel(config)
|
||||||
|
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
||||||
|
self.lm_head = ColumnParallelLinear(config.hidden_size,
|
||||||
|
vocab_size,
|
||||||
|
bias=False,
|
||||||
|
gather_output=False,
|
||||||
|
perform_initialization=False)
|
||||||
|
self.sampler = Sampler(config.vocab_size)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
kv_caches: List[KVCache],
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
cache_events: Optional[List[torch.cuda.Event]],
|
||||||
|
) -> SamplerOutput:
|
||||||
|
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||||
|
input_metadata, cache_events)
|
||||||
|
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
||||||
|
input_metadata)
|
||||||
|
return next_tokens
|
||||||
|
|
||||||
|
_column_parallel_weights = [
|
||||||
|
"qkv_proj.weight", "gate_proj.weight", "up_proj.weight"
|
||||||
|
]
|
||||||
|
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
|
||||||
|
|
||||||
|
def load_weights(self,
|
||||||
|
model_name_or_path: str,
|
||||||
|
cache_dir: Optional[str] = None,
|
||||||
|
load_format: str = "auto"):
|
||||||
|
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||||
|
state_dict = self.state_dict()
|
||||||
|
|
||||||
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
|
model_name_or_path, cache_dir, load_format):
|
||||||
|
if "rotary_emb.inv_freq" in name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if "embed_tokens" in name or "lm_head" in name:
|
||||||
|
param = state_dict[name]
|
||||||
|
load_padded_tensor_parallel_vocab(param, loaded_weight,
|
||||||
|
tensor_model_parallel_rank)
|
||||||
|
continue
|
||||||
|
|
||||||
|
is_attention_weight = False
|
||||||
|
for stride_id, att_weight_name in enumerate(
|
||||||
|
["q_proj", "k_proj", "v_proj"]):
|
||||||
|
if att_weight_name not in name:
|
||||||
|
continue
|
||||||
|
param = state_dict[name.replace(att_weight_name, "qkv_proj")]
|
||||||
|
shard_size = param.shape[0] // 3
|
||||||
|
loaded_weight = loaded_weight[
|
||||||
|
shard_size * tensor_model_parallel_rank:shard_size *
|
||||||
|
(tensor_model_parallel_rank + 1)]
|
||||||
|
param_slice = param.data[shard_size * stride_id:shard_size *
|
||||||
|
(stride_id + 1)]
|
||||||
|
assert param_slice.shape == loaded_weight.shape
|
||||||
|
param_slice.copy_(loaded_weight)
|
||||||
|
is_attention_weight = True
|
||||||
|
break
|
||||||
|
if is_attention_weight:
|
||||||
|
continue
|
||||||
|
|
||||||
|
is_gate_up_weight = False
|
||||||
|
for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
param = state_dict[name.replace(weight_name, "gate_up_proj")]
|
||||||
|
shard_size = param.shape[0] // 2
|
||||||
|
loaded_weight = loaded_weight[
|
||||||
|
shard_size * tensor_model_parallel_rank:shard_size *
|
||||||
|
(tensor_model_parallel_rank + 1)]
|
||||||
|
param_slice = param.data[shard_size * stride_id:shard_size *
|
||||||
|
(stride_id + 1)]
|
||||||
|
assert param_slice.shape == loaded_weight.shape
|
||||||
|
param_slice.copy_(loaded_weight)
|
||||||
|
is_gate_up_weight = True
|
||||||
|
break
|
||||||
|
if is_gate_up_weight:
|
||||||
|
continue
|
||||||
|
|
||||||
|
param = state_dict[name]
|
||||||
|
load_tensor_parallel_weights(param, loaded_weight, name,
|
||||||
|
self._column_parallel_weights,
|
||||||
|
self._row_parallel_weights,
|
||||||
|
tensor_model_parallel_rank)
|
@ -25,7 +25,7 @@
|
|||||||
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
||||||
InputMetadata to extract the original 2D shape of the input.
|
InputMetadata to extract the original 2D shape of the input.
|
||||||
"""
|
"""
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -36,13 +36,14 @@ from vllm.model_executor.layers.activation import SiluAndMul
|
|||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
|
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
|
||||||
from vllm.model_executor.layers.sampler import Sampler
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
|
from vllm.model_executor.weight_utils import (
|
||||||
load_tensor_parallel_weights)
|
load_tensor_parallel_weights, load_padded_tensor_parallel_vocab,
|
||||||
|
hf_model_weights_iterator)
|
||||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||||
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
||||||
from vllm.sequence import SequenceOutputs
|
from vllm.sequence import SamplerOutput
|
||||||
|
|
||||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||||
|
|
||||||
@ -85,6 +86,7 @@ class LlamaAttention(nn.Module):
|
|||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
num_kv_heads: int,
|
num_kv_heads: int,
|
||||||
|
rope_theta: float = 10000,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
@ -99,6 +101,7 @@ class LlamaAttention(nn.Module):
|
|||||||
self.q_size = self.num_heads * self.head_dim
|
self.q_size = self.num_heads * self.head_dim
|
||||||
self.kv_size = self.num_kv_heads * self.head_dim
|
self.kv_size = self.num_kv_heads * self.head_dim
|
||||||
self.scaling = self.head_dim**-0.5
|
self.scaling = self.head_dim**-0.5
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
|
||||||
self.qkv_proj = ColumnParallelLinear(
|
self.qkv_proj = ColumnParallelLinear(
|
||||||
hidden_size,
|
hidden_size,
|
||||||
@ -118,6 +121,7 @@ class LlamaAttention(nn.Module):
|
|||||||
self.attn = PagedAttentionWithRoPE(self.num_heads,
|
self.attn = PagedAttentionWithRoPE(self.num_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
self.scaling,
|
self.scaling,
|
||||||
|
base=self.rope_theta,
|
||||||
rotary_dim=self.head_dim,
|
rotary_dim=self.head_dim,
|
||||||
num_kv_heads=self.num_kv_heads)
|
num_kv_heads=self.num_kv_heads)
|
||||||
|
|
||||||
@ -143,10 +147,13 @@ class LlamaDecoderLayer(nn.Module):
|
|||||||
def __init__(self, config: LlamaConfig):
|
def __init__(self, config: LlamaConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
|
# Requires transformers > 4.32.0
|
||||||
|
rope_theta = getattr(config, "rope_theta", 10000)
|
||||||
self.self_attn = LlamaAttention(
|
self.self_attn = LlamaAttention(
|
||||||
hidden_size=self.hidden_size,
|
hidden_size=self.hidden_size,
|
||||||
num_heads=config.num_attention_heads,
|
num_heads=config.num_attention_heads,
|
||||||
num_kv_heads=config.num_key_value_heads,
|
num_kv_heads=config.num_key_value_heads,
|
||||||
|
rope_theta=rope_theta,
|
||||||
)
|
)
|
||||||
self.mlp = LlamaMLP(
|
self.mlp = LlamaMLP(
|
||||||
hidden_size=self.hidden_size,
|
hidden_size=self.hidden_size,
|
||||||
@ -249,7 +256,7 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
kv_caches: List[KVCache],
|
kv_caches: List[KVCache],
|
||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
cache_events: Optional[List[torch.cuda.Event]],
|
cache_events: Optional[List[torch.cuda.Event]],
|
||||||
) -> Dict[int, SequenceOutputs]:
|
) -> SamplerOutput:
|
||||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||||
input_metadata, cache_events)
|
input_metadata, cache_events)
|
||||||
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
||||||
@ -257,15 +264,14 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
return next_tokens
|
return next_tokens
|
||||||
|
|
||||||
_column_parallel_weights = [
|
_column_parallel_weights = [
|
||||||
"embed_tokens.weight", "lm_head.weight", "qkv_proj.weight",
|
"qkv_proj.weight", "gate_proj.weight", "up_proj.weight"
|
||||||
"gate_proj.weight", "up_proj.weight"
|
|
||||||
]
|
]
|
||||||
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
|
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
|
||||||
|
|
||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
use_np_cache: bool = False):
|
load_format: str = "auto"):
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||||
q_proj_shard_size = (self.config.hidden_size // tp_size)
|
q_proj_shard_size = (self.config.hidden_size // tp_size)
|
||||||
@ -282,20 +288,10 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
|
|
||||||
for name, loaded_weight in hf_model_weights_iterator(
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
model_name_or_path, cache_dir, use_np_cache):
|
model_name_or_path, cache_dir, load_format):
|
||||||
if "rotary_emb.inv_freq" in name:
|
if "rotary_emb.inv_freq" in name:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if "embed_tokens" in name or "lm_head" in name:
|
|
||||||
param = state_dict[name]
|
|
||||||
# Consider padding in the vocab size.
|
|
||||||
padded_vocab_size = (param.shape[0] * tp_size)
|
|
||||||
num_extra_rows = padded_vocab_size - self.config.vocab_size
|
|
||||||
extra_rows = torch.empty(num_extra_rows,
|
|
||||||
loaded_weight.shape[1])
|
|
||||||
extra_rows = extra_rows.to(loaded_weight)
|
|
||||||
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
|
|
||||||
|
|
||||||
is_attention_weight = False
|
is_attention_weight = False
|
||||||
for weight_name, shard_size, offset in attention_weight_specs:
|
for weight_name, shard_size, offset in attention_weight_specs:
|
||||||
if weight_name not in name:
|
if weight_name not in name:
|
||||||
@ -333,6 +329,12 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
param = state_dict[name]
|
param = state_dict[name]
|
||||||
|
|
||||||
|
if "embed_tokens" in name or "lm_head" in name:
|
||||||
|
load_padded_tensor_parallel_vocab(param, loaded_weight,
|
||||||
|
tensor_model_parallel_rank)
|
||||||
|
continue
|
||||||
|
|
||||||
load_tensor_parallel_weights(param, loaded_weight, name,
|
load_tensor_parallel_weights(param, loaded_weight, name,
|
||||||
self._column_parallel_weights,
|
self._column_parallel_weights,
|
||||||
self._row_parallel_weights,
|
self._row_parallel_weights,
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main
|
# Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main
|
||||||
import math
|
import math
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -10,13 +10,14 @@ from vllm.model_executor.input_metadata import InputMetadata
|
|||||||
from vllm.model_executor.layers.activation import get_act_fn
|
from vllm.model_executor.layers.activation import get_act_fn
|
||||||
from vllm.model_executor.layers.attention import PagedAttentionWithALiBi
|
from vllm.model_executor.layers.attention import PagedAttentionWithALiBi
|
||||||
from vllm.model_executor.layers.sampler import Sampler
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
|
from vllm.model_executor.weight_utils import (convert_pyslice_to_tensor,
|
||||||
|
hf_model_weights_iterator,
|
||||||
load_tensor_parallel_weights)
|
load_tensor_parallel_weights)
|
||||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||||
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
||||||
from vllm.sequence import SequenceOutputs
|
from vllm.sequence import SamplerOutput
|
||||||
from vllm.transformers_utils.configs.mpt import MPTConfig
|
from vllm.transformers_utils.configs.mpt import MPTConfig
|
||||||
|
|
||||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||||
@ -230,7 +231,7 @@ class MPTForCausalLM(nn.Module):
|
|||||||
kv_caches: List[KVCache],
|
kv_caches: List[KVCache],
|
||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
cache_events: Optional[List[torch.cuda.Event]],
|
cache_events: Optional[List[torch.cuda.Event]],
|
||||||
) -> Dict[int, SequenceOutputs]:
|
) -> SamplerOutput:
|
||||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||||
input_metadata, cache_events)
|
input_metadata, cache_events)
|
||||||
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
|
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
|
||||||
@ -243,12 +244,12 @@ class MPTForCausalLM(nn.Module):
|
|||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
use_np_cache: bool = False):
|
load_format: str = "auto"):
|
||||||
tp_world_size = get_tensor_model_parallel_world_size()
|
tp_world_size = get_tensor_model_parallel_world_size()
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
for name, loaded_weight in hf_model_weights_iterator(
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
model_name_or_path, cache_dir, use_np_cache):
|
model_name_or_path, cache_dir, load_format):
|
||||||
if "Wqkv" in name:
|
if "Wqkv" in name:
|
||||||
# NOTE(woosuk): MPT's fused QKV has the shape of
|
# NOTE(woosuk): MPT's fused QKV has the shape of
|
||||||
# [3 * num_heads * head_size, hidden_size].
|
# [3 * num_heads * head_size, hidden_size].
|
||||||
@ -260,7 +261,7 @@ class MPTForCausalLM(nn.Module):
|
|||||||
num_heads = total_num_heads // tp_world_size
|
num_heads = total_num_heads // tp_world_size
|
||||||
head_start = tp_rank * num_heads
|
head_start = tp_rank * num_heads
|
||||||
head_end = (tp_rank + 1) * num_heads
|
head_end = (tp_rank + 1) * num_heads
|
||||||
|
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
||||||
if name.endswith(".weight"):
|
if name.endswith(".weight"):
|
||||||
loaded_weight = loaded_weight.view(3, total_num_heads,
|
loaded_weight = loaded_weight.view(3, total_num_heads,
|
||||||
head_size, hidden_size)
|
head_size, hidden_size)
|
||||||
|
@ -21,7 +21,7 @@
|
|||||||
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
||||||
InputMetadata to extract the original 2D shape of the input.
|
InputMetadata to extract the original 2D shape of the input.
|
||||||
"""
|
"""
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -37,7 +37,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
|
|||||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||||
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
||||||
from vllm.sequence import SequenceOutputs
|
from vllm.sequence import SamplerOutput
|
||||||
|
|
||||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||||
|
|
||||||
@ -282,7 +282,7 @@ class OPTForCausalLM(nn.Module):
|
|||||||
kv_caches: List[KVCache],
|
kv_caches: List[KVCache],
|
||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
cache_events: Optional[List[torch.cuda.Event]],
|
cache_events: Optional[List[torch.cuda.Event]],
|
||||||
) -> Dict[int, SequenceOutputs]:
|
) -> SamplerOutput:
|
||||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||||
input_metadata, cache_events)
|
input_metadata, cache_events)
|
||||||
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
|
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
|
||||||
@ -297,12 +297,12 @@ class OPTForCausalLM(nn.Module):
|
|||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
use_np_cache: bool = False):
|
load_format: str = "auto"):
|
||||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
|
|
||||||
for name, loaded_weight in hf_model_weights_iterator(
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
model_name_or_path, cache_dir, use_np_cache):
|
model_name_or_path, cache_dir, load_format):
|
||||||
if "lm_head.weight" in name:
|
if "lm_head.weight" in name:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
316
vllm/model_executor/models/qwen.py
Normal file
316
vllm/model_executor/models/qwen.py
Normal file
@ -0,0 +1,316 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Adapted from
|
||||||
|
# https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py
|
||||||
|
# Copyright (c) Alibaba Cloud.
|
||||||
|
# LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE
|
||||||
|
"""Inference-only QWen model compatible with HuggingFace weights.
|
||||||
|
|
||||||
|
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
||||||
|
InputMetadata to extract the original 2D shape of the input.
|
||||||
|
"""
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from vllm.model_executor.input_metadata import InputMetadata
|
||||||
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
|
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
|
||||||
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
|
from vllm.model_executor.weight_utils import (
|
||||||
|
convert_pyslice_to_tensor,
|
||||||
|
hf_model_weights_iterator,
|
||||||
|
load_padded_tensor_parallel_vocab,
|
||||||
|
load_tensor_parallel_weights,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||||
|
get_tensor_model_parallel_rank,
|
||||||
|
get_tensor_model_parallel_world_size,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||||
|
VocabParallelEmbedding,
|
||||||
|
ColumnParallelLinear,
|
||||||
|
RowParallelLinear,
|
||||||
|
)
|
||||||
|
from vllm.sequence import SamplerOutput
|
||||||
|
from vllm.transformers_utils.configs.qwen import QWenConfig
|
||||||
|
|
||||||
|
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||||
|
|
||||||
|
|
||||||
|
class QWenMLP(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
intermediate_size: int,
|
||||||
|
hidden_act: str = "silu",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.gate_up_proj = ColumnParallelLinear(
|
||||||
|
hidden_size,
|
||||||
|
2 * intermediate_size,
|
||||||
|
bias=False,
|
||||||
|
gather_output=False,
|
||||||
|
perform_initialization=False,
|
||||||
|
)
|
||||||
|
self.c_proj = RowParallelLinear(
|
||||||
|
intermediate_size,
|
||||||
|
hidden_size,
|
||||||
|
bias=False,
|
||||||
|
input_is_parallel=True,
|
||||||
|
perform_initialization=False,
|
||||||
|
)
|
||||||
|
if hidden_act != "silu":
|
||||||
|
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||||
|
"Only silu is supported for now.")
|
||||||
|
self.act_fn = SiluAndMul()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
gate_up, _ = self.gate_up_proj(x)
|
||||||
|
x = self.act_fn(gate_up)
|
||||||
|
x, _ = self.c_proj(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class QWenAttention(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, hidden_size: int, num_heads: int,
|
||||||
|
max_position_embeddings: int):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
|
||||||
|
)
|
||||||
|
self.total_num_heads = num_heads
|
||||||
|
assert self.total_num_heads % tensor_model_parallel_world_size == 0
|
||||||
|
self.num_heads = (self.total_num_heads //
|
||||||
|
tensor_model_parallel_world_size)
|
||||||
|
self.head_dim = hidden_size // self.total_num_heads
|
||||||
|
|
||||||
|
# pylint: disable=invalid-name
|
||||||
|
self.c_attn = ColumnParallelLinear(
|
||||||
|
hidden_size,
|
||||||
|
3 * hidden_size,
|
||||||
|
bias=True,
|
||||||
|
gather_output=False,
|
||||||
|
perform_initialization=False,
|
||||||
|
)
|
||||||
|
self.c_proj = RowParallelLinear(
|
||||||
|
self.total_num_heads * self.head_dim,
|
||||||
|
hidden_size,
|
||||||
|
bias=False,
|
||||||
|
input_is_parallel=True,
|
||||||
|
perform_initialization=False,
|
||||||
|
)
|
||||||
|
self.scaling = self.head_dim**-0.5
|
||||||
|
self.attn = PagedAttentionWithRoPE(
|
||||||
|
self.num_heads,
|
||||||
|
self.head_dim,
|
||||||
|
self.scaling,
|
||||||
|
rotary_dim=self.head_dim,
|
||||||
|
max_position=max_position_embeddings,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
kv_cache: KVCache,
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
cache_event: Optional[torch.cuda.Event],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
qkv, _ = self.c_attn(hidden_states)
|
||||||
|
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||||
|
|
||||||
|
k_cache, v_cache = kv_cache
|
||||||
|
attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
|
||||||
|
input_metadata, cache_event)
|
||||||
|
|
||||||
|
output, _ = self.c_proj(attn_output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class QWenBlock(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config: QWenConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.ln_1 = RMSNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
||||||
|
|
||||||
|
self.attn = QWenAttention(config.n_embd, config.num_attention_heads,
|
||||||
|
config.max_position_embeddings)
|
||||||
|
|
||||||
|
self.ln_2 = RMSNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
||||||
|
|
||||||
|
self.mlp = QWenMLP(config.n_embd, config.ffn_hidden_size // 2)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
kv_cache: KVCache,
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
cache_event: Optional[torch.cuda.Event],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# Self Attention
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.ln_1(hidden_states)
|
||||||
|
hidden_states = self.attn(
|
||||||
|
positions=positions,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
input_metadata=input_metadata,
|
||||||
|
cache_event=cache_event,
|
||||||
|
)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
# Fully Connected
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.ln_2(hidden_states)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class QWenModel(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config: QWenConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.vocab_size = config.vocab_size
|
||||||
|
|
||||||
|
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
||||||
|
self.wte = VocabParallelEmbedding(vocab_size,
|
||||||
|
config.n_embd,
|
||||||
|
perform_initialization=False)
|
||||||
|
self.h = nn.ModuleList(
|
||||||
|
[QWenBlock(config) for _ in range(config.num_hidden_layers)])
|
||||||
|
self.ln_f = RMSNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
kv_caches: List[KVCache],
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
cache_events: Optional[List[torch.cuda.Event]],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
hidden_states = self.wte(input_ids)
|
||||||
|
for i in range(len(self.h)):
|
||||||
|
if cache_events is None:
|
||||||
|
cache_event = None
|
||||||
|
else:
|
||||||
|
cache_event = cache_events[i]
|
||||||
|
layer = self.h[i]
|
||||||
|
hidden_states = layer(
|
||||||
|
positions,
|
||||||
|
hidden_states,
|
||||||
|
kv_caches[i],
|
||||||
|
input_metadata,
|
||||||
|
cache_event,
|
||||||
|
)
|
||||||
|
hidden_states = self.ln_f(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class QWenLMHeadModel(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config: QWenConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.transformer = QWenModel(config)
|
||||||
|
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
||||||
|
self.lm_head = ColumnParallelLinear(
|
||||||
|
config.n_embd,
|
||||||
|
vocab_size,
|
||||||
|
bias=False,
|
||||||
|
gather_output=False,
|
||||||
|
perform_initialization=False,
|
||||||
|
)
|
||||||
|
self.sampler = Sampler(config.vocab_size)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
kv_caches: List[KVCache],
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
cache_events: Optional[List[torch.cuda.Event]],
|
||||||
|
) -> SamplerOutput:
|
||||||
|
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||||
|
input_metadata, cache_events)
|
||||||
|
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
||||||
|
input_metadata)
|
||||||
|
return next_tokens
|
||||||
|
|
||||||
|
_column_parallel_weights = []
|
||||||
|
_row_parallel_weights = ["c_proj.weight"]
|
||||||
|
|
||||||
|
def load_weights(
|
||||||
|
self,
|
||||||
|
model_name_or_path: str,
|
||||||
|
cache_dir: Optional[str] = None,
|
||||||
|
load_format: str = "auto",
|
||||||
|
):
|
||||||
|
tp_world_size = get_tensor_model_parallel_world_size()
|
||||||
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
state_dict = self.state_dict()
|
||||||
|
|
||||||
|
for name, loaded_weight in hf_model_weights_iterator(
|
||||||
|
model_name_or_path, cache_dir, load_format):
|
||||||
|
if "rotary_emb.inv_freq" in name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
||||||
|
|
||||||
|
if "c_attn" in name:
|
||||||
|
total_num_heads = self.config.num_attention_heads
|
||||||
|
hidden_size = self.config.hidden_size
|
||||||
|
head_size = hidden_size // total_num_heads
|
||||||
|
num_heads = total_num_heads // tp_world_size
|
||||||
|
head_start = tp_rank * num_heads
|
||||||
|
head_end = (tp_rank + 1) * num_heads
|
||||||
|
|
||||||
|
if "weight" in name:
|
||||||
|
loaded_weight = loaded_weight.view(3, total_num_heads,
|
||||||
|
head_size, hidden_size)
|
||||||
|
loaded_weight = loaded_weight[:, head_start:head_end, :, :]
|
||||||
|
loaded_weight = loaded_weight.reshape(-1, hidden_size)
|
||||||
|
elif "bias" in name:
|
||||||
|
loaded_weight = loaded_weight.view(3, total_num_heads,
|
||||||
|
head_size)
|
||||||
|
loaded_weight = loaded_weight[:, head_start:head_end, :]
|
||||||
|
loaded_weight = loaded_weight.reshape(-1)
|
||||||
|
|
||||||
|
is_gate_up_weight = False
|
||||||
|
for stride_id, weight_name in enumerate(["w2", "w1"]):
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
param = state_dict[name.replace(weight_name, "gate_up_proj")]
|
||||||
|
shard_size = param.shape[0] // 2
|
||||||
|
loaded_weight = loaded_weight[shard_size * tp_rank:shard_size *
|
||||||
|
(tp_rank + 1)]
|
||||||
|
param_slice = param.data[shard_size * stride_id:shard_size *
|
||||||
|
(stride_id + 1)]
|
||||||
|
assert param_slice.shape == loaded_weight.shape
|
||||||
|
param_slice.copy_(loaded_weight)
|
||||||
|
is_gate_up_weight = True
|
||||||
|
break
|
||||||
|
if is_gate_up_weight:
|
||||||
|
continue
|
||||||
|
|
||||||
|
param = state_dict[name]
|
||||||
|
|
||||||
|
if "wte" in name or "lm_head" in name:
|
||||||
|
load_padded_tensor_parallel_vocab(param, loaded_weight,
|
||||||
|
tp_rank)
|
||||||
|
continue
|
||||||
|
|
||||||
|
load_tensor_parallel_weights(
|
||||||
|
param,
|
||||||
|
loaded_weight,
|
||||||
|
name,
|
||||||
|
self._column_parallel_weights,
|
||||||
|
self._row_parallel_weights,
|
||||||
|
tp_rank,
|
||||||
|
)
|
@ -3,13 +3,19 @@ import filelock
|
|||||||
import glob
|
import glob
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from typing import Iterator, List, Optional, Tuple
|
from collections import defaultdict
|
||||||
|
from typing import Iterator, List, Optional, Tuple, Any
|
||||||
|
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
from safetensors.torch import load_file, save_file, safe_open
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class Disabledtqdm(tqdm):
|
class Disabledtqdm(tqdm):
|
||||||
|
|
||||||
@ -17,43 +23,140 @@ class Disabledtqdm(tqdm):
|
|||||||
super().__init__(*args, **kwargs, disable=True)
|
super().__init__(*args, **kwargs, disable=True)
|
||||||
|
|
||||||
|
|
||||||
def hf_model_weights_iterator(
|
def get_lock(model_name_or_path: str, cache_dir: Optional[str] = None):
|
||||||
model_name_or_path: str,
|
|
||||||
cache_dir: Optional[str] = None,
|
|
||||||
use_np_cache: bool = False,
|
|
||||||
) -> Iterator[Tuple[str, torch.Tensor]]:
|
|
||||||
# Prepare file lock directory to prevent multiple processes from
|
|
||||||
# downloading the same model weights at the same time.
|
|
||||||
lock_dir = cache_dir if cache_dir is not None else "/tmp"
|
lock_dir = cache_dir if cache_dir is not None else "/tmp"
|
||||||
lock_file_name = model_name_or_path.replace("/", "-") + ".lock"
|
lock_file_name = model_name_or_path.replace("/", "-") + ".lock"
|
||||||
lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name))
|
lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name))
|
||||||
|
return lock
|
||||||
|
|
||||||
|
|
||||||
|
def _shared_pointers(tensors):
|
||||||
|
ptrs = defaultdict(list)
|
||||||
|
for k, v in tensors.items():
|
||||||
|
ptrs[v.data_ptr()].append(k)
|
||||||
|
failing = []
|
||||||
|
for _, names in ptrs.items():
|
||||||
|
if len(names) > 1:
|
||||||
|
failing.append(names)
|
||||||
|
return failing
|
||||||
|
|
||||||
|
|
||||||
|
def convert_bin_to_safetensor_file(
|
||||||
|
pt_filename: str,
|
||||||
|
sf_filename: str,
|
||||||
|
):
|
||||||
|
loaded = torch.load(pt_filename, map_location="cpu")
|
||||||
|
if "state_dict" in loaded:
|
||||||
|
loaded = loaded["state_dict"]
|
||||||
|
shared = _shared_pointers(loaded)
|
||||||
|
for shared_weights in shared:
|
||||||
|
for name in shared_weights[1:]:
|
||||||
|
loaded.pop(name)
|
||||||
|
|
||||||
|
# For tensors to be contiguous
|
||||||
|
loaded = {k: v.contiguous() for k, v in loaded.items()}
|
||||||
|
|
||||||
|
dirname = os.path.dirname(sf_filename)
|
||||||
|
os.makedirs(dirname, exist_ok=True)
|
||||||
|
save_file(loaded, sf_filename, metadata={"format": "pt"})
|
||||||
|
|
||||||
|
# check file size
|
||||||
|
sf_size = os.stat(sf_filename).st_size
|
||||||
|
pt_size = os.stat(pt_filename).st_size
|
||||||
|
if (sf_size - pt_size) / pt_size > 0.01:
|
||||||
|
raise RuntimeError(f"""The file size different is more than 1%:
|
||||||
|
- {sf_filename}: {sf_size}
|
||||||
|
- {pt_filename}: {pt_size}
|
||||||
|
""")
|
||||||
|
|
||||||
|
# check if the tensors are the same
|
||||||
|
reloaded = load_file(sf_filename)
|
||||||
|
for k in loaded:
|
||||||
|
pt_tensor = loaded[k]
|
||||||
|
sf_tensor = reloaded[k]
|
||||||
|
if not torch.equal(pt_tensor, sf_tensor):
|
||||||
|
raise RuntimeError(f"The output tensors do not match for key {k}")
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_hf_model_weights(
|
||||||
|
model_name_or_path: str,
|
||||||
|
cache_dir: Optional[str] = None,
|
||||||
|
use_safetensors: bool = False,
|
||||||
|
fall_back_to_pt: bool = True,
|
||||||
|
):
|
||||||
# Download model weights from huggingface.
|
# Download model weights from huggingface.
|
||||||
is_local = os.path.isdir(model_name_or_path)
|
is_local = os.path.isdir(model_name_or_path)
|
||||||
|
allow_patterns = "*.safetensors" if use_safetensors else "*.bin"
|
||||||
if not is_local:
|
if not is_local:
|
||||||
with lock:
|
# Use file lock to prevent multiple processes from
|
||||||
|
# downloading the same model weights at the same time.
|
||||||
|
with get_lock(model_name_or_path, cache_dir):
|
||||||
hf_folder = snapshot_download(model_name_or_path,
|
hf_folder = snapshot_download(model_name_or_path,
|
||||||
allow_patterns="*.bin",
|
allow_patterns=allow_patterns,
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
tqdm_class=Disabledtqdm)
|
tqdm_class=Disabledtqdm)
|
||||||
else:
|
else:
|
||||||
hf_folder = model_name_or_path
|
hf_folder = model_name_or_path
|
||||||
|
hf_weights_files = glob.glob(os.path.join(hf_folder, allow_patterns))
|
||||||
hf_bin_files = [
|
if not use_safetensors:
|
||||||
x for x in glob.glob(os.path.join(hf_folder, "*.bin"))
|
hf_weights_files = [
|
||||||
if not x.endswith("training_args.bin")
|
x for x in hf_weights_files if not x.endswith("training_args.bin")
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if len(hf_weights_files) == 0 and use_safetensors and fall_back_to_pt:
|
||||||
|
return prepare_hf_model_weights(model_name_or_path,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
use_safetensors=False,
|
||||||
|
fall_back_to_pt=False)
|
||||||
|
|
||||||
|
if len(hf_weights_files) == 0:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Cannot find any model weights with `{model_name_or_path}`")
|
||||||
|
|
||||||
|
return hf_folder, hf_weights_files, use_safetensors
|
||||||
|
|
||||||
|
|
||||||
|
def hf_model_weights_iterator(
|
||||||
|
model_name_or_path: str,
|
||||||
|
cache_dir: Optional[str] = None,
|
||||||
|
load_format: str = "auto",
|
||||||
|
) -> Iterator[Tuple[str, torch.Tensor]]:
|
||||||
|
use_safetensors = False
|
||||||
|
use_np_cache = False
|
||||||
|
fall_back_to_pt = False
|
||||||
|
if load_format == "auto":
|
||||||
|
use_safetensors = True
|
||||||
|
fall_back_to_pt = True
|
||||||
|
elif load_format == "safetensors":
|
||||||
|
use_safetensors = True
|
||||||
|
elif load_format == "pt":
|
||||||
|
pass
|
||||||
|
elif load_format == "npcache":
|
||||||
|
use_np_cache = True
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown load_format: {load_format}")
|
||||||
|
|
||||||
|
hf_folder, hf_weights_files, use_safetensors = prepare_hf_model_weights(
|
||||||
|
model_name_or_path,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
use_safetensors=use_safetensors,
|
||||||
|
fall_back_to_pt=fall_back_to_pt)
|
||||||
|
|
||||||
if use_np_cache:
|
if use_np_cache:
|
||||||
|
# Currently np_cache only support *.bin checkpoints
|
||||||
|
assert use_safetensors is False
|
||||||
|
|
||||||
# Convert the model weights from torch tensors to numpy arrays for
|
# Convert the model weights from torch tensors to numpy arrays for
|
||||||
# faster loading.
|
# faster loading.
|
||||||
np_folder = os.path.join(hf_folder, "np")
|
np_folder = os.path.join(hf_folder, "np")
|
||||||
os.makedirs(np_folder, exist_ok=True)
|
os.makedirs(np_folder, exist_ok=True)
|
||||||
weight_names_file = os.path.join(np_folder, "weight_names.json")
|
weight_names_file = os.path.join(np_folder, "weight_names.json")
|
||||||
with lock:
|
# Use file lock to prevent multiple processes from
|
||||||
|
# dumping the same model weights to numpy at the same time.
|
||||||
|
with get_lock(model_name_or_path, cache_dir):
|
||||||
if not os.path.exists(weight_names_file):
|
if not os.path.exists(weight_names_file):
|
||||||
weight_names = []
|
weight_names = []
|
||||||
for bin_file in hf_bin_files:
|
for bin_file in hf_weights_files:
|
||||||
state = torch.load(bin_file, map_location="cpu")
|
state = torch.load(bin_file, map_location="cpu")
|
||||||
for name, param in state.items():
|
for name, param in state.items():
|
||||||
param_path = os.path.join(np_folder, name)
|
param_path = os.path.join(np_folder, name)
|
||||||
@ -71,16 +174,52 @@ def hf_model_weights_iterator(
|
|||||||
with open(param_path, "rb") as f:
|
with open(param_path, "rb") as f:
|
||||||
param = np.load(f)
|
param = np.load(f)
|
||||||
yield name, torch.from_numpy(param)
|
yield name, torch.from_numpy(param)
|
||||||
|
elif use_safetensors:
|
||||||
|
for st_file in hf_weights_files:
|
||||||
|
with safe_open(st_file, framework="pt") as f:
|
||||||
|
for name in f.keys():
|
||||||
|
param = f.get_slice(name)
|
||||||
|
yield name, param
|
||||||
else:
|
else:
|
||||||
for bin_file in hf_bin_files:
|
for bin_file in hf_weights_files:
|
||||||
state = torch.load(bin_file, map_location="cpu")
|
state = torch.load(bin_file, map_location="cpu")
|
||||||
for name, param in state.items():
|
for name, param in state.items():
|
||||||
yield name, param
|
yield name, param
|
||||||
|
del state
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
|
||||||
|
"""convert PySafeSlice object from safetensors to torch.Tensor
|
||||||
|
|
||||||
|
PySafeSlice object supports indexing, which is done before loading the
|
||||||
|
actual tensor and can reduce the amount of memory being read into the
|
||||||
|
memory. However, it does not support more advanced functionalities
|
||||||
|
like `.view()` or `.t()`. Therefore, if we need to modify the loaded
|
||||||
|
tensor with these more complicated operators, we need to convert to
|
||||||
|
tensor first.
|
||||||
|
"""
|
||||||
|
if not isinstance(x, torch.Tensor):
|
||||||
|
x = x[:]
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def load_padded_tensor_parallel_vocab(
|
||||||
|
param: torch.Tensor,
|
||||||
|
loaded_weight: Any, # `torch.Tensor` or `PySafeSlice`
|
||||||
|
tensor_model_parallel_rank: int,
|
||||||
|
) -> None:
|
||||||
|
shard_size = param.shape[0]
|
||||||
|
start_idx = tensor_model_parallel_rank * shard_size
|
||||||
|
end_idx = (tensor_model_parallel_rank + 1) * shard_size
|
||||||
|
loaded_weight = loaded_weight[start_idx:end_idx]
|
||||||
|
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
||||||
|
param[:loaded_weight.shape[0]].copy_(loaded_weight)
|
||||||
|
|
||||||
|
|
||||||
def load_tensor_parallel_weights(
|
def load_tensor_parallel_weights(
|
||||||
param: torch.Tensor,
|
param: torch.Tensor,
|
||||||
loaded_weight: torch.Tensor,
|
loaded_weight: Any, # `torch.Tensor` or `PySafeSlice`
|
||||||
param_name: str,
|
param_name: str,
|
||||||
column_parallel_weight_names: List[str],
|
column_parallel_weight_names: List[str],
|
||||||
row_parallel_weight_names: List[str],
|
row_parallel_weight_names: List[str],
|
||||||
@ -100,6 +239,8 @@ def load_tensor_parallel_weights(
|
|||||||
end_idx = (tensor_model_parallel_rank + 1) * shard_size
|
end_idx = (tensor_model_parallel_rank + 1) * shard_size
|
||||||
loaded_weight = loaded_weight[:, start_idx:end_idx]
|
loaded_weight = loaded_weight[:, start_idx:end_idx]
|
||||||
break
|
break
|
||||||
|
|
||||||
|
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
||||||
assert param.shape == loaded_weight.shape, (
|
assert param.shape == loaded_weight.shape, (
|
||||||
f"{param_name} shape mismatch between model and checkpoint: "
|
f"{param_name} shape mismatch between model and checkpoint: "
|
||||||
f"{param.shape} != {loaded_weight.shape}")
|
f"{param.shape} != {loaded_weight.shape}")
|
||||||
|
@ -75,10 +75,12 @@ class RequestOutput:
|
|||||||
# Get the top-n sequences.
|
# Get the top-n sequences.
|
||||||
n = seq_group.sampling_params.n
|
n = seq_group.sampling_params.n
|
||||||
seqs = seq_group.get_seqs()
|
seqs = seq_group.get_seqs()
|
||||||
assert n <= len(seqs)
|
if seq_group.sampling_params.use_beam_search:
|
||||||
sorted_seqs = sorted(seqs,
|
sorting_key = lambda seq: seq.get_beam_search_score(
|
||||||
key=lambda seq: seq.get_cumulative_logprob(),
|
seq_group.sampling_params.length_penalty)
|
||||||
reverse=True)
|
else:
|
||||||
|
sorting_key = lambda seq: seq.get_cumulative_logprob()
|
||||||
|
sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
|
||||||
top_n_seqs = sorted_seqs[:n]
|
top_n_seqs = sorted_seqs[:n]
|
||||||
|
|
||||||
# Create the outputs.
|
# Create the outputs.
|
||||||
|
@ -34,6 +34,15 @@ class SamplingParams:
|
|||||||
top_k: Integer that controls the number of top tokens to consider. Set
|
top_k: Integer that controls the number of top tokens to consider. Set
|
||||||
to -1 to consider all tokens.
|
to -1 to consider all tokens.
|
||||||
use_beam_search: Whether to use beam search instead of sampling.
|
use_beam_search: Whether to use beam search instead of sampling.
|
||||||
|
length_penalty: Float that penalizes sequences based on their length.
|
||||||
|
Used in beam search.
|
||||||
|
early_stopping: Controls the stopping condition for beam search. It
|
||||||
|
accepts the following values: `True`, where the generation stops as
|
||||||
|
soon as there are `best_of` complete candidates; `False`, where an
|
||||||
|
heuristic is applied and the generation stops when is it very
|
||||||
|
unlikely to find better candidates; `"never"`, where the beam search
|
||||||
|
procedure only stops when there cannot be better candidates
|
||||||
|
(canonical beam search algorithm).
|
||||||
stop: List of strings that stop the generation when they are generated.
|
stop: List of strings that stop the generation when they are generated.
|
||||||
The returned output will not contain the stop strings.
|
The returned output will not contain the stop strings.
|
||||||
ignore_eos: Whether to ignore the EOS token and continue generating
|
ignore_eos: Whether to ignore the EOS token and continue generating
|
||||||
@ -52,6 +61,8 @@ class SamplingParams:
|
|||||||
top_p: float = 1.0,
|
top_p: float = 1.0,
|
||||||
top_k: int = -1,
|
top_k: int = -1,
|
||||||
use_beam_search: bool = False,
|
use_beam_search: bool = False,
|
||||||
|
length_penalty: float = 1.0,
|
||||||
|
early_stopping: Union[bool, str] = False,
|
||||||
stop: Union[None, str, List[str]] = None,
|
stop: Union[None, str, List[str]] = None,
|
||||||
ignore_eos: bool = False,
|
ignore_eos: bool = False,
|
||||||
max_tokens: int = 16,
|
max_tokens: int = 16,
|
||||||
@ -65,6 +76,8 @@ class SamplingParams:
|
|||||||
self.top_p = top_p
|
self.top_p = top_p
|
||||||
self.top_k = top_k
|
self.top_k = top_k
|
||||||
self.use_beam_search = use_beam_search
|
self.use_beam_search = use_beam_search
|
||||||
|
self.length_penalty = length_penalty
|
||||||
|
self.early_stopping = early_stopping
|
||||||
if stop is None:
|
if stop is None:
|
||||||
self.stop = []
|
self.stop = []
|
||||||
elif isinstance(stop, str):
|
elif isinstance(stop, str):
|
||||||
@ -77,8 +90,10 @@ class SamplingParams:
|
|||||||
|
|
||||||
self._verify_args()
|
self._verify_args()
|
||||||
if self.use_beam_search:
|
if self.use_beam_search:
|
||||||
self._verity_beam_search()
|
self._verify_beam_search()
|
||||||
elif self.temperature < _SAMPLING_EPS:
|
else:
|
||||||
|
self._verify_non_beam_search()
|
||||||
|
if self.temperature < _SAMPLING_EPS:
|
||||||
# Zero temperature means greedy sampling.
|
# Zero temperature means greedy sampling.
|
||||||
self._verify_greedy_sampling()
|
self._verify_greedy_sampling()
|
||||||
|
|
||||||
@ -109,7 +124,7 @@ class SamplingParams:
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"logprobs must be non-negative, got {self.logprobs}.")
|
f"logprobs must be non-negative, got {self.logprobs}.")
|
||||||
|
|
||||||
def _verity_beam_search(self) -> None:
|
def _verify_beam_search(self) -> None:
|
||||||
if self.best_of == 1:
|
if self.best_of == 1:
|
||||||
raise ValueError("best_of must be greater than 1 when using beam "
|
raise ValueError("best_of must be greater than 1 when using beam "
|
||||||
f"search. Got {self.best_of}.")
|
f"search. Got {self.best_of}.")
|
||||||
@ -119,6 +134,20 @@ class SamplingParams:
|
|||||||
raise ValueError("top_p must be 1 when using beam search.")
|
raise ValueError("top_p must be 1 when using beam search.")
|
||||||
if self.top_k != -1:
|
if self.top_k != -1:
|
||||||
raise ValueError("top_k must be -1 when using beam search.")
|
raise ValueError("top_k must be -1 when using beam search.")
|
||||||
|
if self.early_stopping not in [True, False, "never"]:
|
||||||
|
raise ValueError(
|
||||||
|
f"early_stopping must be True, False, or 'never', "
|
||||||
|
f"got {self.early_stopping}.")
|
||||||
|
|
||||||
|
def _verify_non_beam_search(self) -> None:
|
||||||
|
if self.early_stopping is not False:
|
||||||
|
raise ValueError("early_stopping is not effective and must be "
|
||||||
|
"False when not using beam search.")
|
||||||
|
if (self.length_penalty < 1.0 - _SAMPLING_EPS
|
||||||
|
or self.length_penalty > 1.0 + _SAMPLING_EPS):
|
||||||
|
raise ValueError(
|
||||||
|
"length_penalty is not effective and must be the "
|
||||||
|
"default value of 1.0 when not using beam search.")
|
||||||
|
|
||||||
def _verify_greedy_sampling(self) -> None:
|
def _verify_greedy_sampling(self) -> None:
|
||||||
if self.best_of > 1:
|
if self.best_of > 1:
|
||||||
@ -138,6 +167,8 @@ class SamplingParams:
|
|||||||
f"top_p={self.top_p}, "
|
f"top_p={self.top_p}, "
|
||||||
f"top_k={self.top_k}, "
|
f"top_k={self.top_k}, "
|
||||||
f"use_beam_search={self.use_beam_search}, "
|
f"use_beam_search={self.use_beam_search}, "
|
||||||
|
f"length_penalty={self.length_penalty}, "
|
||||||
|
f"early_stopping={self.early_stopping}, "
|
||||||
f"stop={self.stop}, "
|
f"stop={self.stop}, "
|
||||||
f"ignore_eos={self.ignore_eos}, "
|
f"ignore_eos={self.ignore_eos}, "
|
||||||
f"max_tokens={self.max_tokens}, "
|
f"max_tokens={self.max_tokens}, "
|
||||||
|
@ -69,6 +69,9 @@ class SequenceData:
|
|||||||
def get_len(self) -> int:
|
def get_len(self) -> int:
|
||||||
return len(self.output_token_ids) + len(self.prompt_token_ids)
|
return len(self.output_token_ids) + len(self.prompt_token_ids)
|
||||||
|
|
||||||
|
def get_prompt_len(self) -> int:
|
||||||
|
return len(self.prompt_token_ids)
|
||||||
|
|
||||||
def get_output_len(self) -> int:
|
def get_output_len(self) -> int:
|
||||||
return len(self.output_token_ids)
|
return len(self.output_token_ids)
|
||||||
|
|
||||||
@ -155,6 +158,9 @@ class Sequence:
|
|||||||
def get_len(self) -> int:
|
def get_len(self) -> int:
|
||||||
return self.data.get_len()
|
return self.data.get_len()
|
||||||
|
|
||||||
|
def get_prompt_len(self) -> int:
|
||||||
|
return self.data.get_prompt_len()
|
||||||
|
|
||||||
def get_output_len(self) -> int:
|
def get_output_len(self) -> int:
|
||||||
return self.data.get_output_len()
|
return self.data.get_output_len()
|
||||||
|
|
||||||
@ -170,14 +176,32 @@ class Sequence:
|
|||||||
def get_cumulative_logprob(self) -> float:
|
def get_cumulative_logprob(self) -> float:
|
||||||
return self.data.cumulative_logprob
|
return self.data.cumulative_logprob
|
||||||
|
|
||||||
|
def get_beam_search_score(self,
|
||||||
|
length_penalty: float = 0.0,
|
||||||
|
seq_len: Optional[int] = None,
|
||||||
|
eos_token_id: Optional[int] = None) -> float:
|
||||||
|
"""Calculate the beam search score with length penalty.
|
||||||
|
|
||||||
|
Adapted from
|
||||||
|
|
||||||
|
https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938
|
||||||
|
"""
|
||||||
|
if seq_len is None:
|
||||||
|
seq_len = self.get_len()
|
||||||
|
# Note: HF implementation does not count the EOS token
|
||||||
|
# towards the length, we align with that here for testing.
|
||||||
|
if (eos_token_id is not None
|
||||||
|
and self.get_last_token_id() == eos_token_id):
|
||||||
|
seq_len -= 1
|
||||||
|
return self.get_cumulative_logprob() / (seq_len**length_penalty)
|
||||||
|
|
||||||
def is_finished(self) -> bool:
|
def is_finished(self) -> bool:
|
||||||
return SequenceStatus.is_finished(self.status)
|
return SequenceStatus.is_finished(self.status)
|
||||||
|
|
||||||
def fork(self, child_seq: "Sequence") -> None:
|
def fork(self, new_seq_id: int) -> "Sequence":
|
||||||
child_seq.logical_token_blocks = copy.deepcopy(
|
new_seq = copy.deepcopy(self)
|
||||||
self.logical_token_blocks)
|
new_seq.seq_id = new_seq_id
|
||||||
child_seq.output_logprobs = copy.deepcopy(self.output_logprobs)
|
return new_seq
|
||||||
child_seq.data = copy.deepcopy(self.data)
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return (f"Sequence(seq_id={self.seq_id}, "
|
return (f"Sequence(seq_id={self.seq_id}, "
|
||||||
@ -203,35 +227,66 @@ class SequenceGroup:
|
|||||||
arrival_time: float,
|
arrival_time: float,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.request_id = request_id
|
self.request_id = request_id
|
||||||
self.seqs = seqs
|
self.seqs_dict = {seq.seq_id: seq for seq in seqs}
|
||||||
self.sampling_params = sampling_params
|
self.sampling_params = sampling_params
|
||||||
self.arrival_time = arrival_time
|
self.arrival_time = arrival_time
|
||||||
|
|
||||||
|
def get_max_num_running_seqs(self) -> int:
|
||||||
|
"""The maximum number of sequences running in parallel in the remaining
|
||||||
|
lifetime of the request."""
|
||||||
|
if self.sampling_params.use_beam_search:
|
||||||
|
# For beam search, maximally there will always be `best_of` beam
|
||||||
|
# candidates running in the future.
|
||||||
|
return self.sampling_params.best_of
|
||||||
|
else:
|
||||||
|
if self.sampling_params.best_of > self.num_seqs():
|
||||||
|
# At prompt stage, the sequence group is not yet filled up
|
||||||
|
# and only have one sequence running. However, in the
|
||||||
|
# generation stage, we will have `best_of` sequences running.
|
||||||
|
return self.sampling_params.best_of
|
||||||
|
# At sampling stages, return the number of actual sequences
|
||||||
|
# running.
|
||||||
|
return self.num_seqs(status=SequenceStatus.RUNNING)
|
||||||
|
|
||||||
def get_seqs(
|
def get_seqs(
|
||||||
self,
|
self,
|
||||||
status: Optional[SequenceStatus] = None,
|
status: Optional[SequenceStatus] = None,
|
||||||
) -> List[Sequence]:
|
) -> List[Sequence]:
|
||||||
if status is None:
|
if status is None:
|
||||||
return self.seqs
|
return list(self.seqs_dict.values())
|
||||||
else:
|
else:
|
||||||
return [seq for seq in self.seqs if seq.status == status]
|
return [
|
||||||
|
seq for seq in self.seqs_dict.values() if seq.status == status
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_finished_seqs(self) -> List[Sequence]:
|
||||||
|
return [seq for seq in self.seqs_dict.values() if seq.is_finished()]
|
||||||
|
|
||||||
def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
|
def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
|
||||||
return len(self.get_seqs(status))
|
return len(self.get_seqs(status))
|
||||||
|
|
||||||
def find(self, seq_id: int) -> Sequence:
|
def find(self, seq_id: int) -> Sequence:
|
||||||
for seq in self.seqs:
|
if seq_id not in self.seqs_dict:
|
||||||
if seq.seq_id == seq_id:
|
|
||||||
return seq
|
|
||||||
raise ValueError(f"Sequence {seq_id} not found.")
|
raise ValueError(f"Sequence {seq_id} not found.")
|
||||||
|
return self.seqs_dict[seq_id]
|
||||||
|
|
||||||
|
def add(self, seq: Sequence) -> None:
|
||||||
|
if seq.seq_id in self.seqs_dict:
|
||||||
|
raise ValueError(f"Sequence {seq.seq_id} already exists.")
|
||||||
|
self.seqs_dict[seq.seq_id] = seq
|
||||||
|
|
||||||
|
def remove(self, seq_id: int) -> None:
|
||||||
|
if seq_id not in self.seqs_dict:
|
||||||
|
raise ValueError(f"Sequence {seq_id} not found.")
|
||||||
|
del self.seqs_dict[seq_id]
|
||||||
|
|
||||||
def is_finished(self) -> bool:
|
def is_finished(self) -> bool:
|
||||||
return all(seq.is_finished() for seq in self.seqs)
|
return all(seq.is_finished() for seq in self.get_seqs())
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return (f"SequenceGroup(request_id={self.request_id}, "
|
return (f"SequenceGroup(request_id={self.request_id}, "
|
||||||
f"sampling_params={self.sampling_params}, "
|
f"sampling_params={self.sampling_params}, "
|
||||||
f"num_seqs={len(self.seqs)})")
|
f"num_seqs={len(self.seqs_dict)})")
|
||||||
|
|
||||||
|
|
||||||
class SequenceGroupMetadata:
|
class SequenceGroupMetadata:
|
||||||
@ -266,7 +321,6 @@ class SequenceOutputs:
|
|||||||
"""The model output associated with a sequence.
|
"""The model output associated with a sequence.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
seq_id: The ID of the sequence.
|
|
||||||
parent_seq_id: The ID of the parent sequence (for forking in beam
|
parent_seq_id: The ID of the parent sequence (for forking in beam
|
||||||
search).
|
search).
|
||||||
output_token: The output token ID.
|
output_token: The output token ID.
|
||||||
@ -276,26 +330,27 @@ class SequenceOutputs:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
seq_id: int,
|
|
||||||
parent_seq_id: int,
|
parent_seq_id: int,
|
||||||
output_token: int,
|
output_token: int,
|
||||||
logprobs: Dict[int, float],
|
logprobs: Dict[int, float],
|
||||||
) -> None:
|
) -> None:
|
||||||
self.seq_id = seq_id
|
|
||||||
self.parent_seq_id = parent_seq_id
|
self.parent_seq_id = parent_seq_id
|
||||||
self.output_token = output_token
|
self.output_token = output_token
|
||||||
self.logprobs = logprobs
|
self.logprobs = logprobs
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return (f"SequenceOutputs(seq_id={self.seq_id}, "
|
return (f"SequenceOutputs(parent_seq_id={self.parent_seq_id}, "
|
||||||
f"parent_seq_id={self.parent_seq_id}, "
|
|
||||||
f"output_token={self.output_token}), "
|
f"output_token={self.output_token}), "
|
||||||
f"logprobs={self.logprobs}")
|
f"logprobs={self.logprobs}")
|
||||||
|
|
||||||
def __eq__(self, other: object) -> bool:
|
def __eq__(self, other: object) -> bool:
|
||||||
if not isinstance(other, SequenceOutputs):
|
if not isinstance(other, SequenceOutputs):
|
||||||
return NotImplemented
|
return NotImplementedError()
|
||||||
return (self.seq_id == other.seq_id
|
return (self.parent_seq_id == other.parent_seq_id
|
||||||
and self.parent_seq_id == other.parent_seq_id
|
|
||||||
and self.output_token == other.output_token
|
and self.output_token == other.output_token
|
||||||
and self.logprobs == other.logprobs)
|
and self.logprobs == other.logprobs)
|
||||||
|
|
||||||
|
|
||||||
|
# For each sequence group, we generate a list of SequenceOutputs object,
|
||||||
|
# each of which contains one possible candidate for the next token.
|
||||||
|
SamplerOutput = List[List[SequenceOutputs]]
|
||||||
|
@ -5,6 +5,8 @@ from vllm.transformers_utils.configs import * # pylint: disable=wildcard-import
|
|||||||
_CONFIG_REGISTRY = {
|
_CONFIG_REGISTRY = {
|
||||||
"mpt": MPTConfig,
|
"mpt": MPTConfig,
|
||||||
"baichuan": BaiChuanConfig,
|
"baichuan": BaiChuanConfig,
|
||||||
|
"aquila": AquilaConfig,
|
||||||
|
"qwen": QWenConfig,
|
||||||
"RefinedWeb": RWConfig, # For tiiuae/falcon-40b(-instruct)
|
"RefinedWeb": RWConfig, # For tiiuae/falcon-40b(-instruct)
|
||||||
"RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct)
|
"RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct)
|
||||||
}
|
}
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
from vllm.transformers_utils.configs.mpt import MPTConfig
|
from vllm.transformers_utils.configs.mpt import MPTConfig
|
||||||
from vllm.transformers_utils.configs.baichuan import BaiChuanConfig
|
from vllm.transformers_utils.configs.baichuan import BaiChuanConfig
|
||||||
|
from vllm.transformers_utils.configs.aquila import AquilaConfig
|
||||||
|
from vllm.transformers_utils.configs.qwen import QWenConfig
|
||||||
# RWConfig is for the original tiiuae/falcon-40b(-instruct) and
|
# RWConfig is for the original tiiuae/falcon-40b(-instruct) and
|
||||||
# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
|
# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
|
||||||
# `FalconConfig` class from the official HuggingFace transformers library.
|
# `FalconConfig` class from the official HuggingFace transformers library.
|
||||||
@ -8,5 +10,7 @@ from vllm.transformers_utils.configs.falcon import RWConfig
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
"MPTConfig",
|
"MPTConfig",
|
||||||
"BaiChuanConfig",
|
"BaiChuanConfig",
|
||||||
|
"AquilaConfig",
|
||||||
|
"QWenConfig",
|
||||||
"RWConfig",
|
"RWConfig",
|
||||||
]
|
]
|
||||||
|
63
vllm/transformers_utils/configs/aquila.py
Normal file
63
vllm/transformers_utils/configs/aquila.py
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2023 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||||
|
# and OPT implementations in this library. It has been modified from its
|
||||||
|
# original forms to accommodate minor architectural differences compared
|
||||||
|
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
""" Aquila model configuration"""
|
||||||
|
|
||||||
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
|
|
||||||
|
class AquilaConfig(PretrainedConfig):
|
||||||
|
model_type = "aquila"
|
||||||
|
keys_to_ignore_at_inference = ["past_key_values"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size=100008,
|
||||||
|
hidden_size=4096,
|
||||||
|
intermediate_size=11008,
|
||||||
|
num_hidden_layers=32,
|
||||||
|
num_attention_heads=32,
|
||||||
|
hidden_act="silu",
|
||||||
|
max_position_embeddings=2048,
|
||||||
|
initializer_range=0.006,
|
||||||
|
rms_norm_eps=1e-5,
|
||||||
|
use_cache=True,
|
||||||
|
pad_token_id=0,
|
||||||
|
bos_token_id=1,
|
||||||
|
eos_token_id=2,
|
||||||
|
tie_word_embeddings=False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.rms_norm_eps = rms_norm_eps
|
||||||
|
self.use_cache = use_cache
|
||||||
|
super().__init__(
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
bos_token_id=bos_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
tie_word_embeddings=tie_word_embeddings,
|
||||||
|
**kwargs,
|
||||||
|
)
|
71
vllm/transformers_utils/configs/qwen.py
Normal file
71
vllm/transformers_utils/configs/qwen.py
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
# Copyright (c) Alibaba Cloud.
|
||||||
|
# LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE
|
||||||
|
|
||||||
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
|
|
||||||
|
class QWenConfig(PretrainedConfig):
|
||||||
|
model_type = "qwen"
|
||||||
|
keys_to_ignore_at_inference = ["past_key_values"]
|
||||||
|
attribute_map = {
|
||||||
|
"hidden_size": "n_embd",
|
||||||
|
"num_attention_heads": "n_head",
|
||||||
|
"max_position_embeddings": "n_positions",
|
||||||
|
"num_hidden_layers": "n_layer",
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size=151851,
|
||||||
|
n_embd=4096,
|
||||||
|
n_layer=32,
|
||||||
|
n_head=32,
|
||||||
|
n_inner=None,
|
||||||
|
embd_pdrop=0.0,
|
||||||
|
attn_pdrop=0.0,
|
||||||
|
layer_norm_epsilon=1e-5,
|
||||||
|
initializer_range=0.02,
|
||||||
|
scale_attn_weights=True,
|
||||||
|
use_cache=True,
|
||||||
|
eos_token_id=151643,
|
||||||
|
apply_residual_connection_post_layernorm=False,
|
||||||
|
bf16=True,
|
||||||
|
kv_channels=128,
|
||||||
|
rotary_pct=1.0,
|
||||||
|
rotary_emb_base=10000,
|
||||||
|
use_dynamic_ntk=False,
|
||||||
|
use_logn_attn=False,
|
||||||
|
use_flash_attn=True,
|
||||||
|
ffn_hidden_size=22016,
|
||||||
|
no_bias=True,
|
||||||
|
tie_word_embeddings=False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.eos_token_id = eos_token_id
|
||||||
|
super().__init__(eos_token_id=eos_token_id,
|
||||||
|
tie_word_embeddings=tie_word_embeddings,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.n_embd = n_embd
|
||||||
|
self.n_layer = n_layer
|
||||||
|
self.n_head = n_head
|
||||||
|
self.n_inner = n_inner
|
||||||
|
self.embd_pdrop = embd_pdrop
|
||||||
|
self.attn_pdrop = attn_pdrop
|
||||||
|
self.layer_norm_epsilon = layer_norm_epsilon
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.scale_attn_weights = scale_attn_weights
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.apply_residual_connection_post_layernorm = (
|
||||||
|
apply_residual_connection_post_layernorm)
|
||||||
|
self.bf16 = bf16
|
||||||
|
self.kv_channels = kv_channels
|
||||||
|
self.rotary_pct = rotary_pct
|
||||||
|
self.rotary_emb_base = rotary_emb_base
|
||||||
|
self.use_dynamic_ntk = use_dynamic_ntk
|
||||||
|
self.use_logn_attn = use_logn_attn
|
||||||
|
self.use_flash_attn = use_flash_attn
|
||||||
|
self.ffn_hidden_size = ffn_hidden_size
|
||||||
|
self.no_bias = no_bias
|
||||||
|
self.tie_word_embeddings = tie_word_embeddings
|
@ -25,7 +25,8 @@ def get_tokenizer(
|
|||||||
"Cannot use the fast tokenizer in slow tokenizer mode.")
|
"Cannot use the fast tokenizer in slow tokenizer mode.")
|
||||||
kwargs["use_fast"] = False
|
kwargs["use_fast"] = False
|
||||||
|
|
||||||
if "llama" in tokenizer_name.lower() and kwargs.get("use_fast", True):
|
if ("llama" in tokenizer_name.lower() and kwargs.get("use_fast", True)
|
||||||
|
and tokenizer_name != _FAST_LLAMA_TOKENIZER):
|
||||||
logger.info(
|
logger.info(
|
||||||
"For some LLaMA-based models, initializing the fast tokenizer may "
|
"For some LLaMA-based models, initializing the fast tokenizer may "
|
||||||
"take a long time. To eliminate the initialization time, consider "
|
"take a long time. To eliminate the initialization time, consider "
|
||||||
@ -72,7 +73,7 @@ def detokenize_incrementally(
|
|||||||
new_token_id: int,
|
new_token_id: int,
|
||||||
skip_special_tokens: bool,
|
skip_special_tokens: bool,
|
||||||
) -> Tuple[str, str]:
|
) -> Tuple[str, str]:
|
||||||
"""Detokenizes the new token in conjuction with the previous output tokens.
|
"""Detokenizes the new token in conjunction with the previous output tokens.
|
||||||
|
|
||||||
NOTE: This function does not update prev_output_tokens.
|
NOTE: This function does not update prev_output_tokens.
|
||||||
|
|
||||||
|
@ -11,7 +11,7 @@ from vllm.model_executor import get_model, InputMetadata, set_random_seed
|
|||||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||||
initialize_model_parallel)
|
initialize_model_parallel)
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.sequence import SequenceData, SequenceGroupMetadata, SequenceOutputs
|
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
|
||||||
from vllm.worker.cache_engine import CacheEngine
|
from vllm.worker.cache_engine import CacheEngine
|
||||||
from vllm.utils import get_gpu_memory
|
from vllm.utils import get_gpu_memory
|
||||||
|
|
||||||
@ -260,7 +260,7 @@ class Worker:
|
|||||||
blocks_to_swap_in: Dict[int, int],
|
blocks_to_swap_in: Dict[int, int],
|
||||||
blocks_to_swap_out: Dict[int, int],
|
blocks_to_swap_out: Dict[int, int],
|
||||||
blocks_to_copy: Dict[int, List[int]],
|
blocks_to_copy: Dict[int, List[int]],
|
||||||
) -> Dict[int, SequenceOutputs]:
|
) -> SamplerOutput:
|
||||||
# Issue cache operations.
|
# Issue cache operations.
|
||||||
issued_cache_op = False
|
issued_cache_op = False
|
||||||
if blocks_to_swap_in:
|
if blocks_to_swap_in:
|
||||||
|
Reference in New Issue
Block a user