Compare commits

...

102 Commits

Author SHA1 Message Date
791d79de32 Bump up the version to v0.1.4 (#846) 2023-08-25 12:28:00 +09:00
94d2f59895 Set replacement=True in torch.multinomial (#858) 2023-08-25 12:22:01 +09:00
75c0ca9d43 Clean up code (#844) 2023-08-23 16:44:15 -07:00
2a4ec90854 Fix for breaking changes in xformers 0.0.21 (#834) 2023-08-23 17:44:21 +09:00
85ebcda94d Fix typo of Aquila in README.md (#836) 2023-08-22 20:48:36 -07:00
d64bf1646c Implement approximate GELU kernels (#828) 2023-08-23 07:43:21 +09:00
a41c20435e Add compute capability 8.9 to default targets (#829) 2023-08-23 07:28:38 +09:00
eedac9dba0 fix: revert code to avoid no attribute problem (#827) 2023-08-22 11:55:16 -07:00
14f9c72bfd Update Supported Model List (#825) 2023-08-22 11:51:44 -07:00
ad5f2fe34c Add support for aquila (#663)
* add aquila

Signed-off-by: ftgreat <ftgreat@163.com>

* fix some bug

Signed-off-by: shunxing1234 <xw747777271@gmail.com>

* delete pdb

Signed-off-by: shunxing1234 <xw747777271@gmail.com>

* fix bugs

Signed-off-by: shunxing1234 <xw747777271@gmail.com>

* fix bugs

Signed-off-by: shunxing1234 <xw747777271@gmail.com>

* delete whitespace

Signed-off-by: shunxing1234 <xw747777271@gmail.com>

* format

* fix order

---------

Signed-off-by: ftgreat <ftgreat@163.com>
Signed-off-by: shunxing1234 <xw747777271@gmail.com>
Co-authored-by: ftgreat <ftgreat@163.com>
2023-08-22 00:13:36 -07:00
4f8584756d Fix mqa is false case in gpt_bigcode (#806) 2023-08-21 22:22:06 -07:00
65fc1c3127 set default coompute capability according to cuda version (#773) 2023-08-21 16:05:44 -07:00
c393af6cd7 [Feature | CI] Added a github action to build wheels (#746) 2023-08-21 16:59:15 +09:00
0c04ce3234 Fix typo in sampling_params.py (#788) 2023-08-18 10:12:46 +09:00
73b3de79ea explicitly del state (#784) 2023-08-17 12:56:04 -07:00
d1744376ae Align with huggingface Top K sampling (#753) 2023-08-15 16:44:33 -07:00
805de738f6 Fix typo in tokenizer.py (#750)
conjuction -> conjunction
2023-08-14 22:26:36 -07:00
1b151ed181 Fix baichuan doc style (#748) 2023-08-13 20:57:31 -07:00
e06f504a76 Supports tokens and arrays of tokens as inputs to the OpenAI completion API (#715) 2023-08-11 12:14:34 -07:00
WRH
462ae5220a [Fix] unwantted bias in InternLM Model (#740) 2023-08-11 11:40:37 -07:00
66c54aa9c3 Check the max prompt length for the OpenAI completions API (#472) 2023-08-08 17:43:49 -07:00
735ecfff61 add internlm model (#528) 2023-08-08 16:35:06 -07:00
a57d13cc96 add QWen-7b (#685)
Co-authored-by: wq.chu <wq.chu@tianrang-inc.com>
2023-08-08 13:50:38 -07:00
79af7e96a0 [OPTIMIZATION] Optimizes the single_query_cached_kv_attention kernel (#420) 2023-08-04 10:57:29 -07:00
621980bdc0 fix: incorrect bigcode attention heads num (#676) 2023-08-04 10:35:22 -07:00
aa84c92ef6 Bump up version to 0.1.3 (#657) 2023-08-02 16:46:53 -07:00
f7389f4763 [Doc] Add Baichuan 13B to supported models (#656) 2023-08-02 16:45:12 -07:00
55fe8a81ec Refactor scheduler (#658) 2023-08-02 16:42:01 -07:00
e8ddc08ec8 [BUG FIX] upgrade fschat version to 0.2.23 (#650)
Co-authored-by: hao.yu <hao.yu@cn-c017.server.mila.quebec>
2023-08-02 14:05:59 -07:00
1b0bd0fe8a Add Falcon support (new) (#592) 2023-08-02 14:04:39 -07:00
20044cab7a Fix log message in scheduler (#652) 2023-08-02 13:35:10 -07:00
64f23c2900 fix baichuan for different position embedding for 7b and 13b models (#643) 2023-08-01 22:22:51 -07:00
d4c7755ca8 fix biachuan-7b tp (#598)
Co-authored-by: wq.chu <wq.chu@tianrang-inc.com>
2023-08-01 15:41:36 -07:00
aa39e42c5a fix doc (#622) 2023-07-31 13:11:57 -07:00
953f28cf9a fix ModuleNotFoundError (#599)
Co-authored-by: fangli <fangli@tencent.com>
2023-07-29 20:52:41 -07:00
c0d00f5be6 [Fix] fix import error of RayWorker (#604) (#605) 2023-07-27 23:37:40 -07:00
58a072be15 [Fix] Add model sequence length into model config (#575) 2023-07-25 23:46:30 -07:00
82ad323dee [Fix] Add chat completion Example and simplify dependencies (#576) 2023-07-25 23:45:48 -07:00
df5dd3c68e Add Baichuan-7B to README (#494) 2023-07-25 15:25:12 -07:00
2d867b55fa fixed tensor parallel is not defined (#564) 2023-07-25 14:16:51 -07:00
d7a1c6d614 Fix paged attention testing. (#495)
Signed-off-by: Tao Peng <jiankeng.pt@alibaba-inc.com>
2023-07-24 21:01:56 -07:00
7d5a155e4a [Fix] Fix GPTBigcoder for distributed execution (#503) 2023-07-24 18:36:33 -07:00
1dde34e0f8 GPTJConfig has no attribute rotary. (#532) 2023-07-24 11:29:30 -07:00
6fc2a38b11 Add support for LLaMA-2 (#505) 2023-07-20 11:38:27 -07:00
c487a221ee Fix bad assert in initialize_cluster if PG already exists (#526) 2023-07-19 23:17:12 -07:00
9925c17940 Ray placement group support (#397) 2023-07-19 22:49:31 -07:00
8c4b2592fb fix: enable trust-remote-code in api server & benchmark. (#509) 2023-07-19 17:06:15 -07:00
WRH
cf21a9bd5c support trust_remote_code in benchmark (#518) 2023-07-19 17:02:40 -07:00
16c3e295a8 fix(ray_utils): ignore re-init error (#465) 2023-07-19 17:01:19 -07:00
bda41c70dd hotfix attn alibi wo head mapping (#496)
Co-authored-by: oliveryuan <oliveryuan@basemind.com>
2023-07-18 11:31:48 -07:00
453bafb96f Merge pull request #498 from MoeedDar/main
Fixed old name reference for max_seq_len
2023-07-18 09:22:56 -07:00
328d231c17 Fixed old name reference for max_seq_len 2023-07-18 16:47:59 +01:00
b4b195b360 fix max seq len (#489) 2023-07-17 23:20:20 -07:00
20b0d88d16 Add support for baichuan (#365) 2023-07-17 13:50:55 -07:00
2bdea7ac11 [Fix] Fix the condition of max_seq_len (#477) 2023-07-17 00:33:48 -04:00
58df2883cb [Doc] Add doc for running vLLM on the cloud (#426)
Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
2023-07-16 13:37:14 -07:00
6d7d95a70a Offload port selection to OS (#467) 2023-07-15 23:11:02 -07:00
96853af5a8 Optimize MQA Kernel (#452) 2023-07-14 20:06:40 -04:00
dbed69058c Fix the KeyError when loading bloom-based models (#441) 2023-07-13 21:58:09 -07:00
7b6ae94059 add vocab padding for LLama(Support WizardLM) (#411) 2023-07-13 23:56:22 -04:00
c6dfc3cdbe Fix handling of special tokens in decoding. (#418) 2023-07-12 11:14:56 -04:00
51be365143 fix: freeze pydantic to v1 (#429) 2023-07-12 11:10:55 -04:00
c894836108 [Model] Add support for GPT-J (#226)
Co-authored-by: woWoosuk Kwon <woosuk.kwon@berkeley.edu>
2023-07-08 17:55:16 -07:00
75beba29b5 Don't try to load training_args.bin (#373) 2023-07-08 15:26:28 -07:00
ddfdf470ae Add trust_remote_code arg to get_config (#405) 2023-07-08 15:24:17 -07:00
b6fbb9a565 Sort the outputs before return (#402) 2023-07-08 14:48:18 -07:00
2179e4f4c5 avoid python list copy in sequence initialization (#401) 2023-07-08 12:42:08 -07:00
a945fcc2ae Add trust-remote-code flag to handle remote tokenizers (#364) 2023-07-07 11:04:58 -07:00
be54f8e5c4 [Fix] Change /generate response-type to json for non-streaming (#374) 2023-07-06 18:15:17 -07:00
b396cb4998 fix: only response [DONE] once when streaming response. (#378) 2023-07-06 18:08:40 -07:00
1c395b4eaa Bump up the version (#300) 2023-07-04 21:41:53 -07:00
3d64cf019e [Server] use fastchat.model.model_adapter.get_conversation_template method to get model template (#357) 2023-07-04 21:39:59 -07:00
98fe8cb542 [Server] Add option to specify chat template for chat endpoint (#345) 2023-07-03 23:01:56 -07:00
ffa6d2f9f9 [Docs] Fix typo (#346) 2023-07-03 16:51:47 -07:00
404422f42e [Model] Add support for MPT (#334) 2023-07-03 16:47:53 -07:00
7717d0838b Fix an endless loop issue when engine_step throws a RuntimeError (#339) 2023-07-03 15:22:28 -07:00
42e0c1df78 [Quality] Add CI for formatting (#343) 2023-07-03 14:50:56 -07:00
e41f06702c Add support for BLOOM (#331) 2023-07-03 13:12:35 -07:00
d6fa1be3a8 [Quality] Add code formatter and linter (#326) 2023-07-03 11:31:55 -07:00
0ffded812a [Fix] Better error message for batched prompts (#342) 2023-07-03 09:27:31 -07:00
0bd2a573a5 Allow send list of str for the Prompt on openai demo endpoint /v1/completions (#323)
* allow str or List[str] for prompt

* Update vllm/entrypoints/openai/api_server.py

Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>

---------

Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
2023-07-03 09:17:50 -07:00
49b26e2cec feat: add ChatCompletion endpoint in OpenAI demo server. (#330) 2023-07-02 22:54:33 -07:00
dafd924c1f Raise error for long prompt (#273) 2023-06-30 18:48:49 -07:00
598dc4b79a [Fix] Weight loading for GPTBigCode (#313) 2023-06-29 22:14:17 -07:00
85de093472 [Fix] Do not pin memory when in WSL (#312) 2023-06-29 15:00:21 -07:00
f72297562f Add news for the vllm+skypilot example (#314) 2023-06-29 12:32:37 -07:00
9d27b09d12 Update README.md (#306) 2023-06-29 06:52:15 -07:00
998d9d1509 [Tokenizer] Add tokenizer mode (#298) 2023-06-28 14:19:22 -07:00
425040d4c1 remove floats == 0 comparison (#285) 2023-06-28 14:11:51 -07:00
4338cc4750 [Tokenizer] Add an option to specify tokenizer (#284) 2023-06-28 09:46:58 -07:00
bdd6b4c8bc Add LLM.set_tokenizer (#283) 2023-06-28 00:28:29 -07:00
2b7d3aca2e Update setup.py (#282)
Co-authored-by: neubig <neubig@gmail.com>
2023-06-27 14:34:23 -07:00
4026a049d3 expand coverage of gpt2 model loading (#271) 2023-06-27 06:27:41 -07:00
43710e8d09 [Fix] Fix default port number in benchmark scripts (#265) 2023-06-26 13:15:35 -07:00
526df28fb2 [BugFix] Fix a bug in counting running sequences (#266) 2023-06-26 13:09:02 -07:00
2cf1a333b6 [Doc] Documentation for distributed inference (#261) 2023-06-26 11:34:23 -07:00
0b7db411b5 [Bug] Fix the OOM condition for CPU cache (#260) 2023-06-26 11:16:13 -07:00
471a7a4566 Compatible with Decapoda Research llama hf version (#251) 2023-06-26 09:23:57 -07:00
6214dd6ce9 Update README.md (#236) 2023-06-25 16:58:06 -07:00
0603379863 fix wrong using getattr to get dict value (#232) 2023-06-24 22:00:24 -07:00
665c48963b [Docs] Add GPTBigCode to supported models (#213) 2023-06-22 15:05:11 -07:00
298695b766 GPTBigCode (StarCoder, SantaCoder Support) (#209) 2023-06-23 01:49:27 +08:00
96 changed files with 7052 additions and 1142 deletions

101
.github/workflows/publish.yml vendored Normal file
View File

@ -0,0 +1,101 @@
# This workflow will upload a Python Package to Release asset
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions
name: Create Release
on:
push:
tags:
- v*
# Needed to create release and upload assets
permissions:
contents: write
jobs:
release:
# Retrieve tag and create release
name: Create Release
runs-on: ubuntu-latest
outputs:
upload_url: ${{ steps.create_release.outputs.upload_url }}
steps:
- name: Checkout
uses: actions/checkout@v3
- name: Extract branch info
shell: bash
run: |
echo "release_tag=${GITHUB_REF#refs/*/}" >> $GITHUB_ENV
- name: Create Release
id: create_release
uses: "actions/github-script@v6"
env:
RELEASE_TAG: ${{ env.release_tag }}
with:
github-token: "${{ secrets.GITHUB_TOKEN }}"
script: |
const script = require('.github/workflows/scripts/create_release.js')
await script(github, context, core)
wheel:
name: Build Wheel
runs-on: ${{ matrix.os }}
needs: release
strategy:
fail-fast: false
matrix:
os: ['ubuntu-20.04']
python-version: ['3.8', '3.9', '3.10', '3.11']
cuda-version: ['11.8'] # Github runner can't build anything older than 11.8
steps:
- name: Checkout
uses: actions/checkout@v3
- name: Set up Linux Env
if: ${{ runner.os == 'Linux' }}
run: |
bash -x .github/workflows/scripts/env.sh
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install CUDA ${{ matrix.cuda-version }}
run: |
bash -x .github/workflows/scripts/cuda-install.sh ${{ matrix.cuda-version }} ${{ matrix.os }}
- name: Install PyTorch-cu${{ matrix.cuda-version }}
run: |
bash -x .github/workflows/scripts/pytorch-install.sh ${{ matrix.python-version }} ${{ matrix.cuda-version }}
- name: Build wheel
shell: bash
run: |
bash -x .github/workflows/scripts/build.sh ${{ matrix.python-version }} ${{ matrix.cuda-version }}
wheel_name=$(ls dist/*whl | xargs -n 1 basename)
asset_name=${wheel_name//"linux"/"manylinux1"}
echo "wheel_name=${wheel_name}" >> $GITHUB_ENV
echo "asset_name=${asset_name}" >> $GITHUB_ENV
- name: Upload Release Asset
uses: actions/upload-release-asset@v1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
upload_url: ${{ needs.release.outputs.upload_url }}
asset_path: ./dist/${{ env.wheel_name }}
asset_name: ${{ env.asset_name }}
asset_content_type: application/*
# (Danielkinz): This last step will publish the .whl to pypi. Warning: untested
# - name: Publish package
# uses: pypa/gh-action-pypi-publish@release/v1.8
# with:
# repository-url: https://test.pypi.org/legacy/
# password: ${{ secrets.PYPI_API_TOKEN }}
# skip-existing: true

31
.github/workflows/pylint.yml vendored Normal file
View File

@ -0,0 +1,31 @@
name: pylint
on:
# Trigger the workflow on push or pull request,
# but only for the main branch
push:
branches:
- main
pull_request:
branches:
- main
jobs:
pylint:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.10"]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pylint==2.8.2
- name: Analysing the code with pylint
run: |
pylint vllm

15
.github/workflows/scripts/build.sh vendored Normal file
View File

@ -0,0 +1,15 @@
#!/bin/bash
python_executable=python$1
cuda_home=/usr/local/cuda-$2
# Update paths
PATH=${cuda_home}/bin:$PATH
LD_LIBRARY_PATH=${cuda_home}/lib64:$LD_LIBRARY_PATH
# Install requirements
$python_executable -m pip install wheel packaging
$python_executable -m pip install -r requirements.txt
# Build
$python_executable setup.py bdist_wheel --dist-dir=dist

View File

@ -0,0 +1,20 @@
// Uses Github's API to create the release and wait for result.
// We use a JS script since github CLI doesn't provide a way to wait for the release's creation and returns immediately.
module.exports = async (github, context, core) => {
try {
const response = await github.rest.repos.createRelease({
draft: false,
generate_release_notes: true,
name: process.env.RELEASE_TAG,
owner: context.repo.owner,
prerelease: false,
repo: context.repo.repo,
tag_name: process.env.RELEASE_TAG,
});
core.setOutput('upload_url', response.data.upload_url);
} catch (error) {
core.setFailed(error.message);
}
}

View File

@ -0,0 +1,18 @@
#!/bin/bash
# Replace '.' with '-' ex: 11.8 -> 11-8
cuda_version=$(echo $1 | tr "." "-")
# Removes '-' and '.' ex: ubuntu-20.04 -> ubuntu2004
OS=$(echo $2 | tr -d ".\-")
# Installs CUDA
wget -nv https://developer.download.nvidia.com/compute/cuda/repos/${OS}/x86_64/cuda-keyring_1.1-1_all.deb
sudo dpkg -i cuda-keyring_1.1-1_all.deb
rm cuda-keyring_1.1-1_all.deb
sudo apt -qq update
sudo apt -y install cuda-${cuda_version} cuda-nvcc-${cuda_version} cuda-libraries-dev-${cuda_version}
sudo apt clean
# Test nvcc
PATH=/usr/local/cuda-$1/bin:${PATH}
nvcc --version

56
.github/workflows/scripts/env.sh vendored Normal file
View File

@ -0,0 +1,56 @@
#!/bin/bash
# This file installs common linux environment tools
export LANG C.UTF-8
# python_version=$1
sudo apt-get update && \
sudo apt-get install -y --no-install-recommends \
software-properties-common \
sudo apt-get install -y --no-install-recommends \
build-essential \
apt-utils \
ca-certificates \
wget \
git \
vim \
libssl-dev \
curl \
unzip \
unrar \
cmake \
net-tools \
sudo \
autotools-dev \
rsync \
jq \
openssh-server \
tmux \
screen \
htop \
pdsh \
openssh-client \
lshw \
dmidecode \
util-linux \
automake \
autoconf \
libtool \
net-tools \
pciutils \
libpci-dev \
libaio-dev \
libcap2 \
libtinfo5 \
fakeroot \
devscripts \
debhelper \
nfs-common
# Remove github bloat files to free up disk space
sudo rm -rf "/usr/local/share/boost"
sudo rm -rf "$AGENT_TOOLSDIRECTORY"
sudo rm -rf "/usr/share/dotnet"

View File

@ -0,0 +1,14 @@
#!/bin/bash
python_executable=python$1
cuda_version=$2
# Install torch
$python_executable -m pip install numpy pyyaml scipy ipython mkl mkl-include ninja cython typing pandas typing-extensions dataclasses setuptools && conda clean -ya
$python_executable -m pip install torch -f https://download.pytorch.org/whl/cu${cuda_version//./}/torch_stable.html
# Print version information
$python_executable --version
$python_executable -c "import torch; print('PyTorch:', torch.__version__)"
$python_executable -c "import torch; print('CUDA:', torch.version.cuda)"
$python_executable -c "from torch.utils import cpp_extension; print (cpp_extension.CUDA_HOME)"

31
.github/workflows/yapf.yml vendored Normal file
View File

@ -0,0 +1,31 @@
name: yapf
on:
# Trigger the workflow on push or pull request,
# but only for the main branch
push:
branches:
- main
pull_request:
branches:
- main
jobs:
yapf:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.10"]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install yapf==0.32.0
pip install toml==0.10.2
- name: Running yapf
run: |
yapf --diff --recursive vllm --exclude 'vllm/model_executor/parallel_utils/**'

3
.gitignore vendored
View File

@ -170,3 +170,6 @@ cython_debug/
# Python pickle files # Python pickle files
*.pkl *.pkl
# Sphinx documentation
_build/

434
.pylintrc Normal file
View File

@ -0,0 +1,434 @@
# This Pylint rcfile contains a best-effort configuration to uphold the
# best-practices and style described in the Google Python style guide:
# https://google.github.io/styleguide/pyguide.html
#
# Its canonical open-source location is:
# https://google.github.io/styleguide/pylintrc
[MASTER]
# Files or directories to be skipped. They should be base names, not paths.
ignore=docs,parallel_utils
# Files or directories matching the regex patterns are skipped. The regex
# matches against base names, not paths.
ignore-patterns=
# Pickle collected data for later comparisons.
persistent=no
# List of plugins (as comma separated values of python modules names) to load,
# usually to register additional checkers.
load-plugins=
# Use multiple processes to speed up Pylint.
jobs=4
# Allow loading of arbitrary C extensions. Extensions are imported into the
# active Python interpreter and may run arbitrary code.
unsafe-load-any-extension=no
[MESSAGES CONTROL]
# Only show warnings with the listed confidence levels. Leave empty to show
# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED
confidence=
# Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option
# multiple time (only on the command line, not in the configuration file where
# it should appear only once). See also the "--disable" option for examples.
#enable=
# Disable the message, report, category or checker with the given id(s). You
# can either give multiple identifiers separated by comma (,) or put this
# option multiple times (only on the command line, not in the configuration
# file where it should appear only once).You can also use "--disable=all" to
# disable everything first and then reenable specific checks. For example, if
# you want to run only the similarities checker, you can use "--disable=all
# --enable=similarities". If you want to run only the classes checker, but have
# no Warning level messages displayed, use"--disable=all --enable=classes
# --disable=W"
disable=abstract-method,
apply-builtin,
arguments-differ,
attribute-defined-outside-init,
backtick,
bad-option-value,
basestring-builtin,
buffer-builtin,
c-extension-no-member,
consider-using-enumerate,
cmp-builtin,
cmp-method,
coerce-builtin,
coerce-method,
delslice-method,
div-method,
duplicate-code,
eq-without-hash,
execfile-builtin,
file-builtin,
filter-builtin-not-iterating,
fixme,
getslice-method,
global-statement,
hex-method,
idiv-method,
implicit-str-concat-in-sequence,
import-error,
import-self,
import-star-module-level,
inconsistent-return-statements,
input-builtin,
intern-builtin,
invalid-str-codec,
locally-disabled,
logging-fstring-interpolation, # added by vLLM
logging-not-lazy, # added by vLLM
long-builtin,
long-suffix,
map-builtin-not-iterating,
misplaced-comparison-constant,
missing-class-docstring, # TODO (vLLM): enable
missing-function-docstring,
missing-module-docstring, # TODO (vLLM): enable
metaclass-assignment,
next-method-called,
next-method-defined,
no-absolute-import,
no-else-break,
no-else-continue,
no-else-raise,
no-else-return,
no-init, # added
no-member,
no-name-in-module,
no-self-use,
nonzero-method,
oct-method,
old-division,
old-ne-operator,
old-octal-literal,
old-raise-syntax,
parameter-unpacking,
print-statement,
raising-string,
range-builtin-not-iterating,
raw_input-builtin,
rdiv-method,
reduce-builtin,
relative-import,
reload-builtin,
round-builtin,
setslice-method,
signature-differs,
standarderror-builtin,
suppressed-message,
sys-max-int,
too-few-public-methods,
too-many-ancestors,
too-many-arguments,
too-many-boolean-expressions,
too-many-branches,
too-many-instance-attributes,
too-many-locals,
too-many-nested-blocks,
too-many-public-methods,
too-many-return-statements,
too-many-statements,
trailing-newlines,
unichr-builtin,
unicode-builtin,
unnecessary-pass,
unpacking-in-except,
unspecified-encoding,
useless-else-on-loop,
useless-object-inheritance,
useless-suppression,
using-cmp-argument,
wrong-import-order,
xrange-builtin,
zip-builtin-not-iterating,
[REPORTS]
# Set the output format. Available formats are text, parseable, colorized, msvs
# (visual studio) and html. You can also give a reporter class, eg
# mypackage.mymodule.MyReporterClass.
output-format=text
# Tells whether to display a full report or only the messages
reports=no
# Python expression which should return a note less than 10 (10 is the highest
# note). You have access to the variables errors warning, statement which
# respectively contain the number of errors / warnings messages and the total
# number of statements analyzed. This is used by the global evaluation report
# (RP0004).
evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
# Template used to display messages. This is a python new-style format string
# used to format the message information. See doc for all details
#msg-template=
[BASIC]
# Good variable names which should always be accepted, separated by a comma
good-names=main,_
# Bad variable names which should always be refused, separated by a comma
bad-names=
# Colon-delimited sets of names that determine each other's naming style when
# the name regexes allow several styles.
name-group=
# Include a hint for the correct naming format with invalid-name
include-naming-hint=no
# List of decorators that produce properties, such as abc.abstractproperty. Add
# to this list to register other decorators that produce valid properties.
property-classes=abc.abstractproperty,cached_property.cached_property,cached_property.threaded_cached_property,cached_property.cached_property_with_ttl,cached_property.threaded_cached_property_with_ttl
# Regular expression matching correct function names
function-rgx=^(?:(?P<exempt>setUp|tearDown|setUpModule|tearDownModule)|(?P<camel_case>_?[A-Z][a-zA-Z0-9]*)|(?P<snake_case>_?[a-z][a-z0-9_]*))$
# Regular expression matching correct variable names
variable-rgx=^[a-z][a-z0-9_]*$
# Regular expression matching correct constant names
const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$
# Regular expression matching correct attribute names
attr-rgx=^_{0,2}[a-z][a-z0-9_]*$
# Regular expression matching correct argument names
argument-rgx=^[a-z][a-z0-9_]*$
# Regular expression matching correct class attribute names
class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$
# Regular expression matching correct inline iteration names
inlinevar-rgx=^[a-z][a-z0-9_]*$
# Regular expression matching correct class names
class-rgx=^_?[A-Z][a-zA-Z0-9]*$
# Regular expression matching correct module names
module-rgx=^(_?[a-z][a-z0-9_]*|__init__)$
# Regular expression matching correct method names
method-rgx=(?x)^(?:(?P<exempt>_[a-z0-9_]+__|runTest|setUp|tearDown|setUpTestCase|tearDownTestCase|setupSelf|tearDownClass|setUpClass|(test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|next)|(?P<camel_case>_{0,2}[A-Z][a-zA-Z0-9_]*)|(?P<snake_case>_{0,2}[a-z][a-z0-9_]*))$
# Regular expression which should only match function or class names that do
# not require a docstring.
no-docstring-rgx=(__.*__|main|test.*|.*test|.*Test)$
# Minimum line length for functions/classes that require docstrings, shorter
# ones are exempt.
docstring-min-length=10
[TYPECHECK]
# List of decorators that produce context managers, such as
# contextlib.contextmanager. Add to this list to register other decorators that
# produce valid context managers.
contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager
# Tells whether missing members accessed in mixin class should be ignored. A
# mixin class is detected if its name ends with "mixin" (case insensitive).
ignore-mixin-members=yes
# List of module names for which member attributes should not be checked
# (useful for modules/projects where namespaces are manipulated during runtime
# and thus existing member attributes cannot be deduced by static analysis. It
# supports qualified module names, as well as Unix pattern matching.
ignored-modules=
# List of class names for which member attributes should not be checked (useful
# for classes with dynamically set attributes). This supports the use of
# qualified names.
ignored-classes=optparse.Values,thread._local,_thread._local
# List of members which are set dynamically and missed by pylint inference
# system, and so shouldn't trigger E1101 when accessed. Python regular
# expressions are accepted.
generated-members=
[FORMAT]
# Maximum number of characters on a single line.
max-line-length=80
# TODO(https://github.com/PyCQA/pylint/issues/3352): Direct pylint to exempt
# lines made too long by directives to pytype.
# Regexp for a line that is allowed to be longer than the limit.
ignore-long-lines=(?x)(
^\s*(\#\ )?<?https?://\S+>?$|
^\s*(from\s+\S+\s+)?import\s+.+$)
# Allow the body of an if to be on the same line as the test if there is no
# else.
single-line-if-stmt=yes
# Maximum number of lines in a module
max-module-lines=99999
# String used as indentation unit. The internal Google style guide mandates 2
# spaces. Google's externaly-published style guide says 4, consistent with
# PEP 8. Here, we use 2 spaces, for conformity with many open-sourced Google
# projects (like TensorFlow).
indent-string=' '
# Number of spaces of indent required inside a hanging or continued line.
indent-after-paren=4
# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
expected-line-ending-format=
[MISCELLANEOUS]
# List of note tags to take in consideration, separated by a comma.
notes=TODO
[STRING]
# This flag controls whether inconsistent-quotes generates a warning when the
# character used as a quote delimiter is used inconsistently within a module.
check-quote-consistency=yes
[VARIABLES]
# Tells whether we should check for unused import in __init__ files.
init-import=no
# A regular expression matching the name of dummy variables (i.e. expectedly
# not used).
dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_)
# List of additional names supposed to be defined in builtins. Remember that
# you should avoid to define new builtins when possible.
additional-builtins=
# List of strings which can identify a callback function by name. A callback
# name must start or end with one of those strings.
callbacks=cb_,_cb
# List of qualified module names which can have objects that can redefine
# builtins.
redefining-builtins-modules=six,six.moves,past.builtins,future.builtins,functools
[LOGGING]
# Logging modules to check that the string format arguments are in logging
# function parameter format
logging-modules=logging,absl.logging,tensorflow.io.logging
[SIMILARITIES]
# Minimum lines number of a similarity.
min-similarity-lines=4
# Ignore comments when computing similarities.
ignore-comments=yes
# Ignore docstrings when computing similarities.
ignore-docstrings=yes
# Ignore imports when computing similarities.
ignore-imports=no
[SPELLING]
# Spelling dictionary name. Available dictionaries: none. To make it working
# install python-enchant package.
spelling-dict=
# List of comma separated words that should not be checked.
spelling-ignore-words=
# A path to a file that contains private dictionary; one word per line.
spelling-private-dict-file=
# Tells whether to store unknown words to indicated private dictionary in
# --spelling-private-dict-file option instead of raising a message.
spelling-store-unknown-words=no
[IMPORTS]
# Deprecated modules which should not be used, separated by a comma
deprecated-modules=regsub,
TERMIOS,
Bastion,
rexec,
sets
# Create a graph of every (i.e. internal and external) dependencies in the
# given file (report RP0402 must not be disabled)
import-graph=
# Create a graph of external dependencies in the given file (report RP0402 must
# not be disabled)
ext-import-graph=
# Create a graph of internal dependencies in the given file (report RP0402 must
# not be disabled)
int-import-graph=
# Force import order to recognize a module as part of the standard
# compatibility libraries.
known-standard-library=
# Force import order to recognize a module as part of a third party library.
known-third-party=enchant, absl
# Analyse import fallback blocks. This can be used to support both Python 2 and
# 3 compatible code, which means that the block might have code that exists
# only in one or another interpreter, leading to false positives when analysed.
analyse-fallback-blocks=no
[CLASSES]
# List of method names used to declare (i.e. assign) instance attributes.
defining-attr-methods=__init__,
__new__,
setUp
# List of member names, which should be excluded from the protected access
# warning.
exclude-protected=_asdict,
_fields,
_replace,
_source,
_make
# List of valid names for the first argument in a class method.
valid-classmethod-first-arg=cls,
class_
# List of valid names for the first argument in a metaclass class method.
valid-metaclass-classmethod-first-arg=mcs
[EXCEPTIONS]
# Exceptions that will emit a warning when being caught. Defaults to
# "Exception"
overgeneral-exceptions=StandardError,
Exception,
BaseException

View File

@ -49,12 +49,15 @@ If not, please file a new issue, providing as much relevant information as possi
In general, we adhere to [Google Python style guide](https://google.github.io/styleguide/pyguide.html) and [Google C++ style guide](https://google.github.io/styleguide/cppguide.html). In general, we adhere to [Google Python style guide](https://google.github.io/styleguide/pyguide.html) and [Google C++ style guide](https://google.github.io/styleguide/cppguide.html).
We include a formatting script [`format.sh`](./format.sh) to format the code.
### Pull Requests ### Pull Requests
When submitting a pull request: When submitting a pull request:
1. Make sure your code has been rebased on top of the latest commit on the main branch. 1. Make sure your code has been rebased on top of the latest commit on the main branch.
2. Include a detailed description of the changes in the pull request. 2. Ensure code is properly formatted by running [`format.sh`](./format.sh).
3. Include a detailed description of the changes in the pull request.
Explain why you made the changes you did. Explain why you made the changes you did.
If your pull request fixes an open issue, please include a reference to it in the description. If your pull request fixes an open issue, please include a reference to it in the description.

View File

@ -17,8 +17,9 @@ Easy, fast, and cheap LLM serving for everyone
--- ---
*Latest News* 🔥 *Latest News* 🔥
- [2023/07] Added support for LLaMA-2! You can run and serve 7B/13B/70B LLaMA-2s on vLLM with a single command!
- [2023/06] We officially released vLLM! vLLM has powered [LMSYS Vicuna and Chatbot Arena](https://chat.lmsys.org) since mid April. Check out our [blog post](https://vllm.ai). - [2023/06] Serving vLLM On any Cloud with SkyPilot. Check out a 1-click [example](https://github.com/skypilot-org/skypilot/blob/master/llm/vllm) to start the vLLM demo, and the [blog post](https://blog.skypilot.co/serving-llm-24x-faster-on-the-cloud-with-vllm-and-skypilot/) for the story behind vLLM development on the clouds.
- [2023/06] We officially released vLLM! FastChat-vLLM integration has powered [LMSYS Vicuna and Chatbot Arena](https://chat.lmsys.org) since mid-April. Check out our [blog post](https://vllm.ai).
--- ---
@ -28,7 +29,7 @@ vLLM is fast with:
- State-of-the-art serving throughput - State-of-the-art serving throughput
- Efficient management of attention key and value memory with **PagedAttention** - Efficient management of attention key and value memory with **PagedAttention**
- Dynamic batching of incoming requests - Continuous batching of incoming requests
- Optimized CUDA kernels - Optimized CUDA kernels
vLLM is flexible and easy to use with: vLLM is flexible and easy to use with:
@ -41,10 +42,19 @@ vLLM is flexible and easy to use with:
vLLM seamlessly supports many Huggingface models, including the following architectures: vLLM seamlessly supports many Huggingface models, including the following architectures:
- Aquila (`BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc.)
- Baichuan (`baichuan-inc/Baichuan-7B`, `baichuan-inc/Baichuan-13B-Chat`, etc.)
- BLOOM (`bigscience/bloom`, `bigscience/bloomz`, etc.)
- Falcon (`tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc.)
- GPT-2 (`gpt2`, `gpt2-xl`, etc.) - GPT-2 (`gpt2`, `gpt2-xl`, etc.)
- GPTNeoX (`EleutherAI/gpt-neox-20b`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc.) - GPT BigCode (`bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, etc.)
- LLaMA (`lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.) - GPT-J (`EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc.)
- GPT-NeoX (`EleutherAI/gpt-neox-20b`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc.)
- InternLM (`internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc.)
- LLaMA & LLaMA-2 (`meta-llama/Llama-2-70b-hf`, `lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.)
- MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.)
- OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.) - OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.)
- Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.)
Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source): Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source):

View File

@ -17,9 +17,11 @@ def main(args: argparse.Namespace):
# the engine will automatically process the request in multiple batches. # the engine will automatically process the request in multiple batches.
llm = LLM( llm = LLM(
model=args.model, model=args.model,
tokenizer=args.tokenizer,
tensor_parallel_size=args.tensor_parallel_size, tensor_parallel_size=args.tensor_parallel_size,
max_num_seqs=args.batch_size, max_num_seqs=args.batch_size,
max_num_batched_tokens=args.batch_size * args.input_len, max_num_batched_tokens=args.batch_size * args.input_len,
trust_remote_code=args.trust_remote_code,
) )
sampling_params = SamplingParams( sampling_params = SamplingParams(
@ -63,6 +65,7 @@ if __name__ == '__main__':
description='Benchmark the latency of processing a single batch of ' description='Benchmark the latency of processing a single batch of '
'requests till completion.') 'requests till completion.')
parser.add_argument('--model', type=str, default='facebook/opt-125m') parser.add_argument('--model', type=str, default='facebook/opt-125m')
parser.add_argument('--tokenizer', type=str, default=None)
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1) parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
parser.add_argument('--input-len', type=int, default=32) parser.add_argument('--input-len', type=int, default=32)
parser.add_argument('--output-len', type=int, default=128) parser.add_argument('--output-len', type=int, default=128)
@ -72,5 +75,7 @@ if __name__ == '__main__':
parser.add_argument('--use-beam-search', action='store_true') parser.add_argument('--use-beam-search', action='store_true')
parser.add_argument('--num-iters', type=int, default=3, parser.add_argument('--num-iters', type=int, default=3,
help='Number of iterations to run.') help='Number of iterations to run.')
parser.add_argument('--trust-remote-code', action='store_true',
help='trust remote code from huggingface')
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)

View File

@ -24,20 +24,13 @@ from typing import AsyncGenerator, List, Tuple
import aiohttp import aiohttp
import numpy as np import numpy as np
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from vllm.transformers_utils.tokenizer import get_tokenizer
# (prompt len, output len, latency) # (prompt len, output len, latency)
REQUEST_LATENCY: List[Tuple[int, int, float]] = [] REQUEST_LATENCY: List[Tuple[int, int, float]] = []
def get_tokenizer(model_name: str) -> PreTrainedTokenizerBase:
config = AutoConfig.from_pretrained(model_name)
if config.model_type == "llama":
# A workaround for potential protobuf errors.
model_name = "hf-internal-testing/llama-tokenizer"
return AutoTokenizer.from_pretrained(model_name)
def sample_requests( def sample_requests(
dataset_path: str, dataset_path: str,
num_requests: int, num_requests: int,
@ -184,7 +177,7 @@ def main(args: argparse.Namespace):
np.random.seed(args.seed) np.random.seed(args.seed)
api_url = f"http://{args.host}:{args.port}/generate" api_url = f"http://{args.host}:{args.port}/generate"
tokenizer = get_tokenizer(args.tokenizer) tokenizer = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code)
input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer) input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
benchmark_start_time = time.time() benchmark_start_time = time.time()
@ -217,7 +210,7 @@ if __name__ == "__main__":
parser.add_argument("--backend", type=str, default="vllm", parser.add_argument("--backend", type=str, default="vllm",
choices=["vllm", "tgi"]) choices=["vllm", "tgi"])
parser.add_argument("--host", type=str, default="localhost") parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8001) parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--dataset", type=str, required=True, parser.add_argument("--dataset", type=str, required=True,
help="Path to the dataset.") help="Path to the dataset.")
parser.add_argument("--tokenizer", type=str, required=True, parser.add_argument("--tokenizer", type=str, required=True,
@ -234,5 +227,7 @@ if __name__ == "__main__":
"Otherwise, we use Poisson process to synthesize " "Otherwise, we use Poisson process to synthesize "
"the request arrival times.") "the request arrival times.")
parser.add_argument("--seed", type=int, default=0) parser.add_argument("--seed", type=int, default=0)
parser.add_argument('--trust-remote-code', action='store_true',
help='trust remote code from huggingface')
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)

View File

@ -6,23 +6,11 @@ import time
from typing import List, Tuple from typing import List, Tuple
import torch import torch
from transformers import (AutoConfig, AutoTokenizer, AutoModelForCausalLM, from transformers import AutoModelForCausalLM, PreTrainedTokenizerBase
PreTrainedTokenizerBase)
from tqdm import tqdm from tqdm import tqdm
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.transformers_utils.tokenizer import get_tokenizer
def get_tokenizer(model_name: str) -> PreTrainedTokenizerBase:
config = AutoConfig.from_pretrained(model_name)
if config.model_type == "llama":
# A workaround for potential protobuf errors.
model_name = "hf-internal-testing/llama-tokenizer"
tokenizer = AutoTokenizer.from_pretrained(model_name)
# To enable padding in the HF backend.
tokenizer.pad_token = tokenizer.eos_token
return tokenizer
return AutoTokenizer.from_pretrained(model_name)
def sample_requests( def sample_requests(
@ -74,15 +62,19 @@ def sample_requests(
def run_vllm( def run_vllm(
requests: List[Tuple[str, int, int]], requests: List[Tuple[str, int, int]],
model: str, model: str,
tokenizer: str,
tensor_parallel_size: int, tensor_parallel_size: int,
seed: int, seed: int,
n: int, n: int,
use_beam_search: bool, use_beam_search: bool,
trust_remote_code: bool,
) -> float: ) -> float:
llm = LLM( llm = LLM(
model=model, model=model,
tokenizer=tokenizer,
tensor_parallel_size=tensor_parallel_size, tensor_parallel_size=tensor_parallel_size,
seed=seed, seed=seed,
trust_remote_code=trust_remote_code,
) )
# Add the requests to the engine. # Add the requests to the engine.
@ -116,11 +108,14 @@ def run_hf(
n: int, n: int,
use_beam_search: bool, use_beam_search: bool,
max_batch_size: int, max_batch_size: int,
trust_remote_code: bool,
) -> float: ) -> float:
assert not use_beam_search assert not use_beam_search
tokenizer = get_tokenizer(model) llm = AutoModelForCausalLM.from_pretrained(model,
llm = AutoModelForCausalLM.from_pretrained( torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
model, torch_dtype=torch.float16) if llm.config.model_type == "llama":
# To enable padding in the HF backend.
tokenizer.pad_token = tokenizer.eos_token
llm = llm.cuda() llm = llm.cuda()
pbar = tqdm(total=len(requests)) pbar = tqdm(total=len(requests))
@ -170,17 +165,18 @@ def main(args: argparse.Namespace):
random.seed(args.seed) random.seed(args.seed)
# Sample the requests. # Sample the requests.
tokenizer = get_tokenizer(args.model) tokenizer = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code)
requests = sample_requests(args.dataset, args.num_prompts, tokenizer) requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
if args.backend == "vllm": if args.backend == "vllm":
elapsed_time = run_vllm( elapsed_time = run_vllm(
requests, args.model, args.tensor_parallel_size, args.seed, args.n, requests, args.model, args.tokenizer, args.tensor_parallel_size,
args.use_beam_search) args.seed, args.n, args.use_beam_search, args.trust_remote_code)
elif args.backend == "hf": elif args.backend == "hf":
assert args.tensor_parallel_size == 1 assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n, elapsed_time = run_hf(
args.use_beam_search, args.hf_max_batch_size) requests, args.model, tokenizer, args.n, args.use_beam_search,
args.hf_max_batch_size, args.trust_remote_code)
else: else:
raise ValueError(f"Unknown backend: {args.backend}") raise ValueError(f"Unknown backend: {args.backend}")
total_num_tokens = sum( total_num_tokens = sum(
@ -198,6 +194,7 @@ if __name__ == "__main__":
parser.add_argument("--dataset", type=str, required=True, parser.add_argument("--dataset", type=str, required=True,
help="Path to the dataset.") help="Path to the dataset.")
parser.add_argument("--model", type=str, default="facebook/opt-125m") parser.add_argument("--model", type=str, default="facebook/opt-125m")
parser.add_argument("--tokenizer", type=str, default=None)
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1) parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
parser.add_argument("--n", type=int, default=1, parser.add_argument("--n", type=int, default=1,
help="Number of generated sequences per prompt.") help="Number of generated sequences per prompt.")
@ -207,12 +204,18 @@ if __name__ == "__main__":
parser.add_argument("--seed", type=int, default=0) parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--hf-max-batch-size", type=int, default=None, parser.add_argument("--hf-max-batch-size", type=int, default=None,
help="Maximum batch size for HF backend.") help="Maximum batch size for HF backend.")
parser.add_argument('--trust-remote-code',
action='store_true',
help='trust remote code from huggingface')
args = parser.parse_args() args = parser.parse_args()
if args.backend == "vllm": if args.backend == "vllm":
if args.hf_max_batch_size is not None: if args.hf_max_batch_size is not None:
raise ValueError("HF max batch size is only for HF backend.") raise ValueError("HF max batch size is only for HF backend.")
elif args.backend == "hf": elif args.backend == "hf":
if args.hf_max_batch_size is None: if args.hf_max_batch_size is None:
raise ValueError("HF max batch size is required for HF backend.") raise ValueError("HF max batch size is required for HF backend.")
if args.tokenizer is None:
args.tokenizer = args.model
main(args) main(args)

View File

@ -1,6 +1,6 @@
#!/bin/bash #!/bin/bash
PORT=8001 PORT=8000
MODEL=$1 MODEL=$1
TOKENS=$2 TOKENS=$2

View File

@ -4,9 +4,25 @@ void silu_and_mul(
torch::Tensor& out, torch::Tensor& out,
torch::Tensor& input); torch::Tensor& input);
void gelu_new(
torch::Tensor& out,
torch::Tensor& input);
void gelu_fast(
torch::Tensor& out,
torch::Tensor& input);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def( m.def(
"silu_and_mul", "silu_and_mul",
&silu_and_mul, &silu_and_mul,
"Activation function used in SwiGLU."); "Activation function used in SwiGLU.");
m.def(
"gelu_new",
&gelu_new,
"GELU implementation used in GPT-2.");
m.def(
"gelu_fast",
&gelu_fast,
"Approximate GELU implementation.");
} }

View File

@ -46,3 +46,71 @@ void silu_and_mul(
d); d);
}); });
} }
namespace vllm {
// Element-wise activation kernel template.
template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
__global__ void activation_kernel(
scalar_t* __restrict__ out, // [num_tokens, d]
const scalar_t* __restrict__ input, // [num_tokens, d]
const int d) {
const int token_idx = blockIdx.x;
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = __ldg(&input[token_idx * d + idx]);
out[token_idx * d + idx] = ACT_FN(x);
}
}
} // namespace vllm
// Launch element-wise activation kernel.
#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
int num_tokens = input.size(0); \
int d = input.size(1); \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
AT_DISPATCH_FLOATING_TYPES_AND2( \
at::ScalarType::Half, \
at::ScalarType::BFloat16, \
input.scalar_type(), \
"activation_kernel", \
[&] { \
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>( \
out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), \
d); \
});
namespace vllm {
template<typename T>
__device__ __forceinline__ T gelu_new_kernel(const T& x) {
const float x3 = (float) (x * x * x);
const T t = (T) tanhf((T) (0.79788456f * (float) (x + (T) (0.044715f * x3))));
return ((T) 0.5) * x * (((T) 1.0) + t);
}
template<typename T>
__device__ __forceinline__ T gelu_fast_kernel(const T& x) {
const float f = (float) x;
const T t = (T) tanhf(((T) (f * 0.79788456f)) * (((T) 1.0) + (T) (0.044715f * f) * x));
return ((T) 0.5) * x * (((T) 1.0) + t);
}
} // namespace vllm
void gelu_new(
torch::Tensor& out, // [num_tokens, d]
torch::Tensor& input) // [num_tokens, d]
{
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel);
}
void gelu_fast(
torch::Tensor& out, // [num_tokens, d]
torch::Tensor& input) // [num_tokens, d]
{
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel);
}

View File

@ -1,15 +1,18 @@
#include <torch/extension.h> #include <torch/extension.h>
#include <c10/util/Optional.h>
void single_query_cached_kv_attention( void single_query_cached_kv_attention(
torch::Tensor& out, torch::Tensor& out,
torch::Tensor& query, torch::Tensor& query,
torch::Tensor& key_cache, torch::Tensor& key_cache,
torch::Tensor& value_cache, torch::Tensor& value_cache,
torch::Tensor& head_mapping,
float scale, float scale,
torch::Tensor& block_tables, torch::Tensor& block_tables,
torch::Tensor& context_lens, torch::Tensor& context_lens,
int block_size, int block_size,
int max_context_len); int max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def( m.def(

View File

@ -74,14 +74,20 @@ template<
__global__ void single_query_cached_kv_attention_kernel( __global__ void single_query_cached_kv_attention_kernel(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ k_cache, // [num_blocks, num_heads, head_size/x, block_size, x] const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
const scalar_t* __restrict__ v_cache, // [num_blocks, num_heads, head_size, block_size] const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
const int* __restrict__ head_mapping, // [num_heads]
const float scale, const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs] const int* __restrict__ context_lens, // [num_seqs]
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const int q_stride) { const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride,
const int kv_block_stride,
const int kv_head_stride) {
constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS
assert(NUM_THREADS % THREAD_GROUP_SIZE == 0);
constexpr int NUM_TOKENS_PER_THREAD_GROUP = (BLOCK_SIZE + WARP_SIZE - 1) / WARP_SIZE; constexpr int NUM_TOKENS_PER_THREAD_GROUP = (BLOCK_SIZE + WARP_SIZE - 1) / WARP_SIZE;
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
const int thread_idx = threadIdx.x; const int thread_idx = threadIdx.x;
@ -90,7 +96,9 @@ __global__ void single_query_cached_kv_attention_kernel(
const int head_idx = blockIdx.x; const int head_idx = blockIdx.x;
const int num_heads = gridDim.x; const int num_heads = gridDim.x;
const int kv_head_idx = head_mapping[head_idx];
const int seq_idx = blockIdx.y; const int seq_idx = blockIdx.y;
const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
// A vector type to store a part of a key or a query. // A vector type to store a part of a key or a query.
// The vector size is configured in such a way that the threads in a thread group // The vector size is configured in such a way that the threads in a thread group
@ -114,12 +122,13 @@ __global__ void single_query_cached_kv_attention_kernel(
// th vectors of the query, and so on. // th vectors of the query, and so on.
// NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous. // NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous.
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
Q_vec q_vecs[NUM_VECS_PER_THREAD]; __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_VECS_PER_THREAD; i++) { for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS) {
const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
q_vecs[i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE); q_vecs[thread_group_offset][i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
} }
__syncthreads(); // TODO(naed90): possible speedup if this is replaced with a memory wall right before we use q_vecs
// Memory planning. // Memory planning.
extern __shared__ char shared_mem[]; extern __shared__ char shared_mem[];
@ -156,8 +165,8 @@ __global__ void single_query_cached_kv_attention_kernel(
#pragma unroll #pragma unroll
for (int j = 0; j < NUM_VECS_PER_THREAD; j++) { for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
const scalar_t* k_ptr = k_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE const scalar_t* k_ptr = k_cache + physical_block_number * kv_block_stride
+ head_idx * HEAD_SIZE * BLOCK_SIZE + kv_head_idx * kv_head_stride
+ physical_block_offset * x; + physical_block_offset * x;
const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE; const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
const int offset1 = (vec_idx * VEC_SIZE) / x; const int offset1 = (vec_idx * VEC_SIZE) / x;
@ -167,12 +176,14 @@ __global__ void single_query_cached_kv_attention_kernel(
// Compute dot product. // Compute dot product.
// This includes a reduction across the threads in the same thread group. // This includes a reduction across the threads in the same thread group.
const float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs, k_vecs); float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs);
const bool mask = token_idx >= context_len; // Add the ALiBi bias if slopes are given.
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len) : 0;
if (thread_group_offset == 0) { if (thread_group_offset == 0) {
// Store the partial reductions to shared memory. // Store the partial reductions to shared memory.
// NOTE(woosuk): It is required to zero out the masked logits. // NOTE(woosuk): It is required to zero out the masked logits.
const bool mask = token_idx >= context_len;
logits[token_idx] = mask ? 0.f : qk; logits[token_idx] = mask ? 0.f : qk;
// Update the max value. // Update the max value.
qk_max = mask ? qk_max : fmaxf(qk_max, qk); qk_max = mask ? qk_max : fmaxf(qk_max, qk);
@ -242,8 +253,8 @@ __global__ void single_query_cached_kv_attention_kernel(
L_vec logits_vec; L_vec logits_vec;
from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx)); from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx));
const scalar_t* v_ptr = v_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE const scalar_t* v_ptr = v_cache + physical_block_number * kv_block_stride
+ head_idx * HEAD_SIZE * BLOCK_SIZE; + kv_head_idx * kv_head_stride;
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
@ -324,11 +335,15 @@ __global__ void single_query_cached_kv_attention_kernel(
query_ptr, \ query_ptr, \
key_cache_ptr, \ key_cache_ptr, \
value_cache_ptr, \ value_cache_ptr, \
head_mapping_ptr, \
scale, \ scale, \
block_tables_ptr, \ block_tables_ptr, \
context_lens_ptr, \ context_lens_ptr, \
max_num_blocks_per_seq, \ max_num_blocks_per_seq, \
query_stride); alibi_slopes_ptr, \
q_stride, \
kv_block_stride, \
kv_head_stride);
// TODO(woosuk): Tune NUM_THREADS. // TODO(woosuk): Tune NUM_THREADS.
template< template<
@ -340,23 +355,33 @@ void single_query_cached_kv_attention_launcher(
torch::Tensor& query, torch::Tensor& query,
torch::Tensor& key_cache, torch::Tensor& key_cache,
torch::Tensor& value_cache, torch::Tensor& value_cache,
torch::Tensor& head_mapping,
float scale, float scale,
torch::Tensor& block_tables, torch::Tensor& block_tables,
torch::Tensor& context_lens, torch::Tensor& context_lens,
int max_context_len) { int max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes) {
int num_seqs = query.size(0); int num_seqs = query.size(0);
int num_heads = query.size(1); int num_heads = query.size(1);
int head_size = query.size(2); int head_size = query.size(2);
int max_num_blocks_per_seq = block_tables.size(1); int max_num_blocks_per_seq = block_tables.size(1);
int query_stride = query.stride(0); int q_stride = query.stride(0);
int kv_block_stride = key_cache.stride(0);
int kv_head_stride = key_cache.stride(1);
int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1); int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
assert(head_size % thread_group_size == 0); assert(head_size % thread_group_size == 0);
// NOTE: alibi_slopes is optional.
const float* alibi_slopes_ptr = alibi_slopes ?
reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr;
T* out_ptr = reinterpret_cast<T*>(out.data_ptr()); T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
T* query_ptr = reinterpret_cast<T*>(query.data_ptr()); T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr()); T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr()); T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
int* head_mapping_ptr = reinterpret_cast<int*>(head_mapping.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>(); int* block_tables_ptr = block_tables.data_ptr<int>();
int* context_lens_ptr = context_lens.data_ptr<int>(); int* context_lens_ptr = context_lens.data_ptr<int>();
@ -371,7 +396,7 @@ void single_query_cached_kv_attention_launcher(
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
switch (head_size) { switch (head_size) {
// NOTE(woosuk): To reduce the compilation time, we omitted head sizes // NOTE(woosuk): To reduce the compilation time, we omitted head sizes
// 32, 160, 192, 256. // 32, 160, 192.
// case 32: // case 32:
// LAUNCH_ATTENTION_KERNEL(T, 32, BLOCK_SIZE, NUM_THREADS); // LAUNCH_ATTENTION_KERNEL(T, 32, BLOCK_SIZE, NUM_THREADS);
// break; // break;
@ -384,6 +409,9 @@ void single_query_cached_kv_attention_launcher(
case 96: case 96:
LAUNCH_ATTENTION_KERNEL(T, 96, BLOCK_SIZE, NUM_THREADS); LAUNCH_ATTENTION_KERNEL(T, 96, BLOCK_SIZE, NUM_THREADS);
break; break;
case 112:
LAUNCH_ATTENTION_KERNEL(T, 112, BLOCK_SIZE, NUM_THREADS);
break;
case 128: case 128:
LAUNCH_ATTENTION_KERNEL(T, 128, BLOCK_SIZE, NUM_THREADS); LAUNCH_ATTENTION_KERNEL(T, 128, BLOCK_SIZE, NUM_THREADS);
break; break;
@ -393,9 +421,9 @@ void single_query_cached_kv_attention_launcher(
// case 192: // case 192:
// LAUNCH_ATTENTION_KERNEL(T, 192, BLOCK_SIZE, NUM_THREADS); // LAUNCH_ATTENTION_KERNEL(T, 192, BLOCK_SIZE, NUM_THREADS);
// break; // break;
// case 256: case 256:
// LAUNCH_ATTENTION_KERNEL(T, 256, BLOCK_SIZE, NUM_THREADS); LAUNCH_ATTENTION_KERNEL(T, 256, BLOCK_SIZE, NUM_THREADS);
// break; break;
default: default:
TORCH_CHECK(false, "Unsupported head size: ", head_size); TORCH_CHECK(false, "Unsupported head size: ", head_size);
break; break;
@ -408,10 +436,12 @@ void single_query_cached_kv_attention_launcher(
query, \ query, \
key_cache, \ key_cache, \
value_cache, \ value_cache, \
head_mapping, \
scale, \ scale, \
block_tables, \ block_tables, \
context_lens, \ context_lens, \
max_context_len); max_context_len, \
alibi_slopes);
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes // NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// 1, 2, 4, 64, 128, 256. // 1, 2, 4, 64, 128, 256.
@ -454,11 +484,13 @@ void single_query_cached_kv_attention(
torch::Tensor& query, // [num_seqs, num_heads, head_size] torch::Tensor& query, // [num_seqs, num_heads, head_size]
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
torch::Tensor& head_mapping, // [num_heads]
float scale, float scale,
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& context_lens, // [num_seqs] torch::Tensor& context_lens, // [num_seqs]
int block_size, int block_size,
int max_context_len) { int max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes) {
if (query.dtype() == at::ScalarType::Float) { if (query.dtype() == at::ScalarType::Float) {
CALL_KERNEL_LAUNCHER_BLOCK_SIZE(float); CALL_KERNEL_LAUNCHER_BLOCK_SIZE(float);
} else if (query.dtype() == at::ScalarType::Half) { } else if (query.dtype() == at::ScalarType::Half) {

View File

@ -7,11 +7,13 @@ template<typename scalar_t>
__global__ void rotary_embedding_neox_kernel( __global__ void rotary_embedding_neox_kernel(
const int64_t* __restrict__ positions, // [num_tokens] const int64_t* __restrict__ positions, // [num_tokens]
scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size] scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size]
scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] scalar_t* __restrict__ key, // [num_tokens, num_kv_heads, head_size]
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
const int rot_dim, const int rot_dim,
const int stride, const int query_stride,
const int key_stride,
const int num_heads, const int num_heads,
const int num_kv_heads,
const int head_size) { const int head_size) {
// Each thread block is responsible for one token. // Each thread block is responsible for one token.
const int token_idx = blockIdx.x; const int token_idx = blockIdx.x;
@ -19,17 +21,17 @@ __global__ void rotary_embedding_neox_kernel(
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
const int embed_dim = rot_dim / 2; const int embed_dim = rot_dim / 2;
const int n = num_heads * embed_dim; const int nq = num_heads * embed_dim;
for (int i = threadIdx.x; i < n; i += blockDim.x) { for (int i = threadIdx.x; i < nq; i += blockDim.x) {
const int head_idx = i / embed_dim; const int head_idx = i / embed_dim;
const int token_head = token_idx * stride + head_idx * head_size; const int token_head = token_idx * query_stride + head_idx * head_size;
const int rot_offset = i % embed_dim; const int rot_offset = i % embed_dim;
const int x_index = rot_offset; const int x_index = rot_offset;
const int y_index = embed_dim + rot_offset; const int y_index = embed_dim + rot_offset;
const int out_x = token_idx * stride + head_idx * head_size + x_index; const int out_x = token_idx * query_stride + head_idx * head_size + x_index;
const int out_y = token_idx * stride + head_idx * head_size + y_index; const int out_y = token_idx * query_stride + head_idx * head_size + y_index;
const scalar_t cos = __ldg(cache_ptr + x_index); const scalar_t cos = __ldg(cache_ptr + x_index);
const scalar_t sin = __ldg(cache_ptr + y_index); const scalar_t sin = __ldg(cache_ptr + y_index);
@ -38,6 +40,22 @@ __global__ void rotary_embedding_neox_kernel(
const scalar_t q_y = query[token_head + y_index]; const scalar_t q_y = query[token_head + y_index];
query[out_x] = q_x * cos - q_y * sin; query[out_x] = q_x * cos - q_y * sin;
query[out_y] = q_y * cos + q_x * sin; query[out_y] = q_y * cos + q_x * sin;
}
const int nk = num_kv_heads * embed_dim;
for (int i = threadIdx.x; i < nk; i += blockDim.x) {
const int head_idx = i / embed_dim;
const int token_head = token_idx * key_stride + head_idx * head_size;
const int rot_offset = i % embed_dim;
const int x_index = rot_offset;
const int y_index = embed_dim + rot_offset;
const int out_x = token_idx * key_stride + head_idx * head_size + x_index;
const int out_y = token_idx * key_stride + head_idx * head_size + y_index;
const scalar_t cos = __ldg(cache_ptr + x_index);
const scalar_t sin = __ldg(cache_ptr + y_index);
const scalar_t k_x = key[token_head + x_index]; const scalar_t k_x = key[token_head + x_index];
const scalar_t k_y = key[token_head + y_index]; const scalar_t k_y = key[token_head + y_index];
@ -51,15 +69,16 @@ __global__ void rotary_embedding_neox_kernel(
void rotary_embedding_neox( void rotary_embedding_neox(
torch::Tensor& positions, // [num_tokens] torch::Tensor& positions, // [num_tokens]
torch::Tensor& query, // [num_tokens, num_heads * head_size] torch::Tensor& query, // [num_tokens, num_heads * head_size]
torch::Tensor& key, // [num_tokens, num_heads * head_size] torch::Tensor& key, // [num_tokens, num_kv_heads * head_size]
int head_size, int head_size,
torch::Tensor& cos_sin_cache) // [max_position, rot_dim] torch::Tensor& cos_sin_cache) // [max_position, rot_dim]
{ {
int num_tokens = query.size(0); int num_tokens = query.size(0);
int rot_dim = cos_sin_cache.size(1); int rot_dim = cos_sin_cache.size(1);
int num_heads = query.size(1) / head_size; int num_heads = query.size(1) / head_size;
int stride = query.stride(0); int num_kv_heads = key.size(1) / head_size;
TORCH_CHECK(stride == key.stride(0)); int query_stride = query.stride(0);
int key_stride = key.stride(0);
dim3 grid(num_tokens); dim3 grid(num_tokens);
dim3 block(std::min(num_heads * rot_dim / 2, 512)); dim3 block(std::min(num_heads * rot_dim / 2, 512));
@ -76,8 +95,10 @@ void rotary_embedding_neox(
key.data_ptr<scalar_t>(), key.data_ptr<scalar_t>(),
cos_sin_cache.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
rot_dim, rot_dim,
stride, query_stride,
key_stride,
num_heads, num_heads,
num_kv_heads,
head_size); head_size);
}); });
} }

View File

@ -4,14 +4,14 @@
```bash ```bash
# Install dependencies. # Install dependencies.
pip -r requirements-docs.txt pip install -r requirements-docs.txt
# Build the docs. # Build the docs.
make clean make clean
make html make html
``` ```
## Open the docs with your brower ## Open the docs with your browser
```bash ```bash
python -m http.server -d build/html/ python -m http.server -d build/html/

View File

@ -29,7 +29,7 @@ vLLM is fast with:
* State-of-the-art serving throughput * State-of-the-art serving throughput
* Efficient management of attention key and value memory with **PagedAttention** * Efficient management of attention key and value memory with **PagedAttention**
* Dynamic batching of incoming requests * Continuous batching of incoming requests
* Optimized CUDA kernels * Optimized CUDA kernels
vLLM is flexible and easy to use with: vLLM is flexible and easy to use with:
@ -40,7 +40,11 @@ vLLM is flexible and easy to use with:
* Streaming outputs * Streaming outputs
* OpenAI-compatible API server * OpenAI-compatible API server
For more information, please refer to our `blog post <https://vllm.ai>`_. For more information, check out the following:
* `vLLM announcing blog post <https://vllm.ai>`_ (intro to PagedAttention)
* `How continuous batching enables 23x throughput in LLM inference while reducing p50 latency <https://www.anyscale.com/blog/continuous-batching-llm-inference>`_ by Cade Daniel et al.
Documentation Documentation
@ -53,6 +57,13 @@ Documentation
getting_started/installation getting_started/installation
getting_started/quickstart getting_started/quickstart
.. toctree::
:maxdepth: 1
:caption: Serving
serving/distributed_serving
serving/run_on_sky
.. toctree:: .. toctree::
:maxdepth: 1 :maxdepth: 1
:caption: Models :caption: Models

View File

@ -14,18 +14,45 @@ Alongside each architecture, we include some popular models that use it.
* - Architecture * - Architecture
- Models - Models
- Example HuggingFace Models - Example HuggingFace Models
* - :code:`AquilaForCausalLM`
- Aqualia
- :code:`BAAI/Aquila-7B`, :code:`BAAI/AquilaChat-7B`, etc.
* - :code:`BaiChuanForCausalLM`
- Baichuan
- :code:`baichuan-inc/Baichuan-7B`, :code:`baichuan-inc/Baichuan-13B-Chat`, etc.
* - :code:`BloomForCausalLM`
- BLOOM, BLOOMZ, BLOOMChat
- :code:`bigscience/bloom`, :code:`bigscience/bloomz`, etc.
* - :code:`FalconForCausalLM`
- Falcon
- :code:`tiiuae/falcon-7b``, :code:`tiiuae/falcon-40b`, :code:`tiiuae/falcon-rw-7b`, etc.
* - :code:`GPT2LMHeadModel` * - :code:`GPT2LMHeadModel`
- GPT-2 - GPT-2
- :code:`gpt2`, :code:`gpt2-xl`, etc. - :code:`gpt2`, :code:`gpt2-xl`, etc.
* - :code:`GPTBigCodeForCausalLM`
- StarCoder, SantaCoder, WizardCoder
- :code:`bigcode/starcoder`, :code:`bigcode/gpt_bigcode-santacoder`, :code:`WizardLM/WizardCoder-15B-V1.0`, etc.
* - :code:`GPTJForCausalLM`
- GPT-J
- :code:`EleutherAI/gpt-j-6b`, :code:`nomic-ai/gpt4all-j`, etc.
* - :code:`GPTNeoXForCausalLM` * - :code:`GPTNeoXForCausalLM`
- GPT-NeoX, Pythia, OpenAssistant, Dolly V2, StableLM - GPT-NeoX, Pythia, OpenAssistant, Dolly V2, StableLM
- :code:`EleutherAI/gpt-neox-20b`, :code:`EleutherAI/pythia-12b`, :code:`OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5`, :code:`databricks/dolly-v2-12b`, :code:`stabilityai/stablelm-tuned-alpha-7b`, etc. - :code:`EleutherAI/gpt-neox-20b`, :code:`EleutherAI/pythia-12b`, :code:`OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5`, :code:`databricks/dolly-v2-12b`, :code:`stabilityai/stablelm-tuned-alpha-7b`, etc.
* - :code:`InternLMForCausalLM`
- InternLM
- :code:`internlm/internlm-7b`, :code:`internlm/internlm-chat-7b`, etc.
* - :code:`LlamaForCausalLM` * - :code:`LlamaForCausalLM`
- LLaMA, Vicuna, Alpaca, Koala, Guanaco - LLaMA, LLaMA-2, Vicuna, Alpaca, Koala, Guanaco
- :code:`openlm-research/open_llama_13b`, :code:`lmsys/vicuna-13b-v1.3`, :code:`young-geng/koala`, :code:`JosephusCheung/Guanaco`, etc. - :code:`meta-llama/Llama-2-13b-hf`, :code:`openlm-research/open_llama_13b`, :code:`lmsys/vicuna-13b-v1.3`, :code:`young-geng/koala`, :code:`JosephusCheung/Guanaco`, etc.
* - :code:`MPTForCausalLM`
- MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter
- :code:`mosaicml/mpt-7b`, :code:`mosaicml/mpt-7b-storywriter`, :code:`mosaicml/mpt-30b`, etc.
* - :code:`OPTForCausalLM` * - :code:`OPTForCausalLM`
- OPT, OPT-IML - OPT, OPT-IML
- :code:`facebook/opt-66b`, :code:`facebook/opt-iml-max-30b`, etc. - :code:`facebook/opt-66b`, :code:`facebook/opt-iml-max-30b`, etc.
* - :code:`OPTForCausalLM`
- Qwen
- :code:`Qwen/Qwen-7B`, :code:`Qwen/Qwen-7B-Chat`, etc.
If your model uses one of the above model architectures, you can seamlessly run your model with vLLM. If your model uses one of the above model architectures, you can seamlessly run your model with vLLM.
Otherwise, please refer to :ref:`Adding a New Model <adding_a_new_model>` for instructions on how to implement support for your model. Otherwise, please refer to :ref:`Adding a New Model <adding_a_new_model>` for instructions on how to implement support for your model.

View File

@ -0,0 +1,38 @@
.. _distributed_serving:
Distributed Inference and Serving
=================================
vLLM supports distributed tensor-parallel inference and serving. Currently, we support `Megatron-LM's tensor parallel algorithm <https://arxiv.org/pdf/1909.08053.pdf>`_. We manage the distributed runtime with `Ray <https://github.com/ray-project/ray>`_. To run distributed inference, install Ray with:
.. code-block:: console
$ pip install ray
To run multi-GPU inference with the :code:`LLM` class, set the :code:`tensor_parallel_size` argument to the number of GPUs you want to use. For example, to run inference on 4 GPUs:
.. code-block:: python
from vllm import LLM
llm = LLM("facebook/opt-13b", tensor_parallel_size=4)
output = llm.generate("San Franciso is a")
To run multi-GPU serving, pass in the :code:`--tensor-parallel-size` argument when starting the server. For example, to run API server on 4 GPUs:
.. code-block:: console
$ python -m vllm.entrypoints.api_server \
$ --model facebook/opt-13b \
$ --tensor-parallel-size 4
To scale vLLM beyond a single machine, start a `Ray runtime <https://docs.ray.io/en/latest/ray-core/starting-ray.html>`_ via CLI before running vLLM:
.. code-block:: console
$ # On head node
$ ray start --head
$ # On worker nodes
$ ray start --address=<ray-head-address>
After that, you can run inference and serving on multiple machines by launching the vLLM process on the head node by setting :code:`tensor_parallel_size` to the number of GPUs to be the total number of GPUs across all machines.

View File

@ -0,0 +1,69 @@
.. _on_cloud:
Running on clouds with SkyPilot
===============================
.. raw:: html
<p align="center">
<img src="https://imgur.com/yxtzPEu.png" alt="vLLM"/>
</p>
vLLM can be run on the cloud to scale to multiple GPUs with `SkyPilot <https://github.com/skypilot-org/skypilot>`__, an open-source framework for running LLMs on any cloud.
To install SkyPilot and setup your cloud credentials, run:
.. code-block:: console
$ pip install skypilot
$ sky check
See the vLLM SkyPilot YAML for serving, `serving.yaml <https://github.com/skypilot-org/skypilot/blob/master/llm/vllm/serve.yaml>`__.
.. code-block:: yaml
resources:
accelerators: A100
envs:
MODEL_NAME: decapoda-research/llama-13b-hf
TOKENIZER: hf-internal-testing/llama-tokenizer
setup: |
conda create -n vllm python=3.9 -y
conda activate vllm
git clone https://github.com/vllm-project/vllm.git
cd vllm
pip install .
pip install gradio
run: |
conda activate vllm
echo 'Starting vllm api server...'
python -u -m vllm.entrypoints.api_server \
--model $MODEL_NAME \
--tensor-parallel-size $SKYPILOT_NUM_GPUS_PER_NODE \
--tokenizer $TOKENIZER 2>&1 | tee api_server.log &
echo 'Waiting for vllm api server to start...'
while ! `cat api_server.log | grep -q 'Uvicorn running on'`; do sleep 1; done
echo 'Starting gradio server...'
python vllm/examples/gradio_webserver.py
Start the serving the LLaMA-13B model on an A100 GPU:
.. code-block:: console
$ sky launch serving.yaml
Check the output of the command. There will be a sharable gradio link (like the last line of the following). Open it in your browser to use the LLaMA model to do the text completion.
.. code-block:: console
(task, pid=7431) Running on public URL: https://<gradio-hash>.gradio.live
**Optional**: Serve the 65B model instead of the default 13B and use more GPU:
.. code-block:: console
sky launch -c vllm-serve-new -s serve.yaml --gpus A100:8 --env MODEL_NAME=decapoda-research/llama-65b-hf

View File

@ -14,7 +14,9 @@ def clear_line(n: int = 1) -> None:
print(LINE_UP, end=LINE_CLEAR, flush=True) print(LINE_UP, end=LINE_CLEAR, flush=True)
def post_http_request(prompt: str, api_url: str, n: int = 1, def post_http_request(prompt: str,
api_url: str,
n: int = 1,
stream: bool = False) -> requests.Response: stream: bool = False) -> requests.Response:
headers = {"User-Agent": "Test Client"} headers = {"User-Agent": "Test Client"}
pload = { pload = {
@ -30,7 +32,8 @@ def post_http_request(prompt: str, api_url: str, n: int = 1,
def get_streaming_response(response: requests.Response) -> Iterable[List[str]]: def get_streaming_response(response: requests.Response) -> Iterable[List[str]]:
for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, for chunk in response.iter_lines(chunk_size=8192,
decode_unicode=False,
delimiter=b"\0"): delimiter=b"\0"):
if chunk: if chunk:
data = json.loads(chunk.decode("utf-8")) data = json.loads(chunk.decode("utf-8"))

View File

@ -12,9 +12,14 @@ def http_bot(prompt):
"stream": True, "stream": True,
"max_tokens": 128, "max_tokens": 128,
} }
response = requests.post(args.model_url, headers=headers, json=pload, stream=True) response = requests.post(args.model_url,
headers=headers,
json=pload,
stream=True)
for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"): for chunk in response.iter_lines(chunk_size=8192,
decode_unicode=False,
delimiter=b"\0"):
if chunk: if chunk:
data = json.loads(chunk.decode("utf-8")) data = json.loads(chunk.decode("utf-8"))
output = data["text"][0] output = data["text"][0]
@ -23,11 +28,11 @@ def http_bot(prompt):
def build_demo(): def build_demo():
with gr.Blocks() as demo: with gr.Blocks() as demo:
gr.Markdown( gr.Markdown("# vLLM text completion demo\n")
"# vLLM text completion demo\n" inputbox = gr.Textbox(label="Input",
) placeholder="Enter text and press ENTER")
inputbox = gr.Textbox(label="Input", placeholder="Enter text and press ENTER") outputbox = gr.Textbox(label="Output",
outputbox = gr.Textbox(label="Output", placeholder="Generated result from the model") placeholder="Generated result from the model")
inputbox.submit(http_bot, [inputbox], [outputbox]) inputbox.submit(http_bot, [inputbox], [outputbox])
return demo return demo
@ -36,7 +41,9 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost") parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8001) parser.add_argument("--port", type=int, default=8001)
parser.add_argument("--model-url", type=str, default="http://localhost:8000/generate") parser.add_argument("--model-url",
type=str,
default="http://localhost:8000/generate")
args = parser.parse_args() args = parser.parse_args()
demo = build_demo() demo = build_demo()

View File

@ -10,19 +10,25 @@ def main(args: argparse.Namespace):
# Test the following prompts. # Test the following prompts.
test_prompts = [ test_prompts = [
("A robot may not injure a human being", SamplingParams()), ("A robot may not injure a human being",
SamplingParams(temperature=0.0)),
("To be or not to be,", ("To be or not to be,",
SamplingParams(temperature=0.8, top_k=5, presence_penalty=0.2)), SamplingParams(temperature=0.8, top_k=5, presence_penalty=0.2)),
("What is the meaning of life?", ("What is the meaning of life?",
SamplingParams(n=2, best_of=5, temperature=0.8, top_p=0.95, frequency_penalty=0.1)), SamplingParams(n=2,
best_of=5,
temperature=0.8,
top_p=0.95,
frequency_penalty=0.1)),
("It is only with the heart that one can see rightly", ("It is only with the heart that one can see rightly",
SamplingParams(n=3, best_of=3, use_beam_search=True, temperature=0.0)), SamplingParams(n=3, best_of=3, use_beam_search=True,
temperature=0.0)),
] ]
# Run the engine by calling `engine.step()` manually. # Run the engine by calling `engine.step()` manually.
request_id = 0 request_id = 0
while True: while True:
# To test iteration-level scheduling, we add one request at each step. # To test continuous batching, we add one request at each step.
if test_prompts: if test_prompts:
prompt, sampling_params = test_prompts.pop(0) prompt, sampling_params = test_prompts.pop(0)
engine.add_request(str(request_id), prompt, sampling_params) engine.add_request(str(request_id), prompt, sampling_params)

View File

@ -1,6 +1,5 @@
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
# Sample prompts. # Sample prompts.
prompts = [ prompts = [
"Hello, my name is", "Hello, my name is",

View File

@ -0,0 +1,33 @@
import openai
# Modify OpenAI's API key and API base to use vLLM's API server.
openai.api_key = "EMPTY"
openai.api_base = "http://localhost:8000/v1"
# List models API
models = openai.Model.list()
print("Models:", models)
model = models["data"][0]["id"]
# Chat completion API
chat_completion = openai.ChatCompletion.create(
model=model,
messages=[{
"role": "system",
"content": "You are a helpful assistant."
}, {
"role": "user",
"content": "Who won the world series in 2020?"
}, {
"role":
"assistant",
"content":
"The Los Angeles Dodgers won the World Series in 2020."
}, {
"role": "user",
"content": "Where was it played?"
}])
print("Chat completion results:")
print(chat_completion)

View File

@ -3,21 +3,26 @@ import openai
# Modify OpenAI's API key and API base to use vLLM's API server. # Modify OpenAI's API key and API base to use vLLM's API server.
openai.api_key = "EMPTY" openai.api_key = "EMPTY"
openai.api_base = "http://localhost:8000/v1" openai.api_base = "http://localhost:8000/v1"
model = "facebook/opt-125m"
# Test list models API # List models API
models = openai.Model.list() models = openai.Model.list()
print("Models:", models) print("Models:", models)
# Test completion API model = models["data"][0]["id"]
stream = True
completion = openai.Completion.create(
model=model, prompt="A robot may not injure a human being", echo=False, n=2,
best_of=3, stream=stream, logprobs=3)
# print the completion # Completion API
stream = False
completion = openai.Completion.create(
model=model,
prompt="A robot may not injure a human being",
echo=False,
n=2,
stream=stream,
logprobs=3)
print("Completion results:")
if stream: if stream:
for c in completion: for c in completion:
print(c) print(c)
else: else:
print("Completion result:", completion) print(completion)

108
format.sh Executable file
View File

@ -0,0 +1,108 @@
#!/usr/bin/env bash
# YAPF formatter, adapted from ray and skypilot.
#
# Usage:
# # Do work and commit your work.
# # Format files that differ from origin/main.
# bash format.sh
# # Commit changed files with message 'Run yapf and pylint'
#
#
# YAPF + Clang formatter (if installed). This script formats all changed files from the last mergebase.
# You are encouraged to run this locally before pushing changes for review.
# Cause the script to exit if a single command fails
set -eo pipefail
# this stops git rev-parse from failing if we run this from the .git directory
builtin cd "$(dirname "${BASH_SOURCE:-$0}")"
ROOT="$(git rev-parse --show-toplevel)"
builtin cd "$ROOT" || exit 1
YAPF_VERSION=$(yapf --version | awk '{print $2}')
PYLINT_VERSION=$(pylint --version | head -n 1 | awk '{print $2}')
MYPY_VERSION=$(mypy --version | awk '{print $2}')
# # params: tool name, tool version, required version
tool_version_check() {
if [[ $2 != $3 ]]; then
echo "Wrong $1 version installed: $3 is required, not $2."
exit 1
fi
}
tool_version_check "yapf" $YAPF_VERSION "$(grep yapf requirements-dev.txt | cut -d'=' -f3)"
tool_version_check "pylint" $PYLINT_VERSION "$(grep "pylint==" requirements-dev.txt | cut -d'=' -f3)"
tool_version_check "mypy" "$MYPY_VERSION" "$(grep mypy requirements-dev.txt | cut -d'=' -f3)"
YAPF_FLAGS=(
'--recursive'
'--parallel'
)
YAPF_EXCLUDES=(
'--exclude' 'build/**'
'--exclude' 'vllm/model_executor/parallel_utils/**'
)
# Format specified files
format() {
yapf --in-place "${YAPF_FLAGS[@]}" "$@"
}
# Format files that differ from main branch. Ignores dirs that are not slated
# for autoformat yet.
format_changed() {
# The `if` guard ensures that the list of filenames is not empty, which
# could cause yapf to receive 0 positional arguments, making it hang
# waiting for STDIN.
#
# `diff-filter=ACM` and $MERGEBASE is to ensure we only format files that
# exist on both branches.
MERGEBASE="$(git merge-base origin/main HEAD)"
if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then
git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs -P 5 \
yapf --in-place "${YAPF_EXCLUDES[@]}" "${YAPF_FLAGS[@]}"
fi
}
# Format all files
format_all() {
yapf --in-place "${YAPF_FLAGS[@]}" "${YAPF_EXCLUDES[@]}" vllm
}
## This flag formats individual files. --files *must* be the first command line
## arg to use this option.
if [[ "$1" == '--files' ]]; then
format "${@:2}"
# If `--all` is passed, then any further arguments are ignored and the
# entire python directory is formatted.
elif [[ "$1" == '--all' ]]; then
format_all
else
# Format only the files that changed in last commit.
format_changed
fi
echo 'vLLM yapf: Done'
# Run mypy
# TODO(zhuohan): Enable mypy
# echo 'vLLM mypy:'
# mypy
# Run Pylint
echo 'vLLM Pylint:'
pylint vllm
if ! git diff --quiet &>/dev/null; then
echo 'Reformatted files. Please review and stage the changes.'
echo 'Changes not staged for commit:'
echo
git --no-pager diff --name-only
exit 1
fi

View File

@ -1,2 +1,12 @@
mypy # formatting
yapf==0.32.0
pylint==2.8.2
# type checking
mypy==0.991
types-PyYAML
types-requests
types-setuptools
# testing
pytest pytest

View File

@ -1,11 +1,11 @@
ninja # For faster builds. ninja # For faster builds.
psutil psutil
ray ray >= 2.5.1
sentencepiece # Required for LLaMA tokenizer. sentencepiece # Required for LLaMA tokenizer.
numpy numpy
torch >= 2.0.0 torch >= 2.0.0
transformers >= 4.28.0 # Required for LLaMA. transformers >= 4.31.0 # Required for LLaMA-2.
xformers >= 0.0.19 xformers >= 0.0.21
fastapi fastapi
uvicorn uvicorn
pydantic # Required for OpenAI server. pydantic < 2 # Required for OpenAI server.

View File

@ -20,10 +20,9 @@ ABI = 1 if torch._C._GLIBCXX_USE_CXX11_ABI else 0
CXX_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"] CXX_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"] NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
if not torch.cuda.is_available(): if CUDA_HOME is None:
raise RuntimeError( raise RuntimeError(
f"Cannot find CUDA at CUDA_HOME: {CUDA_HOME}. " f"Cannot find CUDA_HOME. CUDA must be available to build the package.")
"CUDA must be available in order to build the package.")
def get_nvcc_cuda_version(cuda_dir: str) -> Version: def get_nvcc_cuda_version(cuda_dir: str) -> Version:
@ -48,12 +47,6 @@ for i in range(device_count):
raise RuntimeError( raise RuntimeError(
"GPUs with compute capability less than 7.0 are not supported.") "GPUs with compute capability less than 7.0 are not supported.")
compute_capabilities.add(major * 10 + minor) compute_capabilities.add(major * 10 + minor)
# If no GPU is available, add all supported compute capabilities.
if not compute_capabilities:
compute_capabilities = {70, 75, 80, 86, 90}
# Add target compute capabilities to NVCC flags.
for capability in compute_capabilities:
NVCC_FLAGS += ["-gencode", f"arch=compute_{capability},code=sm_{capability}"]
# Validate the NVCC CUDA version. # Validate the NVCC CUDA version.
nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME) nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME)
@ -62,10 +55,31 @@ if nvcc_cuda_version < Version("11.0"):
if 86 in compute_capabilities and nvcc_cuda_version < Version("11.1"): if 86 in compute_capabilities and nvcc_cuda_version < Version("11.1"):
raise RuntimeError( raise RuntimeError(
"CUDA 11.1 or higher is required for GPUs with compute capability 8.6.") "CUDA 11.1 or higher is required for GPUs with compute capability 8.6.")
if 89 in compute_capabilities and nvcc_cuda_version < Version("11.8"):
# CUDA 11.8 is required to generate the code targeting compute capability 8.9.
# However, GPUs with compute capability 8.9 can also run the code generated by
# the previous versions of CUDA 11 and targeting compute capability 8.0.
# Therefore, if CUDA 11.8 is not available, we target compute capability 8.0
# instead of 8.9.
compute_capabilities.remove(89)
compute_capabilities.add(80)
if 90 in compute_capabilities and nvcc_cuda_version < Version("11.8"): if 90 in compute_capabilities and nvcc_cuda_version < Version("11.8"):
raise RuntimeError( raise RuntimeError(
"CUDA 11.8 or higher is required for GPUs with compute capability 9.0.") "CUDA 11.8 or higher is required for GPUs with compute capability 9.0.")
# If no GPU is available, add all supported compute capabilities.
if not compute_capabilities:
compute_capabilities = {70, 75, 80}
if nvcc_cuda_version >= Version("11.1"):
compute_capabilities.add(86)
if nvcc_cuda_version >= Version("11.8"):
compute_capabilities.add(89)
compute_capabilities.add(90)
# Add target compute capabilities to NVCC flags.
for capability in compute_capabilities:
NVCC_FLAGS += ["-gencode", f"arch=compute_{capability},code=sm_{capability}"]
# Use NVCC threads to parallelize the build. # Use NVCC threads to parallelize the build.
if nvcc_cuda_version >= Version("11.2"): if nvcc_cuda_version >= Version("11.2"):
num_threads = min(os.cpu_count(), 8) num_threads = min(os.cpu_count(), 8)

View File

@ -1,6 +1,6 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from transformers.activations import get_activation
from vllm import activation_ops from vllm import activation_ops
@ -28,3 +28,45 @@ def test_silu_and_mul() -> None:
for d in [512, 4096, 5120, 13824]: for d in [512, 4096, 5120, 13824]:
print(f'Testing dtype={dtype}, num_tokens={num_tokens}, d={d}') print(f'Testing dtype={dtype}, num_tokens={num_tokens}, d={d}')
run_silu_and_mul(num_tokens, d, dtype) run_silu_and_mul(num_tokens, d, dtype)
@torch.inference_mode()
def run_gelu_new(
num_tokens: int,
d: int,
dtype: torch.dtype,
) -> None:
x = torch.randn(num_tokens, d, dtype=dtype, device='cuda')
out = torch.empty(num_tokens, d, dtype=dtype, device='cuda')
activation_ops.gelu_new(out, x)
ref_out = get_activation("gelu_new")(x)
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
def test_gelu_new() -> None:
for dtype in [torch.half, torch.bfloat16, torch.float]:
for num_tokens in [7, 83, 2048]:
for d in [512, 4096, 5120, 13824]:
print(f'Testing dtype={dtype}, num_tokens={num_tokens}, d={d}')
run_gelu_new(num_tokens, d, dtype)
@torch.inference_mode()
def run_gelu_fast(
num_tokens: int,
d: int,
dtype: torch.dtype,
) -> None:
x = torch.randn(num_tokens, d, dtype=dtype, device='cuda')
out = torch.empty(num_tokens, d, dtype=dtype, device='cuda')
activation_ops.gelu_fast(out, x)
ref_out = get_activation("gelu_fast")(x)
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
def test_gelu_fast() -> None:
for dtype in [torch.half, torch.bfloat16, torch.float]:
for num_tokens in [7, 83, 2048]:
for d in [512, 4096, 5120, 13824]:
print(f'Testing dtype={dtype}, num_tokens={num_tokens}, d={d}')
run_gelu_fast(num_tokens, d, dtype)

View File

@ -84,8 +84,8 @@ def ref_multi_query_kv_attention(
seq_len = end_idx - start_idx seq_len = end_idx - start_idx
# Create attention mask. # Create attention mask.
attn_mask = torch.triu( attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype),
torch.ones(seq_len, seq_len, dtype=dtype), diagonal=1) diagonal=1)
attn_mask = attn_mask * torch.finfo(dtype).min attn_mask = attn_mask * torch.finfo(dtype).min
attn_mask = attn_mask.to(dtype=dtype, device='cuda') attn_mask = attn_mask.to(dtype=dtype, device='cuda')
@ -125,8 +125,8 @@ def ref_multi_query_cached_kv_attention(
block_table = block_tables[i] block_table = block_tables[i]
# Create attention mask # Create attention mask
attn_mask = torch.triu( attn_mask = torch.triu(torch.ones(query_len, context_len),
torch.ones(query_len, context_len), diagonal=context_len - query_len + 1) * -1e5 diagonal=context_len - query_len + 1) * -1e5
attn_mask = attn_mask.to(dtype=dtype, device='cuda') attn_mask = attn_mask.to(dtype=dtype, device='cuda')
keys = [] keys = []
@ -164,20 +164,27 @@ def run_single_query_cached_kv_attention(
block_size: int, block_size: int,
num_blocks: int, num_blocks: int,
dtype: torch.dtype, dtype: torch.dtype,
num_kv_heads: int = None,
) -> None: ) -> None:
qkv = torch.empty( qkv = torch.empty(num_tokens,
num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda') 3,
num_heads,
head_size,
dtype=dtype,
device='cuda')
qkv.uniform_(-1e-3, 1e-3) qkv.uniform_(-1e-3, 1e-3)
query, _, _ = qkv.unbind(dim=1) query, _, _ = qkv.unbind(dim=1)
x = 16 // torch.tensor([], dtype=dtype).element_size() x = 16 // torch.tensor([], dtype=dtype).element_size()
key_block_shape = (num_heads, head_size // x, block_size, x) key_block_shape = (num_heads, head_size // x, block_size, x)
key_cache = torch.empty( key_cache = torch.empty(size=(num_blocks, *key_block_shape),
size=(num_blocks, *key_block_shape), dtype=dtype, device='cuda') dtype=dtype,
device='cuda')
key_cache.uniform_(-1e-3, 1e-3) key_cache.uniform_(-1e-3, 1e-3)
value_block_shape = (num_heads, head_size, block_size) value_block_shape = (num_heads, head_size, block_size)
value_cache = torch.empty( value_cache = torch.empty(size=(num_blocks, *value_block_shape),
size=(num_blocks, *value_block_shape), dtype=dtype, device='cuda') dtype=dtype,
device='cuda')
value_cache.uniform_(-1e-3, 1e-3) value_cache.uniform_(-1e-3, 1e-3)
context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_tokens)] context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_tokens)]
@ -193,20 +200,34 @@ def run_single_query_cached_kv_attention(
] ]
block_tables.append(block_table) block_tables.append(block_table)
block_tables = torch.tensor(block_tables, dtype=torch.int, device='cuda') block_tables = torch.tensor(block_tables, dtype=torch.int, device='cuda')
head_mapping = torch.arange(num_heads, dtype=torch.int32, device="cuda")
scale = float(1.0 / (head_size**0.5)) scale = float(1.0 / (head_size**0.5))
output = torch.empty(
num_tokens, num_heads, head_size, dtype=dtype, device='cuda') num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
assert num_heads % num_kv_heads == 0
num_queries_per_kv = num_heads // num_kv_heads
head_mapping = torch.repeat_interleave(
torch.arange(num_kv_heads, dtype=torch.int32, device="cuda"),
num_queries_per_kv)
output = torch.empty(num_tokens,
num_heads,
head_size,
dtype=dtype,
device='cuda')
attention_ops.single_query_cached_kv_attention( attention_ops.single_query_cached_kv_attention(
output, output,
query, query,
key_cache, key_cache,
value_cache, value_cache,
head_mapping,
scale, scale,
block_tables, block_tables,
context_lens, context_lens,
block_size, block_size,
max_context_len, max_context_len,
None, # ALiBi slopes.
) )
ref_output = torch.empty_like(query) ref_output = torch.empty_like(query)
@ -236,8 +257,12 @@ def run_multi_query_kv_attention(
num_tokens = sum(seq_lens) num_tokens = sum(seq_lens)
scale = float(1.0 / (head_size**0.5)) scale = float(1.0 / (head_size**0.5))
qkv = torch.empty( qkv = torch.empty(num_tokens,
num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda') 3,
num_heads,
head_size,
dtype=dtype,
device='cuda')
qkv.uniform_(-1e-3, 1e-3) qkv.uniform_(-1e-3, 1e-3)
query, key, value = qkv.unbind(dim=1) query, key, value = qkv.unbind(dim=1)
@ -272,7 +297,7 @@ def test_single_query_cached_kv_attention() -> None:
torch.cuda.manual_seed(TEST_SEED) torch.cuda.manual_seed(TEST_SEED)
for dtype in [torch.half, torch.bfloat16, torch.float]: for dtype in [torch.half, torch.bfloat16, torch.float]:
for block_size in [8, 16, 32]: for block_size in [8, 16, 32]:
for head_size in [64, 80, 96, 128]: for head_size in [64, 80, 96, 112, 128, 256]:
print(f'Testing single_query_cached_kv_attention with ' print(f'Testing single_query_cached_kv_attention with '
f'dtype={dtype}, block_size={block_size}, ' f'dtype={dtype}, block_size={block_size}, '
f'head_size={head_size}') f'head_size={head_size}')
@ -290,7 +315,7 @@ def test_multi_query_kv_attention() -> None:
torch.random.manual_seed(TEST_SEED) torch.random.manual_seed(TEST_SEED)
torch.cuda.manual_seed(TEST_SEED) torch.cuda.manual_seed(TEST_SEED)
for dtype in [torch.half, torch.bfloat16, torch.float]: for dtype in [torch.half, torch.bfloat16, torch.float]:
for head_size in [64, 80, 96, 128]: for head_size in [64, 80, 96, 112, 128, 256]:
print(f'Testing multi_query_kv_attention with dtype={dtype}, ' print(f'Testing multi_query_kv_attention with dtype={dtype}, '
f'head_size={head_size}') f'head_size={head_size}')
run_multi_query_kv_attention( run_multi_query_kv_attention(

View File

@ -26,8 +26,9 @@ def run_copy_blocks(
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x) key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
key_caches = [] key_caches = []
for _ in range(num_layers): for _ in range(num_layers):
key_cache = torch.randn( key_cache = torch.randn(size=key_cache_shape,
size=key_cache_shape, dtype=dtype, device='cuda') dtype=dtype,
device='cuda')
key_caches.append(key_cache) key_caches.append(key_cache)
cloned_key_caches = [] cloned_key_caches = []
for key_cache in key_caches: for key_cache in key_caches:
@ -36,8 +37,9 @@ def run_copy_blocks(
value_cache_shape = (num_blocks, num_heads, head_size, block_size) value_cache_shape = (num_blocks, num_heads, head_size, block_size)
value_caches = [] value_caches = []
for _ in range(num_layers): for _ in range(num_layers):
value_cache = torch.randn( value_cache = torch.randn(size=value_cache_shape,
size=value_cache_shape, dtype=dtype, device='cuda') dtype=dtype,
device='cuda')
value_caches.append(value_cache) value_caches.append(value_cache)
cloned_value_caches = [] cloned_value_caches = []
for value_cache in value_caches: for value_cache in value_caches:
@ -49,15 +51,18 @@ def run_copy_blocks(
# Reference implementation. # Reference implementation.
for src, dsts in block_mapping.items(): for src, dsts in block_mapping.items():
for dst in dsts: for dst in dsts:
for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches): for key_cache, cloned_key_cache in zip(key_caches,
cloned_key_caches):
cloned_key_cache[dst] = cloned_key_cache[src] cloned_key_cache[dst] = cloned_key_cache[src]
for value_cache, cloned_value_cache in zip(value_caches, cloned_value_caches): for value_cache, cloned_value_cache in zip(value_caches,
cloned_value_caches):
cloned_value_cache[dst] = cloned_value_cache[src] cloned_value_cache[dst] = cloned_value_cache[src]
# Compare the results. # Compare the results.
for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches): for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
assert torch.allclose(key_cache, cloned_key_cache) assert torch.allclose(key_cache, cloned_key_cache)
for value_cache, cloned_value_cache in zip(value_caches, cloned_value_caches): for value_cache, cloned_value_cache in zip(value_caches,
cloned_value_caches):
assert torch.allclose(value_cache, cloned_value_cache) assert torch.allclose(value_cache, cloned_value_cache)
@ -74,8 +79,12 @@ def run_reshape_and_cache(
slot_mapping = random.sample(range(num_slots), num_tokens) slot_mapping = random.sample(range(num_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device='cuda') slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device='cuda')
qkv = torch.randn( qkv = torch.randn(num_tokens,
num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda') 3,
num_heads,
head_size,
dtype=dtype,
device='cuda')
_, key, value = qkv.unbind(dim=1) _, key, value = qkv.unbind(dim=1)
x = 16 // torch.tensor([], dtype=dtype).element_size() x = 16 // torch.tensor([], dtype=dtype).element_size()
@ -84,15 +93,19 @@ def run_reshape_and_cache(
cloned_key_cache = key_cache.clone() cloned_key_cache = key_cache.clone()
value_cache_shape = (num_blocks, num_heads, head_size, block_size) value_cache_shape = (num_blocks, num_heads, head_size, block_size)
value_cache = torch.randn( value_cache = torch.randn(size=value_cache_shape,
size=value_cache_shape, dtype=dtype, device='cuda') dtype=dtype,
device='cuda')
cloned_value_cache = value_cache.clone() cloned_value_cache = value_cache.clone()
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping) cache_ops.reshape_and_cache(key, value, key_cache, value_cache,
slot_mapping)
for i in range(num_tokens): for i in range(num_tokens):
reshaped_key = key.reshape(num_tokens, num_heads, head_size // x, x) reshaped_key = key.reshape(num_tokens, num_heads, head_size // x, x)
block_idx = torch.div(slot_mapping[i], block_size, rounding_mode='floor') block_idx = torch.div(slot_mapping[i],
block_size,
rounding_mode='floor')
block_offset = slot_mapping[i] % block_size block_offset = slot_mapping[i] % block_size
cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i] cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i]
cloned_value_cache[block_idx, :, :, block_offset] = value[i] cloned_value_cache[block_idx, :, :, block_offset] = value[i]
@ -114,8 +127,12 @@ def run_gather_cached_kv(
slot_mapping = random.sample(range(num_slots), num_tokens) slot_mapping = random.sample(range(num_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device='cuda') slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device='cuda')
qkv = torch.randn( qkv = torch.randn(num_tokens,
num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda') 3,
num_heads,
head_size,
dtype=dtype,
device='cuda')
_, key, value = qkv.unbind(dim=1) _, key, value = qkv.unbind(dim=1)
qkv_clone = qkv.clone() qkv_clone = qkv.clone()
@ -126,15 +143,20 @@ def run_gather_cached_kv(
key_cache = torch.randn(size=key_cache_shape, dtype=dtype, device='cuda') key_cache = torch.randn(size=key_cache_shape, dtype=dtype, device='cuda')
value_cache_shape = (num_blocks, num_heads, head_size, block_size) value_cache_shape = (num_blocks, num_heads, head_size, block_size)
value_cache = torch.randn( value_cache = torch.randn(size=value_cache_shape,
size=value_cache_shape, dtype=dtype, device='cuda') dtype=dtype,
device='cuda')
cache_ops.gather_cached_kv(key, value, key_cache, value_cache, slot_mapping) cache_ops.gather_cached_kv(key, value, key_cache, value_cache,
slot_mapping)
# Reference implementation. # Reference implementation.
for i in range(num_tokens): for i in range(num_tokens):
reshaped_key = cloned_key.reshape(num_tokens, num_heads, head_size // x, x) reshaped_key = cloned_key.reshape(num_tokens, num_heads,
block_idx = torch.div(slot_mapping[i], block_size, rounding_mode='floor') head_size // x, x)
block_idx = torch.div(slot_mapping[i],
block_size,
rounding_mode='floor')
block_offset = slot_mapping[i] % block_size block_offset = slot_mapping[i] % block_size
reshaped_key[i] = key_cache[block_idx, :, :, block_offset, :] reshaped_key[i] = key_cache[block_idx, :, :, block_offset, :]
cloned_value[i] = value_cache[block_idx, :, :, block_offset] cloned_value[i] = value_cache[block_idx, :, :, block_offset]
@ -145,20 +167,30 @@ def run_gather_cached_kv(
def test_copy_blocks() -> None: def test_copy_blocks() -> None:
for dtype in [torch.half, torch.bfloat16, torch.float]: for dtype in [torch.half, torch.bfloat16, torch.float]:
run_copy_blocks( run_copy_blocks(num_mappings=23,
num_mappings=23, num_layers=7, num_heads=17, head_size=16, num_layers=7,
block_size=8, num_blocks=1024, dtype=dtype) num_heads=17,
head_size=16,
block_size=8,
num_blocks=1024,
dtype=dtype)
def test_reshape_and_cache() -> None: def test_reshape_and_cache() -> None:
for dtype in [torch.half, torch.bfloat16, torch.float]: for dtype in [torch.half, torch.bfloat16, torch.float]:
run_reshape_and_cache( run_reshape_and_cache(num_tokens=3,
num_tokens=3, num_heads=2, head_size=16, block_size=8, num_blocks=2, num_heads=2,
head_size=16,
block_size=8,
num_blocks=2,
dtype=dtype) dtype=dtype)
def test_gather_cached_kv() -> None: def test_gather_cached_kv() -> None:
for dtype in [torch.half, torch.bfloat16, torch.float]: for dtype in [torch.half, torch.bfloat16, torch.float]:
run_gather_cached_kv( run_gather_cached_kv(num_tokens=3,
num_tokens=3, num_heads=2, head_size=16, block_size=8, num_blocks=2, num_heads=2,
head_size=16,
block_size=8,
num_blocks=2,
dtype=dtype) dtype=dtype)

View File

@ -14,8 +14,10 @@ class RefRMSNorm(nn.Module):
self.variance_epsilon = eps self.variance_epsilon = eps
def forward(self, hidden_states): def forward(self, hidden_states):
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) variance = hidden_states.to(torch.float32).pow(2).mean(-1,
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance +
self.variance_epsilon)
if self.weight.dtype in [torch.half, torch.float16, torch.bfloat16]: if self.weight.dtype in [torch.half, torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype) hidden_states = hidden_states.to(self.weight.dtype)
return self.weight * hidden_states return self.weight * hidden_states

View File

@ -59,7 +59,6 @@ class RefRotaryEmbeddingNeox(nn.Module):
key_rot = key[..., :self.rotary_dim] key_rot = key[..., :self.rotary_dim]
key_pass = key[..., self.rotary_dim:] key_pass = key[..., self.rotary_dim:]
query_rot = query_rot.transpose(0, 1) query_rot = query_rot.transpose(0, 1)
key_rot = key_rot.transpose(0, 1) key_rot = key_rot.transpose(0, 1)
cos = F.embedding(positions, self.cos_cached) cos = F.embedding(positions, self.cos_cached)
@ -86,8 +85,14 @@ def run_rotary_embedding_neox(
base: int = 10000, base: int = 10000,
) -> None: ) -> None:
positions = torch.randint(0, max_position, (num_tokens, ), device='cuda') positions = torch.randint(0, max_position, (num_tokens, ), device='cuda')
query = torch.randn(num_tokens, num_heads * head_size, dtype=dtype, device='cuda') query = torch.randn(num_tokens,
key = torch.randn(num_tokens, num_heads * head_size, dtype=dtype, device='cuda') num_heads * head_size,
dtype=dtype,
device='cuda')
key = torch.randn(num_tokens,
num_heads * head_size,
dtype=dtype,
device='cuda')
# Create the rotary embedding. # Create the rotary embedding.
inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim)) inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim))

View File

@ -1,3 +1,5 @@
"""vLLM: a high-throughput and memory-efficient inference engine for LLMs"""
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.llm_engine import LLMEngine from vllm.engine.llm_engine import LLMEngine
@ -6,7 +8,7 @@ from vllm.entrypoints.llm import LLM
from vllm.outputs import CompletionOutput, RequestOutput from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
__version__ = "0.1.1" __version__ = "0.1.4"
__all__ = [ __all__ = [
"LLM", "LLM",

View File

@ -35,7 +35,8 @@ class LogicalTokenBlock:
def append_tokens(self, token_ids: List[int]) -> None: def append_tokens(self, token_ids: List[int]) -> None:
assert len(token_ids) <= self.get_num_empty_slots() assert len(token_ids) <= self.get_num_empty_slots()
self.token_ids[self.num_tokens:self.num_tokens + len(token_ids)] = token_ids curr_idx = self.num_tokens
self.token_ids[curr_idx:curr_idx + len(token_ids)] = token_ids
self.num_tokens += len(token_ids) self.num_tokens += len(token_ids)
def get_token_ids(self) -> List[int]: def get_token_ids(self) -> List[int]:

View File

@ -1,14 +1,15 @@
from typing import Optional from typing import Optional
import torch import torch
from transformers import AutoConfig, PretrainedConfig from transformers import PretrainedConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.config import get_config
from vllm.utils import get_cpu_memory from vllm.utils import get_cpu_memory
logger = init_logger(__name__) logger = init_logger(__name__)
_GiB = 1 << 30 _GB = 1 << 30
class ModelConfig: class ModelConfig:
@ -16,6 +17,11 @@ class ModelConfig:
Args: Args:
model: Name or path of the huggingface model to use. model: Name or path of the huggingface model to use.
tokenizer: Name or path of the huggingface tokenizer to use.
tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
available, and "slow" will always use the slow tokenizer.
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
downloading the model and tokenizer.
download_dir: Directory to download and load the weights, default to the download_dir: Directory to download and load the weights, default to the
default cache directory of huggingface. default cache directory of huggingface.
use_np_weights: Save a numpy copy of model weights for faster loading. use_np_weights: Save a numpy copy of model weights for faster loading.
@ -30,6 +36,9 @@ class ModelConfig:
def __init__( def __init__(
self, self,
model: str, model: str,
tokenizer: str,
tokenizer_mode: str,
trust_remote_code: bool,
download_dir: Optional[str], download_dir: Optional[str],
use_np_weights: bool, use_np_weights: bool,
use_dummy_weights: bool, use_dummy_weights: bool,
@ -37,13 +46,25 @@ class ModelConfig:
seed: int, seed: int,
) -> None: ) -> None:
self.model = model self.model = model
self.tokenizer = tokenizer
self.tokenizer_mode = tokenizer_mode
self.trust_remote_code = trust_remote_code
self.download_dir = download_dir self.download_dir = download_dir
self.use_np_weights = use_np_weights self.use_np_weights = use_np_weights
self.use_dummy_weights = use_dummy_weights self.use_dummy_weights = use_dummy_weights
self.seed = seed self.seed = seed
self.hf_config: PretrainedConfig = AutoConfig.from_pretrained(model) self.hf_config = get_config(model, trust_remote_code)
self.dtype = _get_and_verify_dtype(self.hf_config, dtype) self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
self._verify_tokenizer_mode()
def _verify_tokenizer_mode(self) -> None:
tokenizer_mode = self.tokenizer_mode.lower()
if tokenizer_mode not in ["auto", "slow"]:
raise ValueError(
f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be "
"either 'auto' or 'slow'.")
self.tokenizer_mode = tokenizer_mode
def verify_with_parallel_config( def verify_with_parallel_config(
self, self,
@ -73,9 +94,48 @@ class ModelConfig:
return self.hf_config.hidden_size // self.hf_config.num_attention_heads return self.hf_config.hidden_size // self.hf_config.num_attention_heads
def get_num_heads(self, parallel_config: "ParallelConfig") -> int: def get_num_heads(self, parallel_config: "ParallelConfig") -> int:
# For GPTBigCode & Falcon:
# Note: for falcon, when new_decoder_architecture is True, the
# multi_query flag is ignored and we use n_head_kv for the number of
# KV heads.
new_decoder_arch_falcon = (
self.hf_config.model_type == "falcon"
and getattr(self.hf_config, "new_decoder_architecture", False))
if not new_decoder_arch_falcon and getattr(self.hf_config,
"multi_query", False):
# Multi-query attention, only one KV head.
return 1
# For Falcon:
if getattr(self.hf_config, "n_head_kv", None) is not None:
return (self.hf_config.n_head_kv //
parallel_config.tensor_parallel_size)
# For LLaMA-2:
if getattr(self.hf_config, "num_key_value_heads", None) is not None:
return (self.hf_config.num_key_value_heads //
parallel_config.tensor_parallel_size)
total_num_attention_heads = self.hf_config.num_attention_heads total_num_attention_heads = self.hf_config.num_attention_heads
return total_num_attention_heads // parallel_config.tensor_parallel_size return total_num_attention_heads // parallel_config.tensor_parallel_size
def get_max_model_len(self) -> int:
max_model_len = float("inf")
possible_keys = [
# OPT
"max_position_embeddings",
# GPT-2
"n_positions",
# MPT
"max_seq_len",
# Others
"max_sequence_length",
"max_seq_length",
"seq_len",
]
for key in possible_keys:
max_len_key = getattr(self.hf_config, key, None)
if max_len_key is not None:
max_model_len = min(max_model_len, max_len_key)
return max_model_len
def get_num_layers(self, parallel_config: "ParallelConfig") -> int: def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
total_num_hidden_layers = self.hf_config.num_hidden_layers total_num_hidden_layers = self.hf_config.num_hidden_layers
return total_num_hidden_layers // parallel_config.pipeline_parallel_size return total_num_hidden_layers // parallel_config.pipeline_parallel_size
@ -90,6 +150,7 @@ class CacheConfig:
vLLM execution. vLLM execution.
swap_space: Size of the CPU swap space per GPU (in GiB). swap_space: Size of the CPU swap space per GPU (in GiB).
""" """
def __init__( def __init__(
self, self,
block_size: int, block_size: int,
@ -98,7 +159,7 @@ class CacheConfig:
) -> None: ) -> None:
self.block_size = block_size self.block_size = block_size
self.gpu_memory_utilization = gpu_memory_utilization self.gpu_memory_utilization = gpu_memory_utilization
self.swap_space_bytes = swap_space * _GiB self.swap_space_bytes = swap_space * _GB
self._verify_args() self._verify_args()
# Will be set after profiling. # Will be set after profiling.
@ -121,14 +182,13 @@ class CacheConfig:
num_gpus_per_node = parallel_config.tensor_parallel_size num_gpus_per_node = parallel_config.tensor_parallel_size
cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node
msg = ( msg = (f"{cpu_memory_usage / _GB:.2f} GiB out of "
f"{cpu_memory_usage / _GiB:.2f} GiB out of " f"the {total_cpu_memory / _GB:.2f} GiB total CPU memory is "
f"the {total_cpu_memory / _GiB:.2f} GiB total CPU memory is "
"allocated for the swap space.") "allocated for the swap space.")
if cpu_memory_usage > 0.7 * total_cpu_memory: if cpu_memory_usage > 0.7 * total_cpu_memory:
raise ValueError("Too large swap space. " + msg) raise ValueError("Too large swap space. " + msg)
elif cpu_memory_usage > 0.4 * total_cpu_memory: elif cpu_memory_usage > 0.4 * total_cpu_memory:
logger.warn("Possibly too large swap space. " + msg) logger.warning("Possibly too large swap space. " + msg)
class ParallelConfig: class ParallelConfig:
@ -141,6 +201,7 @@ class ParallelConfig:
True if either pipeline_parallel_size or tensor_parallel_size is True if either pipeline_parallel_size or tensor_parallel_size is
greater than 1. greater than 1.
""" """
def __init__( def __init__(
self, self,
pipeline_parallel_size: int, pipeline_parallel_size: int,
@ -170,14 +231,15 @@ class SchedulerConfig:
a single iteration. a single iteration.
max_num_seqs: Maximum number of sequences to be processed in a single max_num_seqs: Maximum number of sequences to be processed in a single
iteration. iteration.
max_model_len: Maximum length of a sequence (including prompt
and generated text).
""" """
def __init__(
self, def __init__(self, max_num_batched_tokens: int, max_num_seqs: int,
max_num_batched_tokens: int, max_model_len: int) -> None:
max_num_seqs: int,
) -> None:
self.max_num_batched_tokens = max_num_batched_tokens self.max_num_batched_tokens = max_num_batched_tokens
self.max_num_seqs = max_num_seqs self.max_num_seqs = max_num_seqs
self.max_model_len = max_model_len
_STR_DTYPE_TO_TORCH_DTYPE = { _STR_DTYPE_TO_TORCH_DTYPE = {
@ -221,7 +283,7 @@ def _get_and_verify_dtype(
pass pass
else: else:
# Casting between float16 and bfloat16 is allowed with a warning. # Casting between float16 and bfloat16 is allowed with a warning.
logger.warn(f"Casting {config_dtype} to {torch_dtype}.") logger.warning(f"Casting {config_dtype} to {torch_dtype}.")
# Check if the GPU supports the dtype. # Check if the GPU supports the dtype.
if torch_dtype == torch.bfloat16: if torch_dtype == torch.bfloat16:

View File

@ -27,8 +27,9 @@ class BlockAllocator:
# Initialize the free blocks. # Initialize the free blocks.
self.free_blocks: List[PhysicalTokenBlock] = [] self.free_blocks: List[PhysicalTokenBlock] = []
for i in range(num_blocks): for i in range(num_blocks):
block = PhysicalTokenBlock( block = PhysicalTokenBlock(device=device,
device=device, block_number=i, block_size=block_size) block_number=i,
block_size=block_size)
self.free_blocks.append(block) self.free_blocks.append(block)
def allocate(self) -> PhysicalTokenBlock: def allocate(self) -> PhysicalTokenBlock:
@ -84,10 +85,12 @@ class BlockSpaceManager:
num_required_blocks = len(seq.logical_token_blocks) num_required_blocks = len(seq.logical_token_blocks)
num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks() num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
# Use watermark to avoid frequent cache eviction. # Use watermark to avoid frequent cache eviction.
return num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks return (num_free_gpu_blocks - num_required_blocks >=
self.watermark_blocks)
def allocate(self, seq_group: SequenceGroup) -> None: def allocate(self, seq_group: SequenceGroup) -> None:
# NOTE: Here we assume that all sequences in the group have the same prompt. # NOTE: Here we assume that all sequences in the group have the same
# prompt.
seq = seq_group.get_seqs()[0] seq = seq_group.get_seqs()[0]
# Allocate new physical token blocks that will store the prompt tokens. # Allocate new physical token blocks that will store the prompt tokens.
@ -143,7 +146,8 @@ class BlockSpaceManager:
for block in src_block_table: for block in src_block_table:
block.ref_count += 1 block.ref_count += 1
def _get_physical_blocks(self, seq_group: SequenceGroup) -> List[PhysicalTokenBlock]: def _get_physical_blocks(
self, seq_group: SequenceGroup) -> List[PhysicalTokenBlock]:
# NOTE: Here, we assume that the physical blocks are only shared by # NOTE: Here, we assume that the physical blocks are only shared by
# the sequences in the same group. # the sequences in the same group.
blocks: Set[PhysicalTokenBlock] = set() blocks: Set[PhysicalTokenBlock] = set()

View File

@ -12,8 +12,6 @@ from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
logger = init_logger(__name__) logger = init_logger(__name__)
_LOGGING_INTERVAL_SEC = 5
class PreemptionMode(enum.Enum): class PreemptionMode(enum.Enum):
"""Preemption modes. """Preemption modes.
@ -32,20 +30,28 @@ class SchedulerOutputs:
def __init__( def __init__(
self, self,
scheduled_seq_groups: List[SequenceGroup],
prompt_run: bool,
num_batched_tokens: int,
blocks_to_swap_in: Dict[int, int], blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int], blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]], blocks_to_copy: Dict[int, List[int]],
ignored_seq_groups: List[SequenceGroup],
) -> None: ) -> None:
self.scheduled_seq_groups = scheduled_seq_groups
self.prompt_run = prompt_run
self.num_batched_tokens = num_batched_tokens
self.blocks_to_swap_in = blocks_to_swap_in self.blocks_to_swap_in = blocks_to_swap_in
self.blocks_to_swap_out = blocks_to_swap_out self.blocks_to_swap_out = blocks_to_swap_out
self.blocks_to_copy = blocks_to_copy self.blocks_to_copy = blocks_to_copy
# Swap in and swap out should never happen at the same time. # Swap in and swap out should never happen at the same time.
assert not (blocks_to_swap_in and blocks_to_swap_out) assert not (blocks_to_swap_in and blocks_to_swap_out)
self.ignored_seq_groups = ignored_seq_groups
def is_empty(self) -> bool: def is_empty(self) -> bool:
return (not self.blocks_to_swap_in # NOTE: We do not consider the ignored sequence groups.
and not self.blocks_to_swap_out return (not self.scheduled_seq_groups and not self.blocks_to_swap_in
and not self.blocks_to_copy) and not self.blocks_to_swap_out and not self.blocks_to_copy)
class Scheduler: class Scheduler:
@ -54,14 +60,12 @@ class Scheduler:
self, self,
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
cache_config: CacheConfig, cache_config: CacheConfig,
log_stats: bool,
) -> None: ) -> None:
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
self.cache_config = cache_config self.cache_config = cache_config
self.log_stats = log_stats
# Instantiate the scheduling policy. # Instantiate the scheduling policy.
self.policy = PolicyFactory.get_policy(policy_name='fcfs') self.policy = PolicyFactory.get_policy(policy_name="fcfs")
# Create the block space manager. # Create the block space manager.
self.block_manager = BlockSpaceManager( self.block_manager = BlockSpaceManager(
block_size=self.cache_config.block_size, block_size=self.cache_config.block_size,
@ -76,10 +80,6 @@ class Scheduler:
# Sequence groups in the SWAPPED state. # Sequence groups in the SWAPPED state.
self.swapped: List[SequenceGroup] = [] self.swapped: List[SequenceGroup] = []
self.last_logging_time: float = 0.0
# List[timestamp, num_tokens]
self.num_input_tokens: List[Tuple[float, int]] = []
def add_seq_group(self, seq_group: SequenceGroup) -> None: def add_seq_group(self, seq_group: SequenceGroup) -> None:
# Add sequence groups to the waiting queue. # Add sequence groups to the waiting queue.
self.waiting.append(seq_group) self.waiting.append(seq_group)
@ -102,7 +102,7 @@ class Scheduler:
def get_num_unfinished_seq_groups(self) -> int: def get_num_unfinished_seq_groups(self) -> int:
return len(self.waiting) + len(self.running) + len(self.swapped) return len(self.waiting) + len(self.running) + len(self.swapped)
def _schedule(self) -> Tuple[SchedulerOutputs, List[str]]: def _schedule(self) -> SchedulerOutputs:
# Blocks that need to be swaped or copied before model execution. # Blocks that need to be swaped or copied before model execution.
blocks_to_swap_in: Dict[int, int] = {} blocks_to_swap_in: Dict[int, int] = {}
blocks_to_swap_out: Dict[int, int] = {} blocks_to_swap_out: Dict[int, int] = {}
@ -111,10 +111,71 @@ class Scheduler:
# Fix the current time. # Fix the current time.
now = time.time() now = time.time()
# NOTE(woosuk): We prioritize the sequence groups in the RUNNING state # Join waiting sequences if possible.
# in order to minimize the preemption overheads. if not self.swapped:
# Preemption happens only when there is no available slot to keep all ignored_seq_groups: List[SequenceGroup] = []
# the sequence groups in the RUNNING state. scheduled: List[SequenceGroup] = []
num_batched_tokens = 0
# Optimization: We do not sort the waiting queue since the preempted
# sequence groups are added to the front and the new sequence groups
# are added to the back.
while self.waiting:
seq_group = self.waiting[0]
num_prompt_tokens = seq_group.get_seqs()[0].get_len()
prompt_limit = min(
self.scheduler_config.max_model_len,
self.scheduler_config.max_num_batched_tokens)
if num_prompt_tokens > prompt_limit:
logger.warning(
f"Input prompt ({num_prompt_tokens} tokens) is too long"
f" and exceeds limit of {prompt_limit}")
for seq in seq_group.get_seqs():
seq.status = SequenceStatus.FINISHED_IGNORED
ignored_seq_groups.append(seq_group)
self.waiting.pop(0)
break
# If the sequence group cannot be allocated, stop.
if not self.block_manager.can_allocate(seq_group):
break
# If the number of batched tokens exceeds the limit, stop.
if (num_batched_tokens + num_prompt_tokens >
self.scheduler_config.max_num_batched_tokens):
break
# The total number of sequences in the RUNNING state should not
# exceed the maximum number of sequences.
num_new_seqs = seq_group.num_seqs(
status=SequenceStatus.WAITING)
num_curr_seqs = sum(
seq_group.num_seqs(status=SequenceStatus.RUNNING)
for seq_group in self.running)
if (num_curr_seqs + num_new_seqs >
self.scheduler_config.max_num_seqs):
break
seq_group = self.waiting.pop(0)
self._allocate(seq_group)
self.running.append(seq_group)
num_batched_tokens += num_prompt_tokens
scheduled.append(seq_group)
if scheduled:
scheduler_outputs = SchedulerOutputs(
scheduled_seq_groups=scheduled,
prompt_run=True,
num_batched_tokens=num_batched_tokens,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
ignored_seq_groups=ignored_seq_groups,
)
return scheduler_outputs
# NOTE(woosuk): Preemption happens only when there is no available slot
# to keep all the sequence groups in the RUNNING state.
# In this case, the policy is responsible for deciding which sequence # In this case, the policy is responsible for deciding which sequence
# groups to preempt. # groups to preempt.
self.running = self.policy.sort_by_priority(now, self.running) self.running = self.policy.sort_by_priority(now, self.running)
@ -156,8 +217,11 @@ class Scheduler:
# The total number of sequences in the RUNNING state should not # The total number of sequences in the RUNNING state should not
# exceed the maximum number of sequences. # exceed the maximum number of sequences.
num_new_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED) num_new_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED)
num_curr_seqs = len(self.running) num_curr_seqs = sum(
if num_curr_seqs + num_new_seqs > self.scheduler_config.max_num_seqs: seq_group.num_seqs(status=SequenceStatus.RUNNING)
for seq_group in self.running)
if (num_curr_seqs + num_new_seqs >
self.scheduler_config.max_num_seqs):
break break
seq_group = self.swapped.pop(0) seq_group = self.swapped.pop(0)
@ -167,106 +231,28 @@ class Scheduler:
num_batched_tokens = sum( num_batched_tokens = sum(
seq_group.num_seqs(status=SequenceStatus.RUNNING) seq_group.num_seqs(status=SequenceStatus.RUNNING)
for seq_group in self.running for seq_group in self.running)
)
# Join waiting sequences if possible.
prompt_group_ids: List[str] = []
# NOTE(woosuk): The sequence groups in the SWAPPED state are strictly
# prioritized over the sequence groups in the WAITING state.
# This is because we want to bound the amount of CPU memory taken by
# the swapped sequence groups.
if not self.swapped:
# Optimization: We do not sort the waiting queue since the preempted
# sequence groups are added to the front and the new sequence groups
# are added to the back.
while self.waiting:
seq_group = self.waiting[0]
# If the sequence group has been preempted in this step, stop.
if seq_group in preempted:
break
# If the sequence group cannot be allocated, stop.
if not self.block_manager.can_allocate(seq_group):
break
# If the number of batched tokens exceeds the limit, stop.
num_prompt_tokens = seq_group.get_seqs()[0].get_len()
if (num_batched_tokens + num_prompt_tokens
> self.scheduler_config.max_num_batched_tokens):
break
# The total number of sequences in the RUNNING state should not
# exceed the maximum number of sequences.
num_new_seqs = seq_group.num_seqs(status=SequenceStatus.WAITING)
num_curr_seqs = len(self.running)
if num_curr_seqs + num_new_seqs > self.scheduler_config.max_num_seqs:
break
seq_group = self.waiting.pop(0)
self._allocate(seq_group)
self.running.append(seq_group)
num_batched_tokens += num_prompt_tokens
prompt_group_ids.append(seq_group.request_id)
scheduler_outputs = SchedulerOutputs( scheduler_outputs = SchedulerOutputs(
scheduled_seq_groups=self.running,
prompt_run=False,
num_batched_tokens=num_batched_tokens,
blocks_to_swap_in=blocks_to_swap_in, blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out, blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy, blocks_to_copy=blocks_to_copy,
ignored_seq_groups=[],
) )
if not self.log_stats: return scheduler_outputs
return scheduler_outputs, prompt_group_ids
# TODO(woosuk): Move the below code to the engine.
now = time.time()
if num_batched_tokens > 0:
self.num_input_tokens.append((now, num_batched_tokens))
elapsed_time = now - self.last_logging_time
if elapsed_time > _LOGGING_INTERVAL_SEC:
self.last_logging_time = now
self.num_input_tokens = [
(t, n) for t, n in self.num_input_tokens
if now - t < _LOGGING_INTERVAL_SEC
]
if len(self.num_input_tokens) > 1:
total_num_tokens = sum(n for _, n in self.num_input_tokens[:-1])
window = now - self.num_input_tokens[0][0]
avg_throughput = total_num_tokens / window
else:
avg_throughput = 0.0
total_num_gpu_blocks = self.cache_config.num_gpu_blocks
num_free_gpu_blocks = self.block_manager.get_num_free_gpu_blocks()
num_used_gpu_blocks = total_num_gpu_blocks - num_free_gpu_blocks
gpu_cache_usage = num_used_gpu_blocks / total_num_gpu_blocks
total_num_cpu_blocks = self.cache_config.num_cpu_blocks
if total_num_cpu_blocks > 0:
num_free_cpu_blocks = self.block_manager.get_num_free_cpu_blocks()
num_used_cpu_blocks = total_num_cpu_blocks - num_free_cpu_blocks
cpu_cache_usage = num_used_cpu_blocks / total_num_cpu_blocks
else:
cpu_cache_usage = 0.0
logger.info(
f"Throughput: {avg_throughput:.1f} tokens/s, "
f"Running: {len(self.running)} reqs, "
f"Swapped: {len(self.swapped)} reqs, "
f"Pending: {len(self.waiting)} reqs, "
f"GPU KV cache usage: {gpu_cache_usage * 100:.1f}%, "
f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%")
return scheduler_outputs, prompt_group_ids
def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
# Schedule sequence groups. # Schedule sequence groups.
# This function call changes the internal states of the scheduler # This function call changes the internal states of the scheduler
# such as self.running, self.swapped, and self.waiting. # such as self.running, self.swapped, and self.waiting.
scheduler_outputs, prompt_group_ids = self._schedule() scheduler_outputs = self._schedule()
# Create input data structures. # Create input data structures.
seq_group_metadata_list: List[SequenceGroupMetadata] = [] seq_group_metadata_list: List[SequenceGroupMetadata] = []
for seq_group in self.running: for seq_group in scheduler_outputs.scheduled_seq_groups:
is_prompt = seq_group.request_id in prompt_group_ids
seq_data: Dict[int, List[SequenceData]] = {} seq_data: Dict[int, List[SequenceData]] = {}
block_tables: Dict[int, List[int]] = {} block_tables: Dict[int, List[int]] = {}
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
@ -276,7 +262,7 @@ class Scheduler:
seq_group_metadata = SequenceGroupMetadata( seq_group_metadata = SequenceGroupMetadata(
request_id=seq_group.request_id, request_id=seq_group.request_id,
is_prompt=is_prompt, is_prompt=scheduler_outputs.prompt_run,
seq_data=seq_data, seq_data=seq_data,
sampling_params=seq_group.sampling_params, sampling_params=seq_group.sampling_params,
block_tables=block_tables, block_tables=block_tables,
@ -288,14 +274,21 @@ class Scheduler:
self, self,
seq_outputs: Dict[int, SequenceOutputs], seq_outputs: Dict[int, SequenceOutputs],
) -> List[SequenceGroup]: ) -> List[SequenceGroup]:
# Update the running sequences and free blocks. scheduled: List[SequenceGroup] = []
for seq_group in self.running: for seq_group in self.running:
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
if seq.seq_id in seq_outputs:
scheduled.append(seq_group)
break
# Update the scheduled sequences and free blocks.
for seq_group in scheduled:
# Process beam search results before processing the new tokens. # Process beam search results before processing the new tokens.
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
output = seq_outputs[seq.seq_id] output = seq_outputs[seq.seq_id]
if seq.seq_id != output.parent_seq_id: if seq.seq_id != output.parent_seq_id:
# The sequence is a fork of the parent sequence (beam search). # The sequence is a fork of the parent sequence (beam
# Free the current sequence. # search). Free the current sequence.
self.block_manager.free(seq) self.block_manager.free(seq)
# Fork the parent sequence. # Fork the parent sequence.
parent_seq = seq_group.find(output.parent_seq_id) parent_seq = seq_group.find(output.parent_seq_id)
@ -307,9 +300,7 @@ class Scheduler:
# Append a new token to the sequence. # Append a new token to the sequence.
output = seq_outputs[seq.seq_id] output = seq_outputs[seq.seq_id]
seq.append_token_id(output.output_token, output.logprobs) seq.append_token_id(output.output_token, output.logprobs)
# Return a shallow copy of the running queue to prevent the queue return scheduled
# from being modified by the caller.
return self.running.copy()
def free_seq(self, seq: Sequence, finish_status: SequenceStatus) -> None: def free_seq(self, seq: Sequence, finish_status: SequenceStatus) -> None:
seq.status = finish_status seq.status = finish_status
@ -368,7 +359,7 @@ class Scheduler:
elif preemption_mode == PreemptionMode.SWAP: elif preemption_mode == PreemptionMode.SWAP:
self._preempt_by_swap(seq_group, blocks_to_swap_out) self._preempt_by_swap(seq_group, blocks_to_swap_out)
else: else:
assert False, 'Invalid preemption mode.' assert False, "Invalid preemption mode."
def _preempt_by_recompute( def _preempt_by_recompute(
self, self,
@ -388,9 +379,6 @@ class Scheduler:
seq_group: SequenceGroup, seq_group: SequenceGroup,
blocks_to_swap_out: Dict[int, int], blocks_to_swap_out: Dict[int, int],
) -> None: ) -> None:
seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
for seq in seqs:
seq.status = SequenceStatus.SWAPPED
self._swap_out(seq_group, blocks_to_swap_out) self._swap_out(seq_group, blocks_to_swap_out)
self.swapped.append(seq_group) self.swapped.append(seq_group)

View File

@ -11,10 +11,13 @@ from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
class EngineArgs: class EngineArgs:
"""Arguments for vLLM engine.""" """Arguments for vLLM engine."""
model: str model: str
tokenizer: Optional[str] = None
tokenizer_mode: str = 'auto'
trust_remote_code: bool = False
download_dir: Optional[str] = None download_dir: Optional[str] = None
use_np_weights: bool = False use_np_weights: bool = False
use_dummy_weights: bool = False use_dummy_weights: bool = False
dtype: str = "auto" dtype: str = 'auto'
seed: int = 0 seed: int = 0
worker_use_ray: bool = False worker_use_ray: bool = False
pipeline_parallel_size: int = 1 pipeline_parallel_size: int = 1
@ -27,72 +30,110 @@ class EngineArgs:
disable_log_stats: bool = False disable_log_stats: bool = False
def __post_init__(self): def __post_init__(self):
if self.tokenizer is None:
self.tokenizer = self.model
self.max_num_seqs = min(self.max_num_seqs, self.max_num_batched_tokens) self.max_num_seqs = min(self.max_num_seqs, self.max_num_batched_tokens)
@staticmethod @staticmethod
def add_cli_args( def add_cli_args(
parser: argparse.ArgumentParser, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
) -> argparse.ArgumentParser:
"""Shared CLI arguments for vLLM engine.""" """Shared CLI arguments for vLLM engine."""
# Model arguments # Model arguments
parser.add_argument('--model', type=str, default='facebook/opt-125m', parser.add_argument(
'--model',
type=str,
default='facebook/opt-125m',
help='name or path of the huggingface model to use') help='name or path of the huggingface model to use')
parser.add_argument('--download-dir', type=str, parser.add_argument(
'--tokenizer',
type=str,
default=EngineArgs.tokenizer,
help='name or path of the huggingface tokenizer to use')
parser.add_argument('--tokenizer-mode',
type=str,
default=EngineArgs.tokenizer_mode,
choices=['auto', 'slow'],
help='tokenizer mode. "auto" will use the fast '
'tokenizer if available, and "slow" will '
'always use the slow tokenizer.')
parser.add_argument('--trust-remote-code',
action='store_true',
help='trust remote code from huggingface')
parser.add_argument('--download-dir',
type=str,
default=EngineArgs.download_dir, default=EngineArgs.download_dir,
help='directory to download and load the weights, ' help='directory to download and load the weights, '
'default to the default cache dir of ' 'default to the default cache dir of '
'huggingface') 'huggingface')
parser.add_argument('--use-np-weights', action='store_true', parser.add_argument('--use-np-weights',
action='store_true',
help='save a numpy copy of model weights for ' help='save a numpy copy of model weights for '
'faster loading. This can increase the disk ' 'faster loading. This can increase the disk '
'usage by up to 2x.') 'usage by up to 2x.')
parser.add_argument('--use-dummy-weights', action='store_true', parser.add_argument('--use-dummy-weights',
action='store_true',
help='use dummy values for model weights') help='use dummy values for model weights')
# TODO(woosuk): Support FP32. # TODO(woosuk): Support FP32.
parser.add_argument('--dtype', type=str, default=EngineArgs.dtype, parser.add_argument(
'--dtype',
type=str,
default=EngineArgs.dtype,
choices=['auto', 'half', 'bfloat16', 'float'], choices=['auto', 'half', 'bfloat16', 'float'],
help='data type for model weights and activations. ' help='data type for model weights and activations. '
'The "auto" option will use FP16 precision ' 'The "auto" option will use FP16 precision '
'for FP32 and FP16 models, and BF16 precision ' 'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.') 'for BF16 models.')
# Parallel arguments # Parallel arguments
parser.add_argument('--worker-use-ray', action='store_true', parser.add_argument('--worker-use-ray',
action='store_true',
help='use Ray for distributed serving, will be ' help='use Ray for distributed serving, will be '
'automatically set when using more than 1 GPU') 'automatically set when using more than 1 GPU')
parser.add_argument('--pipeline-parallel-size', '-pp', type=int, parser.add_argument('--pipeline-parallel-size',
'-pp',
type=int,
default=EngineArgs.pipeline_parallel_size, default=EngineArgs.pipeline_parallel_size,
help='number of pipeline stages') help='number of pipeline stages')
parser.add_argument('--tensor-parallel-size', '-tp', type=int, parser.add_argument('--tensor-parallel-size',
'-tp',
type=int,
default=EngineArgs.tensor_parallel_size, default=EngineArgs.tensor_parallel_size,
help='number of tensor parallel replicas') help='number of tensor parallel replicas')
# KV cache arguments # KV cache arguments
parser.add_argument('--block-size', type=int, parser.add_argument('--block-size',
type=int,
default=EngineArgs.block_size, default=EngineArgs.block_size,
choices=[8, 16, 32], choices=[8, 16, 32],
help='token block size') help='token block size')
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request). # TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
parser.add_argument('--seed', type=int, default=EngineArgs.seed, parser.add_argument('--seed',
type=int,
default=EngineArgs.seed,
help='random seed') help='random seed')
parser.add_argument('--swap-space', type=int, parser.add_argument('--swap-space',
type=int,
default=EngineArgs.swap_space, default=EngineArgs.swap_space,
help='CPU swap space size (GiB) per GPU') help='CPU swap space size (GiB) per GPU')
parser.add_argument('--gpu-memory-utilization', type=float, parser.add_argument('--gpu-memory-utilization',
type=float,
default=EngineArgs.gpu_memory_utilization, default=EngineArgs.gpu_memory_utilization,
help='the percentage of GPU memory to be used for' help='the percentage of GPU memory to be used for'
'the model executor') 'the model executor')
parser.add_argument('--max-num-batched-tokens', type=int, parser.add_argument('--max-num-batched-tokens',
type=int,
default=EngineArgs.max_num_batched_tokens, default=EngineArgs.max_num_batched_tokens,
help='maximum number of batched tokens per ' help='maximum number of batched tokens per '
'iteration') 'iteration')
parser.add_argument('--max-num-seqs', type=int, parser.add_argument('--max-num-seqs',
type=int,
default=EngineArgs.max_num_seqs, default=EngineArgs.max_num_seqs,
help='maximum number of sequences per iteration') help='maximum number of sequences per iteration')
parser.add_argument('--disable-log-stats', action='store_true', parser.add_argument('--disable-log-stats',
action='store_true',
help='disable logging statistics') help='disable logging statistics')
return parser return parser
@classmethod @classmethod
def from_cli_args(cls, args: argparse.Namespace) -> "EngineArgs": def from_cli_args(cls, args: argparse.Namespace) -> 'EngineArgs':
# Get the list of attributes of this dataclass. # Get the list of attributes of this dataclass.
attrs = [attr.name for attr in dataclasses.fields(cls)] attrs = [attr.name for attr in dataclasses.fields(cls)]
# Set the attributes from the parsed arguments. # Set the attributes from the parsed arguments.
@ -103,16 +144,20 @@ class EngineArgs:
self, self,
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]: ) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]:
# Initialize the configs. # Initialize the configs.
model_config = ModelConfig( model_config = ModelConfig(self.model, self.tokenizer,
self.model, self.download_dir, self.use_np_weights, self.tokenizer_mode, self.trust_remote_code,
self.use_dummy_weights, self.dtype, self.seed) self.download_dir, self.use_np_weights,
cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization, self.use_dummy_weights, self.dtype,
self.seed)
cache_config = CacheConfig(self.block_size,
self.gpu_memory_utilization,
self.swap_space) self.swap_space)
parallel_config = ParallelConfig(self.pipeline_parallel_size, parallel_config = ParallelConfig(self.pipeline_parallel_size,
self.tensor_parallel_size, self.tensor_parallel_size,
self.worker_use_ray) self.worker_use_ray)
scheduler_config = SchedulerConfig(self.max_num_batched_tokens, scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
self.max_num_seqs) self.max_num_seqs,
model_config.get_max_model_len())
return model_config, cache_config, parallel_config, scheduler_config return model_config, cache_config, parallel_config, scheduler_config
@ -124,12 +169,13 @@ class AsyncEngineArgs(EngineArgs):
@staticmethod @staticmethod
def add_cli_args( def add_cli_args(
parser: argparse.ArgumentParser, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
) -> argparse.ArgumentParser:
parser = EngineArgs.add_cli_args(parser) parser = EngineArgs.add_cli_args(parser)
parser.add_argument('--engine-use-ray', action='store_true', parser.add_argument('--engine-use-ray',
action='store_true',
help='use Ray to start the LLM engine in a ' help='use Ray to start the LLM engine in a '
'separate process as the server process.') 'separate process as the server process.')
parser.add_argument('--disable-log-requests', action='store_true', parser.add_argument('--disable-log-requests',
action='store_true',
help='disable logging requests') help='disable logging requests')
return parser return parser

View File

@ -2,6 +2,7 @@ import asyncio
import time import time
from typing import Dict, List, Optional from typing import Dict, List, Optional
from vllm.config import ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.llm_engine import LLMEngine from vllm.engine.llm_engine import LLMEngine
from vllm.engine.ray_utils import initialize_cluster, ray from vllm.engine.ray_utils import initialize_cluster, ray
@ -35,8 +36,13 @@ class AsyncLLMEngine:
log_requests: Whether to log the requests. log_requests: Whether to log the requests.
*args, *kwargs: Arguments for LLMEngine. *args, *kwargs: Arguments for LLMEngine.
""" """
def __init__(self, worker_use_ray: bool, engine_use_ray: bool,
log_requests: bool = True, *args, **kwargs) -> None: def __init__(self,
worker_use_ray: bool,
engine_use_ray: bool,
*args,
log_requests: bool = True,
**kwargs) -> None:
self.worker_use_ray = worker_use_ray self.worker_use_ray = worker_use_ray
self.engine_use_ray = engine_use_ray self.engine_use_ray = engine_use_ray
self.log_requests = log_requests self.log_requests = log_requests
@ -80,8 +86,7 @@ class AsyncLLMEngine:
prompt: Optional[str], prompt: Optional[str],
sampling_params: SamplingParams, sampling_params: SamplingParams,
request_id: str, request_id: str,
prompt_token_ids: Optional[List[int]] = None prompt_token_ids: Optional[List[int]] = None) -> RequestOutput:
) -> RequestOutput:
"""Generate outputs for a request. """Generate outputs for a request.
Generate outputs for a request. This method is a coroutine. It adds the Generate outputs for a request. This method is a coroutine. It adds the
@ -117,12 +122,15 @@ class AsyncLLMEngine:
# Add the request into the vLLM engine's waiting queue. # Add the request into the vLLM engine's waiting queue.
if self.engine_use_ray: if self.engine_use_ray:
await self.engine.add_request.remote( await self.engine.add_request.remote(
request_id, prompt, sampling_params, request_id,
prompt,
sampling_params,
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time) arrival_time=arrival_time)
else: else:
self.engine.add_request( self.engine.add_request(request_id,
request_id, prompt, sampling_params, prompt,
sampling_params,
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time) arrival_time=arrival_time)
@ -136,7 +144,11 @@ class AsyncLLMEngine:
# Kick the engine if the engine is not running. # Kick the engine if the engine is not running.
if not self.is_engine_running: if not self.is_engine_running:
try:
await self.engine_step(request_id) await self.engine_step(request_id)
except RuntimeError as e:
await self.abort(request_id)
raise e
# Wait for new output. The group_event will be set in engine_step # Wait for new output. The group_event will be set in engine_step
# when there is new output available for the sequence group. # when there is new output available for the sequence group.
@ -199,20 +211,29 @@ class AsyncLLMEngine:
self.is_engine_running = False self.is_engine_running = False
self.kicking_request_id = None self.kicking_request_id = None
async def get_model_config(self) -> ModelConfig:
"""Get the model configuration of the vLLM engine."""
if self.engine_use_ray:
return await self.engine.get_model_config.remote()
else:
return self.engine.get_model_config()
@classmethod @classmethod
def from_engine_args(cls, engine_args: AsyncEngineArgs) -> "AsyncLLMEngine": def from_engine_args(cls,
engine_args: AsyncEngineArgs) -> "AsyncLLMEngine":
"""Creates an async LLM engine from the engine arguments.""" """Creates an async LLM engine from the engine arguments."""
# Create the engine configs. # Create the engine configs.
engine_configs = engine_args.create_engine_configs() engine_configs = engine_args.create_engine_configs()
parallel_config = engine_configs[2] parallel_config = engine_configs[2]
# Initialize the cluster. # Initialize the cluster.
distributed_init_method, devices = initialize_cluster( distributed_init_method, placement_group = initialize_cluster(
parallel_config, engine_args.engine_use_ray) parallel_config, engine_args.engine_use_ray)
# Create the async LLM engine. # Create the async LLM engine.
engine = cls(engine_args.worker_use_ray, engine = cls(engine_args.worker_use_ray,
engine_args.engine_use_ray, engine_args.engine_use_ray,
not engine_args.disable_log_requests,
*engine_configs, *engine_configs,
distributed_init_method, devices, distributed_init_method,
placement_group,
log_requests=not engine_args.disable_log_requests,
log_stats=not engine_args.disable_log_stats) log_stats=not engine_args.disable_log_stats)
return engine return engine

View File

@ -1,21 +1,32 @@
import time import time
from typing import Any, List, Optional import copy
from functools import partial
from typing import Any, List, Optional, Tuple, TYPE_CHECKING
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig) SchedulerConfig)
from vllm.core.scheduler import Scheduler from vllm.core.scheduler import Scheduler
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.engine.ray_utils import DeviceID, initialize_cluster, ray from vllm.engine.ray_utils import initialize_cluster, ray, RayWorker
from vllm.engine.tokenizer_utils import detokenize_incrementally, get_tokenizer
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
get_tokenizer)
from vllm.utils import Counter from vllm.utils import Counter
from vllm.worker.worker import Worker
if ray:
from ray.air.util.torch_dist import init_torch_dist_process_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
logger = init_logger(__name__) logger = init_logger(__name__)
_LOGGING_INTERVAL_SEC = 5
class LLMEngine: class LLMEngine:
"""An LLM engine that receives requests and generates texts. """An LLM engine that receives requests and generates texts.
@ -53,19 +64,21 @@ class LLMEngine:
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
distributed_init_method: str, distributed_init_method: str,
stage_devices: List[List[DeviceID]], placement_group: Optional["PlacementGroup"],
log_stats: bool, log_stats: bool,
) -> None: ) -> None:
logger.info( logger.info(
"Initializing an LLM engine with config: " "Initializing an LLM engine with config: "
f"model={model_config.model!r}, " f"model={model_config.model!r}, "
f"tokenizer={model_config.tokenizer!r}, "
f"tokenizer_mode={model_config.tokenizer_mode}, "
f"trust_remote_code={model_config.trust_remote_code}, "
f"dtype={model_config.dtype}, " f"dtype={model_config.dtype}, "
f"use_dummy_weights={model_config.use_dummy_weights}, " f"use_dummy_weights={model_config.use_dummy_weights}, "
f"download_dir={model_config.download_dir!r}, " f"download_dir={model_config.download_dir!r}, "
f"use_np_weights={model_config.use_np_weights}, " f"use_np_weights={model_config.use_np_weights}, "
f"tensor_parallel_size={parallel_config.tensor_parallel_size}, " f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
f"seed={model_config.seed})" f"seed={model_config.seed})")
)
# TODO(woosuk): Print more configs in debug mode. # TODO(woosuk): Print more configs in debug mode.
self.model_config = model_config self.model_config = model_config
@ -75,34 +88,89 @@ class LLMEngine:
self.log_stats = log_stats self.log_stats = log_stats
self._verify_args() self._verify_args()
self.tokenizer = get_tokenizer(model_config.model) self.tokenizer = get_tokenizer(
model_config.tokenizer,
tokenizer_mode=model_config.tokenizer_mode,
trust_remote_code=model_config.trust_remote_code)
self.seq_counter = Counter() self.seq_counter = Counter()
# Create the parallel GPU workers. # Create the parallel GPU workers.
self.workers: List[Worker] = []
assert len(stage_devices) == 1, "Only support one stage for now."
for rank, node_resource, _ in stage_devices[0]:
worker_cls = Worker
if self.parallel_config.worker_use_ray: if self.parallel_config.worker_use_ray:
worker_cls = ray.remote( self._init_workers_ray(placement_group)
num_cpus=0, else:
num_gpus=1, self._init_workers(distributed_init_method)
resources={node_resource: 1e-3},
)(worker_cls).remote
worker = worker_cls(
model_config,
parallel_config,
scheduler_config,
rank,
distributed_init_method,
)
self.workers.append(worker)
# Profile the memory usage and initialize the cache. # Profile the memory usage and initialize the cache.
self._init_cache() self._init_cache()
# Create the scheduler. # Create the scheduler.
self.scheduler = Scheduler(scheduler_config, cache_config, log_stats) self.scheduler = Scheduler(scheduler_config, cache_config)
# Logging.
self.last_logging_time = 0.0
# List of (timestamp, num_tokens)
self.num_prompt_tokens: List[Tuple[float, int]] = []
# List of (timestamp, num_tokens)
self.num_generation_tokens: List[Tuple[float, int]] = []
def _init_workers(self, distributed_init_method: str):
# Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker
from vllm.worker.worker import Worker # pylint: disable=import-outside-toplevel
assert self.parallel_config.world_size == 1, (
"Ray is required if parallel_config.world_size > 1.")
self.workers: List[Worker] = []
worker = Worker(
self.model_config,
self.parallel_config,
self.scheduler_config,
0,
distributed_init_method,
)
self.workers.append(worker)
self._run_workers(
"init_model",
get_all_outputs=True,
)
def _init_workers_ray(self, placement_group: "PlacementGroup"):
# Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker
from vllm.worker.worker import Worker # pylint: disable=import-outside-toplevel
self.workers: List[Worker] = []
for bundle in placement_group.bundle_specs:
if not bundle.get("GPU", 0):
continue
worker = ray.remote(
num_cpus=0,
num_gpus=1,
scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=placement_group,
placement_group_capture_child_tasks=True),
)(RayWorker).remote()
self.workers.append(worker)
# Initialize torch distributed process group for the workers.
init_torch_dist_process_group(self.workers, backend="nccl")
model_config = copy.deepcopy(self.model_config)
parallel_config = copy.deepcopy(self.parallel_config)
scheduler_config = copy.deepcopy(self.scheduler_config)
self._run_workers("init_worker",
get_all_outputs=True,
worker_init_fn=lambda: Worker(
model_config,
parallel_config,
scheduler_config,
None,
None,
))
self._run_workers(
"init_model",
get_all_outputs=True,
)
def _verify_args(self) -> None: def _verify_args(self) -> None:
self.model_config.verify_with_parallel_config(self.parallel_config) self.model_config.verify_with_parallel_config(self.parallel_config)
@ -125,10 +193,10 @@ class LLMEngine:
num_gpu_blocks = min(b[0] for b in num_blocks) num_gpu_blocks = min(b[0] for b in num_blocks)
num_cpu_blocks = min(b[1] for b in num_blocks) num_cpu_blocks = min(b[1] for b in num_blocks)
# FIXME(woosuk): Change to debug log. # FIXME(woosuk): Change to debug log.
logger.info(f'# GPU blocks: {num_gpu_blocks}, ' logger.info(f"# GPU blocks: {num_gpu_blocks}, "
f'# CPU blocks: {num_cpu_blocks}') f"# CPU blocks: {num_cpu_blocks}")
if num_gpu_blocks <= 0 or num_cpu_blocks <= 0: if num_gpu_blocks <= 0:
raise ValueError("No available memory for the cache blocks. " raise ValueError("No available memory for the cache blocks. "
"Try increasing `gpu_memory_utilization` when " "Try increasing `gpu_memory_utilization` when "
"initializing the engine.") "initializing the engine.")
@ -146,9 +214,12 @@ class LLMEngine:
engine_configs = engine_args.create_engine_configs() engine_configs = engine_args.create_engine_configs()
parallel_config = engine_configs[2] parallel_config = engine_configs[2]
# Initialize the cluster. # Initialize the cluster.
distributed_init_method, devices = initialize_cluster(parallel_config) distributed_init_method, placement_group = initialize_cluster(
parallel_config)
# Create the LLM engine. # Create the LLM engine.
engine = cls(*engine_configs, distributed_init_method, devices, engine = cls(*engine_configs,
distributed_init_method,
placement_group,
log_stats=not engine_args.disable_log_stats) log_stats=not engine_args.disable_log_stats)
return engine return engine
@ -205,6 +276,10 @@ class LLMEngine:
""" """
self.scheduler.abort_seq_group(request_id) self.scheduler.abort_seq_group(request_id)
def get_model_config(self) -> ModelConfig:
"""Gets the model configuration."""
return self.model_config
def get_num_unfinished_requests(self) -> int: def get_num_unfinished_requests(self) -> int:
"""Gets the number of unfinished requests.""" """Gets the number of unfinished requests."""
return self.scheduler.get_num_unfinished_seq_groups() return self.scheduler.get_num_unfinished_seq_groups()
@ -223,9 +298,16 @@ class LLMEngine:
the sequences and returns the newly generated results. the sequences and returns the newly generated results.
""" """
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule() seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
if (not seq_group_metadata_list) and scheduler_outputs.is_empty(): if scheduler_outputs.is_empty():
if not scheduler_outputs.ignored_seq_groups:
# Nothing to do. # Nothing to do.
return [] return []
# If there are ignored seq groups, we need to return them as the
# request outputs.
return [
RequestOutput.from_seq_group(seq_group)
for seq_group in scheduler_outputs.ignored_seq_groups
]
# Execute the model. # Execute the model.
output = self._run_workers( output = self._run_workers(
@ -247,11 +329,79 @@ class LLMEngine:
# Create the outputs. # Create the outputs.
request_outputs: List[RequestOutput] = [] request_outputs: List[RequestOutput] = []
for seq_group in seq_groups: for seq_group in seq_groups + scheduler_outputs.ignored_seq_groups:
request_output = RequestOutput.from_seq_group(seq_group) request_output = RequestOutput.from_seq_group(seq_group)
request_outputs.append(request_output) request_outputs.append(request_output)
if self.log_stats:
# Log the system stats.
self._log_system_stats(scheduler_outputs.prompt_run,
scheduler_outputs.num_batched_tokens)
return request_outputs return request_outputs
def _log_system_stats(
self,
prompt_run: bool,
num_batched_tokens: int,
) -> None:
now = time.time()
# Log the number of batched input tokens.
if prompt_run:
self.num_prompt_tokens.append((now, num_batched_tokens))
else:
self.num_generation_tokens.append((now, num_batched_tokens))
elapsed_time = now - self.last_logging_time
if elapsed_time < _LOGGING_INTERVAL_SEC:
return
# Discard the old stats.
self.num_prompt_tokens = [(t, n) for t, n in self.num_prompt_tokens
if now - t < _LOGGING_INTERVAL_SEC]
self.num_generation_tokens = [(t, n)
for t, n in self.num_generation_tokens
if now - t < _LOGGING_INTERVAL_SEC]
if len(self.num_prompt_tokens) > 1:
total_num_tokens = sum(n for _, n in self.num_prompt_tokens[:-1])
window = now - self.num_prompt_tokens[0][0]
avg_prompt_throughput = total_num_tokens / window
else:
avg_prompt_throughput = 0.0
if len(self.num_generation_tokens) > 1:
total_num_tokens = sum(n
for _, n in self.num_generation_tokens[:-1])
window = now - self.num_generation_tokens[0][0]
avg_generation_throughput = total_num_tokens / window
else:
avg_generation_throughput = 0.0
total_num_gpu_blocks = self.cache_config.num_gpu_blocks
num_free_gpu_blocks = (
self.scheduler.block_manager.get_num_free_gpu_blocks())
num_used_gpu_blocks = total_num_gpu_blocks - num_free_gpu_blocks
gpu_cache_usage = num_used_gpu_blocks / total_num_gpu_blocks
total_num_cpu_blocks = self.cache_config.num_cpu_blocks
if total_num_cpu_blocks > 0:
num_free_cpu_blocks = (
self.scheduler.block_manager.get_num_free_cpu_blocks())
num_used_cpu_blocks = total_num_cpu_blocks - num_free_cpu_blocks
cpu_cache_usage = num_used_cpu_blocks / total_num_cpu_blocks
else:
cpu_cache_usage = 0.0
logger.info("Avg prompt throughput: "
f"{avg_prompt_throughput:.1f} tokens/s, "
"Avg generation throughput: "
f"{avg_generation_throughput:.1f} tokens/s, "
f"Running: {len(self.scheduler.running)} reqs, "
f"Swapped: {len(self.scheduler.swapped)} reqs, "
f"Pending: {len(self.scheduler.waiting)} reqs, "
f"GPU KV cache usage: {gpu_cache_usage * 100:.1f}%, "
f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%")
self.last_logging_time = now
def _decode_sequences(self, seq_groups: List[SequenceGroup]) -> None: def _decode_sequences(self, seq_groups: List[SequenceGroup]) -> None:
"""Decodes the sequence outputs.""" """Decodes the sequence outputs."""
for seq_group in seq_groups: for seq_group in seq_groups:
@ -262,6 +412,7 @@ class LLMEngine:
seq.get_last_token_id(), seq.get_last_token_id(),
skip_special_tokens=True, skip_special_tokens=True,
) )
if new_token is not None:
seq.output_tokens.append(new_token) seq.output_tokens.append(new_token)
seq.output_text = new_output_text seq.output_text = new_output_text
@ -277,13 +428,18 @@ class LLMEngine:
# Truncate the output text so that the stop string is # Truncate the output text so that the stop string is
# not included in the output. # not included in the output.
seq.output_text = seq.output_text[:-len(stop_str)] seq.output_text = seq.output_text[:-len(stop_str)]
self.scheduler.free_seq(seq, self.scheduler.free_seq(
SequenceStatus.FINISHED_STOPPED) seq, SequenceStatus.FINISHED_STOPPED)
stopped = True stopped = True
break break
if stopped: if stopped:
continue continue
# Check if the sequence has reached max_model_len.
if seq.get_len() > self.scheduler_config.max_model_len:
self.scheduler.free_seq(
seq, SequenceStatus.FINISHED_LENGTH_CAPPED)
continue
# Check if the sequence has reached max_tokens. # Check if the sequence has reached max_tokens.
if seq.get_output_len() == sampling_params.max_tokens: if seq.get_output_len() == sampling_params.max_tokens:
self.scheduler.free_seq( self.scheduler.free_seq(
@ -292,23 +448,24 @@ class LLMEngine:
# Check if the sequence has generated the EOS token. # Check if the sequence has generated the EOS token.
if not sampling_params.ignore_eos: if not sampling_params.ignore_eos:
if seq.get_last_token_id() == self.tokenizer.eos_token_id: if seq.get_last_token_id() == self.tokenizer.eos_token_id:
self.scheduler.free_seq(seq, self.scheduler.free_seq(
SequenceStatus.FINISHED_STOPPED) seq, SequenceStatus.FINISHED_STOPPED)
continue continue
def _run_workers( def _run_workers(
self, self,
method: str, method: str,
get_all_outputs: bool = False,
*args, *args,
get_all_outputs: bool = False,
**kwargs, **kwargs,
) -> Any: ) -> Any:
"""Runs the given method on all workers.""" """Runs the given method on all workers."""
all_outputs = [] all_outputs = []
for worker in self.workers: for worker in self.workers:
executor = getattr(worker, method)
if self.parallel_config.worker_use_ray: if self.parallel_config.worker_use_ray:
executor = executor.remote executor = partial(worker.execute_method.remote, method)
else:
executor = getattr(worker, method)
output = executor(*args, **kwargs) output = executor(*args, **kwargs)
all_outputs.append(output) all_outputs.append(output)

View File

@ -1,21 +1,49 @@
import random import socket
from typing import List, Optional, Tuple from typing import Optional, Tuple, TYPE_CHECKING
try:
import ray
except ImportError:
ray = None
from vllm.config import ParallelConfig from vllm.config import ParallelConfig
DeviceID = Tuple[int, Optional[str], int] # rank, node resource (node IP), device id try:
import ray
from ray.air.util.torch_dist import TorchDistributedWorker
class RayWorker(TorchDistributedWorker):
"""Ray wrapper for vllm.worker.Worker, allowing Worker to be
lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES."""
def __init__(self) -> None:
self.worker = None
def init_worker(self, worker_init_fn):
self.worker = worker_init_fn()
def __getattr__(self, name):
return getattr(self.worker, name)
def execute_method(self, method, *args, **kwargs):
executor = getattr(self, method)
return executor(*args, **kwargs)
except ImportError:
ray = None
TorchDistributedWorker = None
RayWorker = None # pylint: disable=invalid-name
if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
def get_open_port():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return s.getsockname()[1]
def initialize_cluster( def initialize_cluster(
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
engine_use_ray: bool = False, engine_use_ray: bool = False,
ray_address: Optional[str] = None, ray_address: Optional[str] = None,
) -> Tuple[str, List[List[DeviceID]]]: ) -> Tuple[str, Optional["PlacementGroup"]]:
"""Initialize the distributed cluster probably with Ray. """Initialize the distributed cluster probably with Ray.
Args: Args:
@ -37,71 +65,46 @@ def initialize_cluster(
"Ray is not installed. Please install Ray to use distributed " "Ray is not installed. Please install Ray to use distributed "
"serving.") "serving.")
# Connect to a ray cluster. # Connect to a ray cluster.
ray.init(address=ray_address) ray.init(address=ray_address, ignore_reinit_error=True)
if not parallel_config.worker_use_ray: if not parallel_config.worker_use_ray:
# Initialize cluster locally. # Initialize cluster locally.
port = random.randint(10000, 20000) port = get_open_port()
# We need to setup the distributed init method to make sure # We need to setup the distributed init method to make sure
# the distributed megatron code (e.g., get world size) works correctly. # the distributed megatron code (e.g., get world size) works correctly.
distributed_init_method = f"tcp://localhost:{port}" distributed_init_method = f"tcp://localhost:{port}"
all_stage_devices = [[(0, None, 0)]] return distributed_init_method, None
return distributed_init_method, all_stage_devices
# Assume we have a uniform cluster that each node has the same number of current_placement_group = ray.util.get_current_placement_group()
# GPUs for now. if current_placement_group:
valid_node_resources = [] # We are in a placement group
num_devices_per_node = None bundles = current_placement_group.bundle_specs
for node in ray.nodes(): # Verify that we can use the placement group.
if (not node['Alive']) or node['Resources']['GPU'] <= 0: gpu_bundles = 0
continue for bundle in bundles:
if num_devices_per_node is None: bundle_gpus = bundle.get("GPU", 0)
num_devices_per_node = node['Resources']['GPU'] if bundle_gpus > 1:
else: raise ValueError(
assert num_devices_per_node == node['Resources']['GPU'], ( "Placement group bundle cannot have more than 1 GPU.")
"The number of GPUs per node is not uniform.") if bundle_gpus:
for key in node['Resources']: gpu_bundles += 1
if key.startswith('node:'): if parallel_config.world_size > gpu_bundles:
valid_node_resources.append(key)
# Verify the parallel config.
num_nodes = len(valid_node_resources)
if parallel_config.world_size > num_nodes * num_devices_per_node:
raise ValueError( raise ValueError(
"The number of required GPUs exceeds the total number of " "The number of required GPUs exceeds the total number of "
"available GPUs.") "available GPUs in the placement group.")
if parallel_config.tensor_parallel_size >= num_devices_per_node:
if parallel_config.tensor_parallel_size % num_devices_per_node != 0:
raise ValueError(
"The number of tensor parallelism is not divisible by the "
"number of GPUs per node.")
else: else:
if num_devices_per_node % parallel_config.tensor_parallel_size != 0: num_gpus_in_cluster = ray.cluster_resources().get("GPU", 0)
if parallel_config.world_size > num_gpus_in_cluster:
raise ValueError( raise ValueError(
"The number of GPUs per node is not divisible by the number " "The number of required GPUs exceeds the total number of "
"of tensor parallelism.") "available GPUs in the cluster.")
# Create a new placement group
current_placement_group = ray.util.placement_group([{
"GPU": 1
}] * parallel_config.world_size)
# Wait until PG is ready - this will block until all
# requested resources are available, and will timeout
# if they cannot be provisioned.
ray.get(current_placement_group.ready(), timeout=1800)
# Assign GPUs to pipeline stages. return None, current_placement_group
rank = 0
current_node_id = 0
current_device_id = 0
distributed_init_method = None
all_stage_devices = []
for _ in range(parallel_config.pipeline_parallel_size):
stage_devices = []
for _ in range(parallel_config.tensor_parallel_size):
node_resource = valid_node_resources[current_node_id]
stage_devices.append((rank, node_resource, current_device_id))
if distributed_init_method is None:
ip = node_resource.split("node:")[-1]
port = random.randint(10000, 20000)
distributed_init_method = f"tcp://{ip}:{port}"
rank += 1
current_device_id += 1
if current_device_id >= num_devices_per_node:
current_node_id += 1
current_device_id = 0
all_stage_devices.append(stage_devices)
return distributed_init_method, all_stage_devices

View File

@ -1,92 +0,0 @@
from typing import List, Tuple, Union
from transformers import (AutoConfig, AutoTokenizer, PreTrainedTokenizer,
PreTrainedTokenizerFast)
from vllm.logger import init_logger
logger = init_logger(__name__)
_MODEL_TYPES_WITH_SLOW_TOKENIZER = []
def get_tokenizer(
model_name: str,
*args,
**kwargs,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
"""Gets a tokenizer for the given model name via Huggingface."""
config = AutoConfig.from_pretrained(model_name)
if "open_llama" in model_name:
kwargs["use_fast"] = False
logger.info(
"OpenLLaMA models do not support the fast tokenizer. "
"Using the slow tokenizer instead.")
elif config.model_type == "llama" and getattr(kwargs, "use_fast", True):
# LLaMA fast tokenizer causes protobuf errors in some environments.
# However, we found that the below LLaMA fast tokenizer works well in
# most environments.
model_name = "hf-internal-testing/llama-tokenizer"
logger.info(
f"Using the LLaMA fast tokenizer in '{model_name}' to avoid "
"potential protobuf errors.")
elif config.model_type in _MODEL_TYPES_WITH_SLOW_TOKENIZER:
if getattr(kwargs, "use_fast", False) == True:
raise ValueError(
f"Cannot use the fast tokenizer for {config.model_type} due to "
"bugs in the fast tokenizer.")
logger.info(
f"Using the slow tokenizer for {config.model_type} due to bugs in "
"the fast tokenizer. This could potentially lead to performance "
"degradation.")
kwargs["use_fast"] = False
return AutoTokenizer.from_pretrained(model_name, *args, **kwargs)
def detokenize_incrementally(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
prev_output_tokens: List[str],
new_token_id: int,
skip_special_tokens: bool,
) -> Tuple[str, str]:
"""Detokenizes the new token in conjuction with the previous output tokens.
NOTE: This function does not update prev_output_tokens.
Returns:
new_token: The new token as a string.
output_text: The new output text as a string.
"""
new_token = tokenizer.convert_ids_to_tokens(
new_token_id, skip_special_tokens=skip_special_tokens)
output_tokens = prev_output_tokens + [new_token]
# Convert the tokens to a string.
# Optimization: If the tokenizer does not have `added_tokens_encoder`,
# then we can directly use `convert_tokens_to_string`.
if not getattr(tokenizer, "added_tokens_encoder", {}):
output_text = tokenizer.convert_tokens_to_string(output_tokens)
return new_token, output_text
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921
# NOTE(woosuk): The following code is slow because it runs a for loop over
# the output_tokens. In Python, running a for loop over a list can be slow
# even when the loop body is very simple.
sub_texts = []
current_sub_text = []
for token in output_tokens:
if skip_special_tokens and token in tokenizer.all_special_ids:
continue
if token in tokenizer.added_tokens_encoder:
if current_sub_text:
sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
sub_texts.append(sub_text)
current_sub_text = []
sub_texts.append(token)
else:
current_sub_text.append(token)
if current_sub_text:
sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
sub_texts.append(sub_text)
output_text = " ".join(sub_texts)
return new_token, output_text

View File

@ -3,7 +3,7 @@ import json
from typing import AsyncGenerator from typing import AsyncGenerator
from fastapi import BackgroundTasks, FastAPI, Request from fastapi import BackgroundTasks, FastAPI, Request
from fastapi.responses import Response, StreamingResponse from fastapi.responses import JSONResponse, Response, StreamingResponse
import uvicorn import uvicorn
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
@ -12,7 +12,7 @@ from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid from vllm.utils import random_uuid
TIMEOUT_KEEP_ALIVE = 5 # seconds. TIMEOUT_KEEP_ALIVE = 5 # seconds.
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds.
app = FastAPI() app = FastAPI()
@ -37,8 +37,7 @@ async def generate(request: Request) -> Response:
async for request_output in results_generator: async for request_output in results_generator:
prompt = request_output.prompt prompt = request_output.prompt
text_outputs = [ text_outputs = [
prompt + output.text prompt + output.text for output in request_output.outputs
for output in request_output.outputs
] ]
ret = {"text": text_outputs} ret = {"text": text_outputs}
yield (json.dumps(ret) + "\0").encode("utf-8") yield (json.dumps(ret) + "\0").encode("utf-8")
@ -63,12 +62,9 @@ async def generate(request: Request) -> Response:
assert final_output is not None assert final_output is not None
prompt = final_output.prompt prompt = final_output.prompt
text_outputs = [ text_outputs = [prompt + output.text for output in final_output.outputs]
prompt + output.text
for output in final_output.outputs
]
ret = {"text": text_outputs} ret = {"text": text_outputs}
return Response(content=json.dumps(ret)) return JSONResponse(ret)
if __name__ == "__main__": if __name__ == "__main__":
@ -81,5 +77,8 @@ if __name__ == "__main__":
engine_args = AsyncEngineArgs.from_cli_args(args) engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(engine_args) engine = AsyncLLMEngine.from_engine_args(engine_args)
uvicorn.run(app, host=args.host, port=args.port, log_level="debug", uvicorn.run(app,
host=args.host,
port=args.port,
log_level="debug",
timeout_keep_alive=TIMEOUT_KEEP_ALIVE) timeout_keep_alive=TIMEOUT_KEEP_ALIVE)

View File

@ -25,6 +25,11 @@ class LLM:
Args: Args:
model: The name or path of a HuggingFace Transformers model. model: The name or path of a HuggingFace Transformers model.
tokenizer: The name or path of a HuggingFace Transformers tokenizer.
tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
if available, and "slow" will always use the slow tokenizer.
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
downloading the model and tokenizer.
tensor_parallel_size: The number of GPUs to use for distributed tensor_parallel_size: The number of GPUs to use for distributed
execution with tensor parallelism. execution with tensor parallelism.
dtype: The data type for the model weights and activations. Currently, dtype: The data type for the model weights and activations. Currently,
@ -38,6 +43,9 @@ class LLM:
def __init__( def __init__(
self, self,
model: str, model: str,
tokenizer: Optional[str] = None,
tokenizer_mode: str = "auto",
trust_remote_code: bool = False,
tensor_parallel_size: int = 1, tensor_parallel_size: int = 1,
dtype: str = "auto", dtype: str = "auto",
seed: int = 0, seed: int = 0,
@ -47,6 +55,9 @@ class LLM:
kwargs["disable_log_stats"] = True kwargs["disable_log_stats"] = True
engine_args = EngineArgs( engine_args = EngineArgs(
model=model, model=model,
tokenizer=tokenizer,
tokenizer_mode=tokenizer_mode,
trust_remote_code=trust_remote_code,
tensor_parallel_size=tensor_parallel_size, tensor_parallel_size=tensor_parallel_size,
dtype=dtype, dtype=dtype,
seed=seed, seed=seed,
@ -56,10 +67,15 @@ class LLM:
self.request_counter = Counter() self.request_counter = Counter()
def get_tokenizer( def get_tokenizer(
self, self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
return self.llm_engine.tokenizer return self.llm_engine.tokenizer
def set_tokenizer(
self,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
) -> None:
self.llm_engine.tokenizer = tokenizer
def generate( def generate(
self, self,
prompts: Optional[Union[str, List[str]]] = None, prompts: Optional[Union[str, List[str]]] = None,
@ -139,4 +155,8 @@ class LLM:
pbar.update(1) pbar.update(1)
if use_tqdm: if use_tqdm:
pbar.close() pbar.close()
# Sort the outputs by request ID.
# This is necessary because some requests may be finished earlier than
# its previous requests.
outputs = sorted(outputs, key=lambda x: int(x.request_id))
return outputs return outputs

View File

@ -1,30 +1,44 @@
# Adapted from https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/serve/openai_api_server.py # Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/serve/openai_api_server.py
import argparse import argparse
from http import HTTPStatus import asyncio
import json import json
import time import time
from typing import AsyncGenerator, Dict, List, Optional from http import HTTPStatus
from typing import AsyncGenerator, Dict, List, Optional, Tuple, Union
import fastapi import fastapi
import uvicorn
from fastapi import BackgroundTasks, Request from fastapi import BackgroundTasks, Request
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse from fastapi.responses import JSONResponse, StreamingResponse
import uvicorn from packaging import version
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.tokenizer_utils import get_tokenizer
from vllm.entrypoints.openai.protocol import ( from vllm.entrypoints.openai.protocol import (
CompletionRequest, CompletionResponse, CompletionResponseChoice, CompletionRequest, CompletionResponse, CompletionResponseChoice,
CompletionResponseStreamChoice, CompletionStreamResponse, ErrorResponse, CompletionResponseStreamChoice, CompletionStreamResponse,
ChatCompletionRequest, ChatCompletionResponse,
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
LogProbs, ModelCard, ModelList, ModelPermission, UsageInfo) LogProbs, ModelCard, ModelList, ModelPermission, UsageInfo)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils import random_uuid from vllm.utils import random_uuid
try:
import fastchat
from fastchat.conversation import Conversation, SeparatorStyle
from fastchat.model.model_adapter import get_conversation_template
_fastchat_available = True
except ImportError:
_fastchat_available = False
TIMEOUT_KEEP_ALIVE = 5 # seconds TIMEOUT_KEEP_ALIVE = 5 # seconds
logger = init_logger(__name__) logger = init_logger(__name__)
@ -34,14 +48,13 @@ app = fastapi.FastAPI()
def create_error_response(status_code: HTTPStatus, def create_error_response(status_code: HTTPStatus,
message: str) -> JSONResponse: message: str) -> JSONResponse:
return JSONResponse( return JSONResponse(ErrorResponse(message=message,
ErrorResponse(message=message, type="invalid_request_error").dict(), type="invalid_request_error").dict(),
status_code=status_code.value status_code=status_code.value)
)
@app.exception_handler(RequestValidationError) @app.exception_handler(RequestValidationError)
async def validation_exception_handler(request, exc): async def validation_exception_handler(request, exc): # pylint: disable=unused-argument
return create_error_response(HTTPStatus.BAD_REQUEST, str(exc)) return create_error_response(HTTPStatus.BAD_REQUEST, str(exc))
@ -55,11 +68,88 @@ async def check_model(request) -> Optional[JSONResponse]:
return ret return ret
async def get_gen_prompt(request) -> str:
if not _fastchat_available:
raise ModuleNotFoundError(
"fastchat is not installed. Please install fastchat to use "
"the chat completion and conversation APIs: `$ pip install fschat`"
)
if version.parse(fastchat.__version__) < version.parse("0.2.23"):
raise ImportError(
f"fastchat version is low. Current version: {fastchat.__version__} "
"Please upgrade fastchat to use: `$ pip install -U fschat`")
conv = get_conversation_template(request.model)
conv = Conversation(
name=conv.name,
system_template=conv.system_template,
system_message=conv.system_message,
roles=conv.roles,
messages=list(conv.messages), # prevent in-place modification
offset=conv.offset,
sep_style=SeparatorStyle(conv.sep_style),
sep=conv.sep,
sep2=conv.sep2,
stop_str=conv.stop_str,
stop_token_ids=conv.stop_token_ids,
)
if isinstance(request.messages, str):
prompt = request.messages
else:
for message in request.messages:
msg_role = message["role"]
if msg_role == "system":
conv.system_message = message["content"]
elif msg_role == "user":
conv.append_message(conv.roles[0], message["content"])
elif msg_role == "assistant":
conv.append_message(conv.roles[1], message["content"])
else:
raise ValueError(f"Unknown role: {msg_role}")
# Add a blank message for the assistant.
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
return prompt
async def check_length(
request: Union[ChatCompletionRequest, CompletionRequest],
prompt: Optional[str] = None,
prompt_ids: Optional[List[int]] = None
) -> Tuple[List[int], Optional[JSONResponse]]:
assert (not (prompt is None and prompt_ids is None)
and not (prompt is not None and prompt_ids is not None)
), "Either prompt or prompt_ids should be provided."
if prompt_ids is not None:
input_ids = prompt_ids
else:
input_ids = tokenizer(prompt).input_ids
token_num = len(input_ids)
if token_num + request.max_tokens > max_model_len:
return input_ids, create_error_response(
HTTPStatus.BAD_REQUEST,
f"This model's maximum context length is {max_model_len} tokens. "
f"However, you requested {request.max_tokens + token_num} tokens "
f"({token_num} in the messages, "
f"{request.max_tokens} in the completion). "
f"Please reduce the length of the messages or completion.",
)
else:
return input_ids, None
@app.get("/v1/models") @app.get("/v1/models")
async def show_available_models(): async def show_available_models():
"""Show available models. Right now we only have one model.""" """Show available models. Right now we only have one model."""
model_cards = [ModelCard(id=served_model, root=served_model, model_cards = [
permission=[ModelPermission()])] ModelCard(id=served_model,
root=served_model,
permission=[ModelPermission()])
]
return ModelList(data=model_cards) return ModelList(data=model_cards)
@ -76,15 +166,187 @@ def create_logprobs(token_ids: List[int],
if len(logprobs.text_offset) == 0: if len(logprobs.text_offset) == 0:
logprobs.text_offset.append(initial_text_offset) logprobs.text_offset.append(initial_text_offset)
else: else:
logprobs.text_offset.append(logprobs.text_offset[-1] + last_token_len) logprobs.text_offset.append(logprobs.text_offset[-1] +
last_token_len)
last_token_len = len(token) last_token_len = len(token)
logprobs.top_logprobs.append( logprobs.top_logprobs.append({
{tokenizer.convert_ids_to_tokens(i): p tokenizer.convert_ids_to_tokens(i): p
for i, p in id_logprob.items()}) for i, p in id_logprob.items()
})
return logprobs return logprobs
@app.post("/v1/chat/completions")
async def create_chat_completion(raw_request: Request):
"""Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/chat/create
for the API specification. This API mimics the OpenAI ChatCompletion API.
NOTE: Currently we do not support the following features:
- function_call (Users should implement this by themselves)
- logit_bias (to be supported by vLLM engine)
"""
request = ChatCompletionRequest(**await raw_request.json())
logger.info(f"Received chat completion request: {request}")
error_check_ret = await check_model(request)
if error_check_ret is not None:
return error_check_ret
if request.logit_bias is not None:
# TODO: support logit_bias in vLLM engine.
return create_error_response(HTTPStatus.BAD_REQUEST,
"logit_bias is not currently supported")
prompt = await get_gen_prompt(request)
token_ids, error_check_ret = await check_length(request, prompt=prompt)
if error_check_ret is not None:
return error_check_ret
model_name = request.model
request_id = f"cmpl-{random_uuid()}"
created_time = int(time.time())
try:
sampling_params = SamplingParams(
n=request.n,
presence_penalty=request.presence_penalty,
frequency_penalty=request.frequency_penalty,
temperature=request.temperature,
top_p=request.top_p,
stop=request.stop,
max_tokens=request.max_tokens,
best_of=request.best_of,
top_k=request.top_k,
ignore_eos=request.ignore_eos,
use_beam_search=request.use_beam_search,
)
except ValueError as e:
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
result_generator = engine.generate(prompt, sampling_params, request_id,
token_ids)
async def abort_request() -> None:
await engine.abort(request_id)
def create_stream_response_json(
index: int,
text: str,
finish_reason: Optional[str] = None,
) -> str:
choice_data = ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(content=text),
finish_reason=finish_reason,
)
response = ChatCompletionStreamResponse(
id=request_id,
created=created_time,
model=model_name,
choices=[choice_data],
)
response_json = response.json(ensure_ascii=False)
return response_json
async def completion_stream_generator() -> AsyncGenerator[str, None]:
# First chunk with role
for i in range(request.n):
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(role="assistant"),
finish_reason=None,
)
chunk = ChatCompletionStreamResponse(id=request_id,
choices=[choice_data],
model=model_name)
data = chunk.json(exclude_unset=True, ensure_ascii=False)
yield f"data: {data}\n\n"
previous_texts = [""] * request.n
previous_num_tokens = [0] * request.n
async for res in result_generator:
res: RequestOutput
for output in res.outputs:
i = output.index
delta_text = output.text[len(previous_texts[i]):]
previous_texts[i] = output.text
previous_num_tokens[i] = len(output.token_ids)
response_json = create_stream_response_json(
index=i,
text=delta_text,
)
yield f"data: {response_json}\n\n"
if output.finish_reason is not None:
response_json = create_stream_response_json(
index=i,
text="",
finish_reason=output.finish_reason,
)
yield f"data: {response_json}\n\n"
yield "data: [DONE]\n\n"
# Streaming response
if request.stream:
background_tasks = BackgroundTasks()
# Abort the request if the client disconnects.
background_tasks.add_task(abort_request)
return StreamingResponse(completion_stream_generator(),
media_type="text/event-stream",
background=background_tasks)
# Non-streaming response
final_res: RequestOutput = None
async for res in result_generator:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await abort_request()
return create_error_response(HTTPStatus.BAD_REQUEST,
"Client disconnected")
final_res = res
assert final_res is not None
choices = []
for output in final_res.outputs:
choice_data = ChatCompletionResponseChoice(
index=output.index,
message=ChatMessage(role="assistant", content=output.text),
finish_reason=output.finish_reason,
)
choices.append(choice_data)
num_prompt_tokens = len(final_res.prompt_token_ids)
num_generated_tokens = sum(
len(output.token_ids) for output in final_res.outputs)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens + num_generated_tokens,
)
response = ChatCompletionResponse(
id=request_id,
created=created_time,
model=model_name,
choices=choices,
usage=usage,
)
if request.stream:
# When user requests streaming but we don't stream, we still need to
# return a streaming response with a single event.
response_json = response.json(ensure_ascii=False)
async def fake_stream_generator() -> AsyncGenerator[str, None]:
yield f"data: {response_json}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(fake_stream_generator(),
media_type="text/event-stream")
return response
@app.post("/v1/completions") @app.post("/v1/completions")
async def create_completion(raw_request: Request): async def create_completion(raw_request: Request):
"""Completion API similar to OpenAI's API. """Completion API similar to OpenAI's API.
@ -124,7 +386,34 @@ async def create_completion(raw_request: Request):
model_name = request.model model_name = request.model
request_id = f"cmpl-{random_uuid()}" request_id = f"cmpl-{random_uuid()}"
use_token_ids = False
if isinstance(request.prompt, list):
if len(request.prompt) == 0:
return create_error_response(HTTPStatus.BAD_REQUEST,
"please provide at least one prompt")
first_element = request.prompt[0]
if isinstance(first_element, int):
use_token_ids = True
prompt = request.prompt prompt = request.prompt
elif isinstance(first_element, (str, list)):
# TODO: handles multiple prompt case in list[list[int]]
if len(request.prompt) > 1:
return create_error_response(
HTTPStatus.BAD_REQUEST,
"multiple prompts in a batch is not currently supported")
use_token_ids = not isinstance(first_element, str)
prompt = request.prompt[0]
else:
prompt = request.prompt
if use_token_ids:
_, error_check_ret = await check_length(request, prompt_ids=prompt)
else:
token_ids, error_check_ret = await check_length(request, prompt=prompt)
if error_check_ret is not None:
return error_check_ret
created_time = int(time.time()) created_time = int(time.time())
try: try:
sampling_params = SamplingParams( sampling_params = SamplingParams(
@ -144,22 +433,30 @@ async def create_completion(raw_request: Request):
except ValueError as e: except ValueError as e:
return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
result_generator = engine.generate(prompt, sampling_params, if use_token_ids:
request_id) result_generator = engine.generate(None,
sampling_params,
request_id,
prompt_token_ids=prompt)
else:
result_generator = engine.generate(prompt, sampling_params, request_id,
token_ids)
# Similar to the OpenAI API, when n != best_of, we do not stream the # Similar to the OpenAI API, when n != best_of, we do not stream the
# results. In addition, we do not stream the results when use beam search. # results. In addition, we do not stream the results when use beam search.
stream = (request.stream and stream = (request.stream
(request.best_of is None or request.n == request.best_of) and and (request.best_of is None or request.n == request.best_of)
not request.use_beam_search) and not request.use_beam_search)
async def abort_request() -> None: async def abort_request() -> None:
await engine.abort(request_id) await engine.abort(request_id)
def create_stream_response_json(index: int, def create_stream_response_json(
index: int,
text: str, text: str,
logprobs: Optional[LogProbs] = None, logprobs: Optional[LogProbs] = None,
finish_reason: Optional[str] = None) -> str: finish_reason: Optional[str] = None,
) -> str:
choice_data = CompletionResponseStreamChoice( choice_data = CompletionResponseStreamChoice(
index=index, index=index,
text=text, text=text,
@ -200,7 +497,8 @@ async def create_completion(raw_request: Request):
) )
yield f"data: {response_json}\n\n" yield f"data: {response_json}\n\n"
if output.finish_reason is not None: if output.finish_reason is not None:
logprobs = LogProbs() if request.logprobs is not None else None logprobs = (LogProbs()
if request.logprobs is not None else None)
response_json = create_stream_response_json( response_json = create_stream_response_json(
index=i, index=i,
text="", text="",
@ -244,8 +542,8 @@ async def create_completion(raw_request: Request):
choices.append(choice_data) choices.append(choice_data)
num_prompt_tokens = len(final_res.prompt_token_ids) num_prompt_tokens = len(final_res.prompt_token_ids)
num_generated_tokens = sum(len(output.token_ids) num_generated_tokens = sum(
for output in final_res.outputs) len(output.token_ids) for output in final_res.outputs)
usage = UsageInfo( usage = UsageInfo(
prompt_tokens=num_prompt_tokens, prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens, completion_tokens=num_generated_tokens,
@ -263,9 +561,11 @@ async def create_completion(raw_request: Request):
# When user requests streaming but we don't stream, we still need to # When user requests streaming but we don't stream, we still need to
# return a streaming response with a single event. # return a streaming response with a single event.
response_json = response.json(ensure_ascii=False) response_json = response.json(ensure_ascii=False)
async def fake_stream_generator() -> AsyncGenerator[str, None]: async def fake_stream_generator() -> AsyncGenerator[str, None]:
yield f"data: {response_json}\n\n" yield f"data: {response_json}\n\n"
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
return StreamingResponse(fake_stream_generator(), return StreamingResponse(fake_stream_generator(),
media_type="text/event-stream") media_type="text/event-stream")
@ -274,26 +574,34 @@ async def create_completion(raw_request: Request):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="vLLM OpenAI-Compatible RESTful API server." description="vLLM OpenAI-Compatible RESTful API server.")
) parser.add_argument("--host",
parser.add_argument("--host", type=str, default="localhost", help="host name") type=str,
default="localhost",
help="host name")
parser.add_argument("--port", type=int, default=8000, help="port number") parser.add_argument("--port", type=int, default=8000, help="port number")
parser.add_argument( parser.add_argument("--allow-credentials",
"--allow-credentials", action="store_true", help="allow credentials" action="store_true",
) help="allow credentials")
parser.add_argument( parser.add_argument("--allowed-origins",
"--allowed-origins", type=json.loads, default=["*"], help="allowed origins" type=json.loads,
) default=["*"],
parser.add_argument( help="allowed origins")
"--allowed-methods", type=json.loads, default=["*"], help="allowed methods" parser.add_argument("--allowed-methods",
) type=json.loads,
parser.add_argument( default=["*"],
"--allowed-headers", type=json.loads, default=["*"], help="allowed headers" help="allowed methods")
) parser.add_argument("--allowed-headers",
parser.add_argument("--served-model-name", type=str, default=None, type=json.loads,
help="The model name used in the API. If not specified, " default=["*"],
"the model name will be the same as the " help="allowed headers")
"huggingface name.") parser.add_argument("--served-model-name",
type=str,
default=None,
help="The model name used in the API. If not "
"specified, the model name will be the same as "
"the huggingface name.")
parser = AsyncEngineArgs.add_cli_args(parser) parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args() args = parser.parse_args()
@ -307,13 +615,23 @@ if __name__ == "__main__":
logger.info(f"args: {args}") logger.info(f"args: {args}")
served_model = args.served_model_name or args.model if args.served_model_name is not None:
served_model = args.served_model_name
else:
served_model = args.model
engine_args = AsyncEngineArgs.from_cli_args(args) engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(engine_args) engine = AsyncLLMEngine.from_engine_args(engine_args)
engine_model_config = asyncio.run(engine.get_model_config())
max_model_len = engine_model_config.get_max_model_len()
# A separate tokenizer to map token IDs to strings. # A separate tokenizer to map token IDs to strings.
tokenizer = get_tokenizer(args.model) tokenizer = get_tokenizer(engine_args.tokenizer,
tokenizer_mode=engine_args.tokenizer_mode,
trust_remote_code=engine_args.trust_remote_code)
uvicorn.run(app, host=args.host, port=args.port, log_level="info", uvicorn.run(app,
host=args.host,
port=args.port,
log_level="info",
timeout_keep_alive=TIMEOUT_KEEP_ALIVE) timeout_keep_alive=TIMEOUT_KEEP_ALIVE)

View File

@ -1,4 +1,5 @@
# Adapted from https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py # Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
import time import time
from typing import Dict, List, Literal, Optional, Union from typing import Dict, List, Literal, Optional, Union
@ -53,21 +54,28 @@ class UsageInfo(BaseModel):
class ChatCompletionRequest(BaseModel): class ChatCompletionRequest(BaseModel):
model: str model: str
messages: List[Dict[str, str]] messages: Union[str, List[Dict[str, str]]]
temperature: Optional[float] = 0.7 temperature: Optional[float] = 0.7
top_p: Optional[float] = 1.0 top_p: Optional[float] = 1.0
n: Optional[int] = 1 n: Optional[int] = 1
max_tokens: Optional[int] = None max_tokens: Optional[int] = 16
stop: Optional[Union[str, List[str]]] = None stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
stream: Optional[bool] = False stream: Optional[bool] = False
presence_penalty: Optional[float] = 0.0 presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0 frequency_penalty: Optional[float] = 0.0
logit_bias: Optional[Dict[str, float]] = None
user: Optional[str] = None user: Optional[str] = None
# Additional parameters supported by vLLM
best_of: Optional[int] = None
top_k: Optional[int] = -1
ignore_eos: Optional[bool] = False
use_beam_search: Optional[bool] = False
class CompletionRequest(BaseModel): class CompletionRequest(BaseModel):
model: str model: str
prompt: str # a string, array of strings, array of tokens, or array of token arrays
prompt: Union[List[int], List[List[int]], str, List[str]]
suffix: Optional[str] = None suffix: Optional[str] = None
max_tokens: Optional[int] = 16 max_tokens: Optional[int] = 16
temperature: Optional[float] = 1.0 temperature: Optional[float] = 1.0
@ -92,7 +100,8 @@ class LogProbs(BaseModel):
text_offset: List[int] = Field(default_factory=list) text_offset: List[int] = Field(default_factory=list)
token_logprobs: List[Optional[float]] = Field(default_factory=list) token_logprobs: List[Optional[float]] = Field(default_factory=list)
tokens: List[str] = Field(default_factory=list) tokens: List[str] = Field(default_factory=list)
top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list) top_logprobs: List[Optional[Dict[str,
float]]] = Field(default_factory=list)
class CompletionResponseChoice(BaseModel): class CompletionResponseChoice(BaseModel):
@ -124,3 +133,42 @@ class CompletionStreamResponse(BaseModel):
created: int = Field(default_factory=lambda: int(time.time())) created: int = Field(default_factory=lambda: int(time.time()))
model: str model: str
choices: List[CompletionResponseStreamChoice] choices: List[CompletionResponseStreamChoice]
class ChatMessage(BaseModel):
role: str
content: str
class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
finish_reason: Optional[Literal["stop", "length"]] = None
class ChatCompletionResponse(BaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
object: str = "chat.completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseChoice]
usage: UsageInfo
class DeltaMessage(BaseModel):
role: Optional[str] = None
content: Optional[str] = None
class ChatCompletionResponseStreamChoice(BaseModel):
index: int
delta: DeltaMessage
finish_reason: Optional[Literal["stop", "length"]] = None
class ChatCompletionStreamResponse(BaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
object: str = "chat.completion.chunk"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseStreamChoice]

View File

@ -1,9 +1,9 @@
# Adapted from https://github.com/skypilot-org/skypilot/blob/86dc0f6283a335e4aa37b3c10716f90999f48ab6/sky/sky_logging.py # Adapted from
# https://github.com/skypilot-org/skypilot/blob/86dc0f6283a335e4aa37b3c10716f90999f48ab6/sky/sky_logging.py
"""Logging configuration for vLLM."""
import logging import logging
import sys import sys
_FORMAT = "%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s" _FORMAT = "%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s"
_DATE_FORMAT = "%m-%d %H:%M:%S" _DATE_FORMAT = "%m-%d %H:%M:%S"

View File

@ -2,7 +2,6 @@ from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.model_executor.utils import set_random_seed from vllm.model_executor.utils import set_random_seed
__all__ = [ __all__ = [
"InputMetadata", "InputMetadata",
"get_model", "get_model",

View File

@ -1,18 +1,29 @@
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
import torch import torch
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask from xformers.ops import AttentionBias
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import SequenceData from vllm.sequence import SequenceData
class InputMetadata: class InputMetadata:
"""Metadata for input sequences. Used for PagedAttention.
Args:
seq_groups: List of (seq_ids, sampling_params).
seq_data: Seq_id -> SequenceData.
prompt_lens: Lengths of prompts.
slot_mapping: The address to write the new KV to of each token.
context_lens: the length of attention context for each generation token.
max_context_len: The maximum context length.
block_tables: The block tables. (Seq id -> list of physical block)
"""
def __init__( def __init__(
self, self,
seq_groups: List[Tuple[List[int], SamplingParams]], # List of (seq_ids, sampling_params). seq_groups: List[Tuple[List[int], SamplingParams]],
seq_data: Dict[int, SequenceData], # Seq_id -> SequenceData. seq_data: Dict[int, SequenceData],
prompt_lens: List[int], prompt_lens: List[int],
slot_mapping: torch.Tensor, slot_mapping: torch.Tensor,
context_lens: torch.Tensor, context_lens: torch.Tensor,
@ -27,7 +38,6 @@ class InputMetadata:
self.max_context_len = max_context_len self.max_context_len = max_context_len
self.block_tables = block_tables self.block_tables = block_tables
self.attn_bias = BlockDiagonalCausalMask.from_seqlens(prompt_lens)
self.num_prompts = len(prompt_lens) self.num_prompts = len(prompt_lens)
self.num_prompt_tokens = sum(prompt_lens) self.num_prompt_tokens = sum(prompt_lens)
self.num_generation_tokens = context_lens.shape[0] self.num_generation_tokens = context_lens.shape[0]
@ -39,6 +49,9 @@ class InputMetadata:
assert block_tables.shape[0] == self.num_generation_tokens assert block_tables.shape[0] == self.num_generation_tokens
assert context_lens.shape[0] == self.num_generation_tokens assert context_lens.shape[0] == self.num_generation_tokens
# Set during the execution of the first attention op.
self.attn_bias: List[AttentionBias] = []
def __repr__(self) -> str: def __repr__(self) -> str:
# Print only useful metadata. # Print only useful metadata.
return (f'InputMetadata(' return (f'InputMetadata('

View File

@ -4,10 +4,50 @@ import torch.nn as nn
from vllm import activation_ops from vllm import activation_ops
class SiluAndMul(nn.Module):
"""An activation function for SwiGLU.
The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[1] // 2.
Shapes:
x: (num_tokens, 2 * d)
return: (num_tokens, d)
"""
def forward(self, x: torch.Tensor) -> torch.Tensor:
num_tokens = x.shape[0]
d = x.shape[1] // 2
out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device)
activation_ops.silu_and_mul(out, x)
return out
class NewGELU(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
num_tokens = x.shape[0]
d = x.shape[1]
out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device)
activation_ops.gelu_new(out, x)
return out
class FastGELU(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
num_tokens = x.shape[0]
d = x.shape[1]
out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device)
activation_ops.gelu_fast(out, x)
return out
_ACTIVATION_REGISTRY = { _ACTIVATION_REGISTRY = {
"gelu": nn.GELU(), "gelu": nn.GELU(),
"gelu_new": nn.GELU(approximate="tanh"), # NOTE: This may introduce small rounding errors. "gelu_fast": FastGELU(),
"gelu_fast": nn.GELU(approximate="tanh"), # NOTE: This may introduce small rounding errors. "gelu_new": NewGELU(),
"gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
"relu": nn.ReLU(), "relu": nn.ReLU(),
} }
@ -18,23 +58,3 @@ def get_act_fn(act_fn: str) -> nn.Module:
if act_fn in _ACTIVATION_REGISTRY: if act_fn in _ACTIVATION_REGISTRY:
return _ACTIVATION_REGISTRY[act_fn] return _ACTIVATION_REGISTRY[act_fn]
raise ValueError(f"Activation function {act_fn!r} is not supported.") raise ValueError(f"Activation function {act_fn!r} is not supported.")
class SiluAndMul(nn.Module):
"""An activation function for SwiGLU.
The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[1] // 2.
"""
def __init__(self):
super().__init__()
def forward(
self,
x: torch.Tensor, # (num_tokens, 2 * d)
) -> torch.Tensor: # (num_tokens, d)
num_tokens = x.shape[0]
d = x.shape[1] // 2
out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device)
activation_ops.silu_and_mul(out, x)
return out

View File

@ -1,28 +1,39 @@
"""Multi-head attention.""" """Multi-head attention."""
from typing import Optional from typing import List, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from xformers import ops as xops from xformers import ops as xops
from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
LowerTriangularMaskWithTensorBias)
from vllm import attention_ops from vllm import attention_ops
from vllm import cache_ops from vllm import cache_ops
from vllm import pos_encoding_ops from vllm import pos_encoding_ops
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
_SUPPORTED_HEAD_SIZES = [64, 80, 96, 128] _SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256]
class PagedAttention(nn.Module): class PagedAttention(nn.Module):
# pylint: disable=line-too-long
"""GPT-style multi-head PagedAttention. """GPT-style multi-head PagedAttention.
This class takes flattened 1D query, key, and value tensors as input. The This class takes flattened 1D query, key, and value tensors as input. The
input 1D tensors can be split into three parts: the prompt tokens, the input 1D tensors can either contain prompt tokens or generation tokens, in
generation tokens, and the paddings. addition to paddings.
|<------------------------------------- num_valid_tokens ------------------------------------->| If the input tensors contain prompt tokens, the layout is as follows:
|<--------------- num_prompt_tokens -------------->|<------- num_generation_tokens (M) ------->|
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|<--generation_0-->|...|<--generation_M-1-->|<--padding-->| |<---------------------- num_valid_tokens ---------------------->|
|<--------------- num_prompt_tokens -------------->|
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|<--padding-->|
Otherwise, the layout is as follows:
|<------------------ num_valid_tokens ------------------->|
|<------- num_generation_tokens (M) ------->|
|<--generation_0-->|...|<--generation_M-1-->|<--padding-->|
The prompts might have different lengths, while the generation tokens always The prompts might have different lengths, while the generation tokens always
have length 1. The paddings are appended to make the input length a multiple have length 1. The paddings are appended to make the input length a multiple
@ -41,31 +52,67 @@ class PagedAttention(nn.Module):
5. Output a flattened 1D tensor. 5. Output a flattened 1D tensor.
""" """
def __init__(self, num_heads: int, head_size: int, scale: float) -> None: def __init__(self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: Optional[int] = None) -> None:
super().__init__() super().__init__()
self.num_heads = num_heads self.num_heads = num_heads
self.head_size = head_size self.head_size = head_size
self.scale = float(scale) self.scale = float(scale)
self.attn_op = xops.fmha.cutlass.FwOp() self.attn_op = xops.fmha.cutlass.FwOp()
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.head_mapping = torch.repeat_interleave(
torch.arange(self.num_kv_heads, dtype=torch.int32, device="cuda"),
self.num_queries_per_kv)
if self.head_size not in _SUPPORTED_HEAD_SIZES: if self.head_size not in _SUPPORTED_HEAD_SIZES:
raise ValueError(f"head_size ({self.head_size}) is not supported. " raise ValueError(f"head_size ({self.head_size}) is not supported. "
f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.") f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.")
def set_attn_bias(self, input_metadata: InputMetadata) -> None:
if input_metadata.attn_bias:
# Already set by a previous layer.
return
prompt_lens = input_metadata.prompt_lens
attn_bias = BlockDiagonalCausalMask.from_seqlens(prompt_lens)
input_metadata.attn_bias.append(attn_bias)
def multi_query_kv_attention( def multi_query_kv_attention(
self, self,
output: torch.Tensor, # [num_prompt_tokens, num_heads, head_size] output: torch.Tensor,
query: torch.Tensor, # [num_prompt_tokens, num_heads, head_size] query: torch.Tensor,
key: torch.Tensor, # [num_prompt_tokens, num_heads, head_size] key: torch.Tensor,
value: torch.Tensor, # [num_prompt_tokens, num_heads, head_size] value: torch.Tensor,
attn_bias: xops.AttentionBias, input_metadata: InputMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
"""Normal attention for the prompt tokens.
Args:
output: shape = [num_prompt_tokens, num_heads, head_size]
query: shape = [num_prompt_tokens, num_heads, head_size]
key: shape = [num_prompt_tokens, num_kv_heads, head_size]
value: shape = [num_prompt_tokens, num_kv_heads, head_size]
input_metadata: metadata for paged attention.
"""
if self.num_kv_heads != self.num_heads:
# Project the key and value tensors to the desired number of heads.
key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=1)
value = torch.repeat_interleave(value,
self.num_queries_per_kv,
dim=1)
# TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize. # TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize.
out = xops.memory_efficient_attention_forward( out = xops.memory_efficient_attention_forward(
query.unsqueeze(0), query.unsqueeze(0),
key.unsqueeze(0), key.unsqueeze(0),
value.unsqueeze(0), value.unsqueeze(0),
attn_bias=attn_bias, attn_bias=input_metadata.attn_bias[0],
p=0.0, p=0.0,
scale=self.scale, scale=self.scale,
op=self.attn_op, op=self.attn_op,
@ -76,42 +123,72 @@ class PagedAttention(nn.Module):
def single_query_cached_kv_attention( def single_query_cached_kv_attention(
self, self,
output: torch.Tensor, # [num_generation_tokens, num_heads, head_size] output: torch.Tensor,
query: torch.Tensor, # [num_generation_tokens, num_heads, head_size] query: torch.Tensor,
key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x] key_cache: torch.Tensor,
value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size] value_cache: torch.Tensor,
input_metadata: InputMetadata, input_metadata: InputMetadata,
) -> None: ) -> None:
"""PagedAttention for the generation tokens.
Args:
output: shape = [num_generation_tokens, num_heads, head_size]
query: shape = [num_generation_tokens, num_heads, head_size]
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
block_size, x]
value_cache: shape = [num_blocks, num_kv_heads, head_size,
block_size]
input_metadata: metadata for paged attention.
"""
block_size = value_cache.shape[3] block_size = value_cache.shape[3]
attention_ops.single_query_cached_kv_attention( attention_ops.single_query_cached_kv_attention(
output, output,
query, query,
key_cache, key_cache,
value_cache, value_cache,
self.head_mapping,
self.scale, self.scale,
input_metadata.block_tables, input_metadata.block_tables,
input_metadata.context_lens, input_metadata.context_lens,
block_size, block_size,
input_metadata.max_context_len, input_metadata.max_context_len,
None, # alibi_slopes
) )
def forward( def forward(
self, self,
query: torch.Tensor, # [num_tokens, num_heads * head_size] query: torch.Tensor,
key: torch.Tensor, # [num_tokens, num_heads * head_size] key: torch.Tensor,
value: torch.Tensor, # [num_tokens, num_heads * head_size] value: torch.Tensor,
key_cache: Optional[torch.Tensor], # [num_blocks, num_heads, head_size/x, block_size, x] key_cache: Optional[torch.Tensor],
value_cache: Optional[torch.Tensor], # [num_blocks, num_heads, head_size, block_size] value_cache: Optional[torch.Tensor],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event], cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: # [num_tokens, num_heads * head_size] ) -> torch.Tensor:
# NOTE: The query, key, and value tensors must be sliced from a qkv """PagedAttention forward pass.
# tensor of shape [num_tokens, 3 * num_heads * head_size].
NOTE: The query, key, and value tensors must be sliced from a qkv
tensor of shape [num_tokens, 3 * num_heads * head_size].
Args:
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
block_size, x]
value_cache: shape = [num_blocks, num_kv_heads, head_size,
block_size]
input_metadata: metadata for paged attention.
cache_event: event to wait for the cache operations to finish.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
# Reshape the query, key, and value tensors. # Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size) query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_heads, self.head_size) key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_heads, self.head_size) value = value.view(-1, self.num_kv_heads, self.head_size)
# Pre-allocate the output tensor. # Pre-allocate the output tensor.
output = torch.empty_like(query) output = torch.empty_like(query)
@ -119,12 +196,15 @@ class PagedAttention(nn.Module):
# Compute the attention op for prompts. # Compute the attention op for prompts.
num_prompt_tokens = input_metadata.num_prompt_tokens num_prompt_tokens = input_metadata.num_prompt_tokens
if num_prompt_tokens > 0: if num_prompt_tokens > 0:
# Prompt run.
assert input_metadata.num_generation_tokens == 0
self.set_attn_bias(input_metadata)
self.multi_query_kv_attention( self.multi_query_kv_attention(
output[:num_prompt_tokens], output[:num_prompt_tokens],
query[:num_prompt_tokens], query[:num_prompt_tokens],
key[:num_prompt_tokens], key[:num_prompt_tokens],
value[:num_prompt_tokens], value[:num_prompt_tokens],
input_metadata.attn_bias, input_metadata,
) )
# Wait until the cache op is done. # Wait until the cache op is done.
@ -147,17 +227,16 @@ class PagedAttention(nn.Module):
) )
if input_metadata.num_generation_tokens > 0: if input_metadata.num_generation_tokens > 0:
# Decoding run.
assert input_metadata.num_prompt_tokens == 0
assert key_cache is not None and value_cache is not None, ( assert key_cache is not None and value_cache is not None, (
"key_cache and value_cache must be provided when " "key_cache and value_cache must be provided when "
"generating tokens." "generating tokens.")
)
# Compute the attention op for generation tokens. # Compute the attention op for generation tokens.
self.single_query_cached_kv_attention( self.single_query_cached_kv_attention(
output[num_prompt_tokens:num_valid_tokens], output[num_prompt_tokens:num_valid_tokens],
query[num_prompt_tokens:num_valid_tokens], query[num_prompt_tokens:num_valid_tokens], key_cache,
key_cache, value_cache, input_metadata)
value_cache,
input_metadata)
# Reshape the output tensor. # Reshape the output tensor.
# NOTE(woosuk): The output tensor may include paddings. # NOTE(woosuk): The output tensor may include paddings.
@ -175,19 +254,21 @@ class PagedAttentionWithRoPE(PagedAttention):
rotary_dim: int, rotary_dim: int,
max_position: int = 8192, max_position: int = 8192,
base: int = 10000, base: int = 10000,
num_kv_heads: Optional[int] = None,
) -> None: ) -> None:
super().__init__(num_heads, head_size, scale) super().__init__(num_heads, head_size, scale, num_kv_heads)
# Create the cos and sin cache. # Create the cos and sin cache.
inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim)) inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim))
t = torch.arange(max_position).float() t = torch.arange(max_position).float()
freqs = torch.einsum('i,j -> ij', t, inv_freq.float()) freqs = torch.einsum("i,j -> ij", t, inv_freq.float())
cos = freqs.cos() cos = freqs.cos()
sin = freqs.sin() sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1) cache = torch.cat((cos, sin), dim=-1)
# FIXME(woosuk): This assumes that we configure the default dtype when # FIXME(woosuk): This assumes that we configure the default dtype when
# initializing the model. Make it more robust. # initializing the model.
# TODO(woosuk): Make it more robust.
torch_dtype = torch.get_default_dtype() torch_dtype = torch.get_default_dtype()
cache = cache.to(torch_dtype) cache = cache.to(torch_dtype)
# Embedding size: [max_position, rotary_dim] # Embedding size: [max_position, rotary_dim]
@ -195,15 +276,33 @@ class PagedAttentionWithRoPE(PagedAttention):
def forward( def forward(
self, self,
positions: torch.Tensor, # [num_tokens] positions: torch.Tensor,
query: torch.Tensor, # [num_tokens, num_heads * head_size] query: torch.Tensor,
key: torch.Tensor, # [num_tokens, num_heads * head_size] key: torch.Tensor,
value: torch.Tensor, # [num_tokens, num_heads * head_size] value: torch.Tensor,
key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x] key_cache: torch.Tensor,
value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size] value_cache: torch.Tensor,
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event], cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: # [num_tokens, num_heads * head_size] ) -> torch.Tensor:
""" PagedAttention forward pass with rotary embedding.
Args:
positions: shape = [num_tokens]
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
block_size, x]
value_cache: shape = [num_blocks, num_kv_heads, head_size,
block_size]
input_metadata: metadata for paged attention.
cache_event: event to wait for the cache operations to finish.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
# Apply rotary embedding to the query and key before passing them # Apply rotary embedding to the query and key before passing them
# to the attention op. # to the attention op.
pos_encoding_ops.rotary_embedding_neox( pos_encoding_ops.rotary_embedding_neox(
@ -222,3 +321,126 @@ class PagedAttentionWithRoPE(PagedAttention):
input_metadata, input_metadata,
cache_event, cache_event,
) )
class PagedAttentionWithALiBi(PagedAttention):
"""PagedAttention with ALiBi attention bias."""
def __init__(self,
num_heads: int,
head_size: int,
scale: float,
slopes: List[float],
num_kv_heads: Optional[int] = None) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads)
assert len(slopes) == num_heads
slopes = torch.tensor(slopes, dtype=torch.float32)
self.register_buffer("alibi_slopes", slopes, persistent=False)
def set_attn_bias(self, input_metadata: InputMetadata) -> None:
if input_metadata.attn_bias:
# Already set by a previous layer.
return
# Generates ALiBi mask for each prompt.
for prompt_len in input_metadata.prompt_lens:
bias = torch.arange(prompt_len)
# Note(zhuohan): HF uses
# `bias = bias[None, :].repeat(prompt_len, 1)`
# here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi
# paper.
bias = bias[None, :] - bias[:, None]
bias = bias.to(self.alibi_slopes.device)
# When using custom attention bias, xformers requires the bias to
# be sliced from a tensor whose length is a multiple of 8.
padded_len = (prompt_len + 7) // 8 * 8
bias = torch.empty(
1, # batch_size
self.num_heads,
prompt_len,
padded_len,
device=self.alibi_slopes.device,
)[:, :, :, :prompt_len].copy_(bias)
bias.mul_(self.alibi_slopes[:, None, None])
attn_bias = LowerTriangularMaskWithTensorBias(bias)
input_metadata.attn_bias.append(attn_bias)
def multi_query_kv_attention(
self,
output: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
input_metadata: InputMetadata,
) -> torch.Tensor:
"""Attention with ALiBi bias for the prompt tokens.
Args:
output: shape = [num_prompt_tokens, num_heads, head_size]
query: shape = [num_prompt_tokens, num_heads, head_size]
key: shape = [num_prompt_tokens, num_kv_heads, head_size]
value: shape = [num_prompt_tokens, num_kv_heads, head_size]
input_metadata: metadata for paged attention.
"""
if self.num_kv_heads != self.num_heads:
# Project the key and value tensors to the desired number of heads.
key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=1)
value = torch.repeat_interleave(value,
self.num_queries_per_kv,
dim=1)
# FIXME(woosuk): Because xformers does not support dynamic sequence
# lengths with custom attention bias, we process each prompt one by
# one. This is inefficient, especially when we have many short prompts.
start = 0
for i, prompt_len in enumerate(input_metadata.prompt_lens):
end = start + prompt_len
out = xops.memory_efficient_attention_forward(
query[None, start:end],
key[None, start:end],
value[None, start:end],
attn_bias=input_metadata.attn_bias[i],
p=0.0,
scale=self.scale,
op=self.attn_op,
)
# TODO(woosuk): Unnecessary copy. Optimize.
output[start:end].copy_(out.squeeze(0))
start += prompt_len
return output
def single_query_cached_kv_attention(
self,
output: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
input_metadata: InputMetadata,
) -> None:
"""PagedAttention with ALiBi bias for the generation tokens.
Args:
output: shape = [num_generation_tokens, num_heads, head_size]
query: shape = [num_generation_tokens, num_heads, head_size]
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
block_size, x]
value_cache: shape = [num_blocks, num_kv_heads, head_size,
block_size]
input_metadata: metadata for paged attention.
"""
block_size = value_cache.shape[3]
attention_ops.single_query_cached_kv_attention(
output,
query,
key_cache,
value_cache,
self.head_mapping,
self.scale,
input_metadata.block_tables,
input_metadata.context_lens,
block_size,
input_metadata.max_context_len,
self.alibi_slopes,
)

View File

@ -11,6 +11,8 @@ from vllm.model_executor.parallel_utils.tensor_parallel import (
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import SequenceOutputs from vllm.sequence import SequenceOutputs
_SAMPLING_EPS = 1e-5
class Sampler(nn.Module): class Sampler(nn.Module):
"""Samples the next tokens from the model's outputs. """Samples the next tokens from the model's outputs.
@ -36,12 +38,15 @@ class Sampler(nn.Module):
embedding: torch.Tensor, embedding: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_metadata: InputMetadata, input_metadata: InputMetadata,
embedding_bias: Optional[torch.Tensor] = None,
) -> Dict[int, SequenceOutputs]: ) -> Dict[int, SequenceOutputs]:
# Get the hidden states that we use for sampling. # Get the hidden states that we use for sampling.
hidden_states = _prune_hidden_states(hidden_states, input_metadata) hidden_states = _prune_hidden_states(hidden_states, input_metadata)
# Get the logits for the next tokens. # Get the logits for the next tokens.
logits = torch.matmul(hidden_states, embedding.t()) logits = torch.matmul(hidden_states, embedding.t())
if embedding_bias is not None:
logits += embedding_bias
logits = gather_from_tensor_model_parallel_region(logits) logits = gather_from_tensor_model_parallel_region(logits)
# Remove paddings in vocab (if any). # Remove paddings in vocab (if any).
logits = logits[:, :self.vocab_size] logits = logits[:, :self.vocab_size]
@ -49,34 +54,37 @@ class Sampler(nn.Module):
# Apply presence and frequency penalties. # Apply presence and frequency penalties.
output_tokens = _get_output_tokens(input_metadata) output_tokens = _get_output_tokens(input_metadata)
assert len(output_tokens) == logits.shape[0] assert len(output_tokens) == logits.shape[0]
presence_penalties, frequency_penalties = _get_penalties(input_metadata) presence_penalties, frequency_penalties = _get_penalties(
input_metadata)
assert len(presence_penalties) == logits.shape[0] assert len(presence_penalties) == logits.shape[0]
assert len(frequency_penalties) == logits.shape[0] assert len(frequency_penalties) == logits.shape[0]
logits = _apply_penalties( logits = _apply_penalties(logits, output_tokens, presence_penalties,
logits, output_tokens, presence_penalties, frequency_penalties, frequency_penalties, self.vocab_size)
self.vocab_size)
# Apply temperature scaling. # Apply temperature scaling.
temperatures = _get_temperatures(input_metadata) temperatures = _get_temperatures(input_metadata)
assert len(temperatures) == logits.shape[0] assert len(temperatures) == logits.shape[0]
if any(t != 1.0 for t in temperatures): if any(t != 1.0 for t in temperatures):
t = torch.tensor( t = torch.tensor(temperatures,
temperatures, dtype=logits.dtype, device=logits.device) dtype=logits.dtype,
device=logits.device)
# Use in-place division to avoid creating a new tensor. # Use in-place division to avoid creating a new tensor.
logits.div_(t.unsqueeze(dim=1)) logits.div_(t.unsqueeze(dim=1))
# Apply top-p and top-k truncation.
top_ps, top_ks = _get_top_p_top_k(input_metadata, self.vocab_size)
assert len(top_ps) == len(top_ks) == logits.shape[0]
do_top_p = any(p < 1.0 - _SAMPLING_EPS for p in top_ps)
do_top_k = any(k != self.vocab_size for k in top_ks)
if do_top_p or do_top_k:
logits = _apply_top_p_top_k(logits, top_ps, top_ks)
# We use float32 for probabilities and log probabilities. # We use float32 for probabilities and log probabilities.
# Compute the probabilities. # Compute the probabilities.
probs = torch.softmax(logits, dim=-1, dtype=torch.float) probs = torch.softmax(logits, dim=-1, dtype=torch.float)
# Compute the log probabilities (before applying top-p and top-k). # Compute the log probabilities (before applying top-p and top-k).
logprobs = torch.log(probs) logprobs = torch.log(probs)
# Apply top-p and top-k truncation.
top_ps, top_ks = _get_top_p_top_k(input_metadata, self.vocab_size)
assert len(top_ps) == len(top_ks) == probs.shape[0]
if any(p < 1.0 for p in top_ps) or any(k != self.vocab_size for k in top_ks):
probs = _apply_top_p_top_k(probs, top_ps, top_ks)
# Sample the next tokens. # Sample the next tokens.
return _sample(probs, logprobs, input_metadata) return _sample(probs, logprobs, input_metadata)
@ -96,8 +104,7 @@ def _prune_hidden_states(
def _get_penalties( def _get_penalties(
input_metadata: InputMetadata, input_metadata: InputMetadata) -> Tuple[List[float], List[float]]:
) -> Tuple[List[float], List[float]]:
# Collect the presence and frequency penalties. # Collect the presence and frequency penalties.
presence_penalties: List[float] = [] presence_penalties: List[float] = []
frequency_penalties: List[float] = [] frequency_penalties: List[float] = []
@ -116,9 +123,7 @@ def _get_penalties(
return presence_penalties, frequency_penalties return presence_penalties, frequency_penalties
def _get_output_tokens( def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]:
input_metadata: InputMetadata,
) -> List[List[int]]:
output_tokens: List[List[int]] = [] output_tokens: List[List[int]] = []
for i, seq_group in enumerate(input_metadata.seq_groups): for i, seq_group in enumerate(input_metadata.seq_groups):
seq_ids, _ = seq_group seq_ids, _ = seq_group
@ -152,7 +157,7 @@ def _apply_penalties(
continue continue
p = presence_penalties[i] p = presence_penalties[i]
f = frequency_penalties[i] f = frequency_penalties[i]
if p == 0.0 and f == 0.0: if p < _SAMPLING_EPS and f < _SAMPLING_EPS:
continue continue
indices.append(i) indices.append(i)
@ -168,11 +173,13 @@ def _apply_penalties(
device=logits.device) device=logits.device)
frequency_penalties = [frequency_penalties[i] for i in indices] frequency_penalties = [frequency_penalties[i] for i in indices]
frequency_penalties = torch.tensor( frequency_penalties = torch.tensor(frequency_penalties,
frequency_penalties, dtype=logits.dtype, device=logits.device) dtype=logits.dtype,
device=logits.device)
presence_penalties = [presence_penalties[i] for i in indices] presence_penalties = [presence_penalties[i] for i in indices]
presence_penalties = torch.tensor( presence_penalties = torch.tensor(presence_penalties,
presence_penalties, dtype=logits.dtype, device=logits.device) dtype=logits.dtype,
device=logits.device)
# We follow the definition in OpenAI API. # We follow the definition in OpenAI API.
# Refer to https://platform.openai.com/docs/api-reference/parameter-details # Refer to https://platform.openai.com/docs/api-reference/parameter-details
@ -182,15 +189,13 @@ def _apply_penalties(
return logits return logits
def _get_temperatures( def _get_temperatures(input_metadata: InputMetadata) -> List[float]:
input_metadata: InputMetadata,
) -> List[float]:
# Collect the temperatures for the logits. # Collect the temperatures for the logits.
temperatures: List[float] = [] temperatures: List[float] = []
for i, seq_group in enumerate(input_metadata.seq_groups): for i, seq_group in enumerate(input_metadata.seq_groups):
seq_ids, sampling_params = seq_group seq_ids, sampling_params = seq_group
temperature = sampling_params.temperature temperature = sampling_params.temperature
if temperature == 0.0: if temperature < _SAMPLING_EPS:
# NOTE: Zero temperature means deterministic sampling # NOTE: Zero temperature means deterministic sampling
# (i.e., greedy sampling or beam search). # (i.e., greedy sampling or beam search).
# Set the temperature to 1 to avoid division by zero. # Set the temperature to 1 to avoid division by zero.
@ -230,30 +235,32 @@ def _get_top_p_top_k(
def _apply_top_p_top_k( def _apply_top_p_top_k(
probs: torch.Tensor, logits: torch.Tensor,
top_ps: List[float], top_ps: List[float],
top_ks: List[int], top_ks: List[int],
) -> torch.Tensor: ) -> torch.Tensor:
p = torch.tensor(top_ps, dtype=probs.dtype, device=probs.device) p = torch.tensor(top_ps, dtype=logits.dtype, device=logits.device)
k = torch.tensor(top_ks, dtype=torch.int, device=probs.device) k = torch.tensor(top_ks, dtype=torch.int, device=logits.device)
probs_sort, probs_idx = probs.sort(dim=-1, descending=True) logits_sort, logits_idx = logits.sort(dim=-1, descending=True)
# Apply top-p. # Apply top-p.
probs_sum = torch.cumsum(probs_sort, dim=-1) probs_sort = logits_sort.softmax(dim=-1)
probs_sum = probs_sort.cumsum(dim=-1)
top_p_mask = (probs_sum - probs_sort) > p.unsqueeze(dim=1) top_p_mask = (probs_sum - probs_sort) > p.unsqueeze(dim=1)
probs_sort[top_p_mask] = 0.0 logits_sort[top_p_mask] = -float("inf")
# Apply top-k. # Apply top-k.
# Create a mask for the top-k elements. # Create a mask for the top-k elements.
top_k_mask = torch.arange(probs_idx.shape[-1], device=probs_idx.device) top_k_mask = torch.arange(logits_idx.shape[-1], device=logits_idx.device)
top_k_mask = top_k_mask.expand(probs_idx.shape[0], -1) top_k_mask = top_k_mask.expand(logits_idx.shape[0], -1)
top_k_mask = top_k_mask >= k.unsqueeze(dim=1) top_k_mask = top_k_mask >= k.unsqueeze(dim=1)
probs_sort[top_k_mask] = 0.0 logits_sort[top_k_mask] = -float("inf")
# Re-sort the probabilities. # Re-sort the probabilities.
probs = torch.gather( logits = torch.gather(logits_sort,
probs_sort, dim=-1, index=torch.argsort(probs_idx, dim=-1)) dim=-1,
return probs index=torch.argsort(logits_idx, dim=-1))
return logits
def _get_topk_logprobs( def _get_topk_logprobs(
@ -286,7 +293,7 @@ def _sample_from_prompt(
beam_width = sampling_params.best_of beam_width = sampling_params.best_of
_, next_token_ids = torch.topk(prob, beam_width) _, next_token_ids = torch.topk(prob, beam_width)
next_token_ids = next_token_ids.tolist() next_token_ids = next_token_ids.tolist()
elif sampling_params.temperature == 0.0: elif sampling_params.temperature < _SAMPLING_EPS:
# Greedy sampling. # Greedy sampling.
assert sampling_params.best_of == 1 assert sampling_params.best_of == 1
next_token_id = torch.argmax(prob) next_token_id = torch.argmax(prob)
@ -295,8 +302,9 @@ def _sample_from_prompt(
# Random sampling. # Random sampling.
# Sample `best_of` tokens for the prompt. # Sample `best_of` tokens for the prompt.
num_seqs = sampling_params.best_of num_seqs = sampling_params.best_of
next_token_ids = torch.multinomial( next_token_ids = torch.multinomial(prob,
prob, num_samples=num_seqs, replacement=True) num_samples=num_seqs,
replacement=True)
next_token_ids = next_token_ids.tolist() next_token_ids = next_token_ids.tolist()
return next_token_ids return next_token_ids
@ -314,8 +322,9 @@ def _sample_from_generation_tokens(
if sampling_params.use_beam_search: if sampling_params.use_beam_search:
# Beam search. # Beam search.
# Add cumulative logprobs for the sequences in the group. # Add cumulative logprobs for the sequences in the group.
seq_logprobs = torch.tensor( seq_logprobs = torch.tensor(seq_logprobs,
seq_logprobs, dtype=torch.float, device=logprobs.device) dtype=torch.float,
device=logprobs.device)
logprobs = logprobs + seq_logprobs.unsqueeze(dim=1) logprobs = logprobs + seq_logprobs.unsqueeze(dim=1)
vocab_size = logprobs.size(-1) vocab_size = logprobs.size(-1)
@ -343,7 +352,7 @@ def _sample_from_generation_tokens(
parent_seq_ids = [beam_outputs[seq_id][0] for seq_id in seq_ids] parent_seq_ids = [beam_outputs[seq_id][0] for seq_id in seq_ids]
next_token_ids = [beam_outputs[seq_id][1] for seq_id in seq_ids] next_token_ids = [beam_outputs[seq_id][1] for seq_id in seq_ids]
elif sampling_params.temperature == 0.0: elif sampling_params.temperature < _SAMPLING_EPS:
# Greedy sampling. # Greedy sampling.
assert len(seq_ids) == 1 assert len(seq_ids) == 1
next_token_id = torch.argmax(probs, dim=-1) next_token_id = torch.argmax(probs, dim=-1)
@ -352,8 +361,9 @@ def _sample_from_generation_tokens(
else: else:
# Random sampling. # Random sampling.
# Sample 1 token for each sequence in the group. # Sample 1 token for each sequence in the group.
next_token_ids = torch.multinomial( next_token_ids = torch.multinomial(probs,
probs, num_samples=1, replacement=True) num_samples=1,
replacement=True)
next_token_ids = next_token_ids.squeeze(dim=-1).tolist() next_token_ids = next_token_ids.squeeze(dim=-1).tolist()
parent_seq_ids = seq_ids parent_seq_ids = seq_ids
return parent_seq_ids, next_token_ids return parent_seq_ids, next_token_ids
@ -380,15 +390,16 @@ def _sample(
# Sample the next tokens. # Sample the next tokens.
next_token_ids = _sample_from_prompt(prob, sampling_params) next_token_ids = _sample_from_prompt(prob, sampling_params)
# Get top-k log probabilities for the next tokens. # Get top-k log probabilities for the next tokens.
next_logprobs = _get_topk_logprobs( next_logprobs = _get_topk_logprobs(logprob,
logprob, sampling_params.logprobs) sampling_params.logprobs)
# Build the output. # Build the output.
for seq_id, next_token_id in zip(seq_ids, next_token_ids): for seq_id, next_token_id in zip(seq_ids, next_token_ids):
output_logprobs = next_logprobs.copy() output_logprobs = next_logprobs.copy()
output_logprobs[next_token_id] = logprob[next_token_id].item() output_logprobs[next_token_id] = logprob[next_token_id].item()
seq_outputs[seq_id] = SequenceOutputs( seq_outputs[seq_id] = SequenceOutputs(seq_id, seq_id,
seq_id, seq_id, next_token_id, output_logprobs) next_token_id,
output_logprobs)
else: else:
# Generate the next tokens for generation tokens. # Generate the next tokens for generation tokens.
prob = probs[idx:idx + len(seq_ids)] prob = probs[idx:idx + len(seq_ids)]
@ -398,22 +409,24 @@ def _sample(
# Sample the next tokens. # Sample the next tokens.
seq_logprobs = [ seq_logprobs = [
input_metadata.seq_data[seq_id].cumulative_logprob input_metadata.seq_data[seq_id].cumulative_logprob
for seq_id in seq_ids] for seq_id in seq_ids
]
parent_seq_ids, next_token_ids = _sample_from_generation_tokens( parent_seq_ids, next_token_ids = _sample_from_generation_tokens(
seq_ids, prob, logprob, seq_logprobs, sampling_params) seq_ids, prob, logprob, seq_logprobs, sampling_params)
# Get top-k log probabilities for the next tokens. # Get top-k log probabilities for the next tokens.
next_logprobs: Dict[int, Dict[int, float]] = {} next_logprobs: Dict[int, Dict[int, float]] = {}
for i, seq_id in enumerate(seq_ids): for j, seq_id in enumerate(seq_ids):
next_logprobs[seq_id] = _get_topk_logprobs( next_logprobs[seq_id] = _get_topk_logprobs(
logprob[i], sampling_params.logprobs) logprob[j], sampling_params.logprobs)
# Build the output. # Build the output.
for seq_id, parent_seq_id, next_token_id in zip( for seq_id, parent_seq_id, next_token_id in zip(
seq_ids, parent_seq_ids, next_token_ids): seq_ids, parent_seq_ids, next_token_ids):
i = seq_ids.index(parent_seq_id) j = seq_ids.index(parent_seq_id)
output_logprobs = next_logprobs[parent_seq_id].copy() output_logprobs = next_logprobs[parent_seq_id].copy()
output_logprobs[next_token_id] = logprob[i, next_token_id].item() output_logprobs[next_token_id] = logprob[j,
next_token_id].item()
seq_outputs[seq_id] = SequenceOutputs( seq_outputs[seq_id] = SequenceOutputs(
seq_id, seq_id,
parent_seq_id, parent_seq_id,

View File

@ -6,16 +6,27 @@ import torch.nn as nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.model_executor.models import (GPT2LMHeadModel, GPTNeoXForCausalLM, from vllm.model_executor.models import * # pylint: disable=wildcard-import
LlamaForCausalLM, OPTForCausalLM)
from vllm.model_executor.weight_utils import initialize_dummy_weights from vllm.model_executor.weight_utils import initialize_dummy_weights
# TODO(woosuk): Lazy-load the model classes. # TODO(woosuk): Lazy-load the model classes.
_MODEL_REGISTRY = { _MODEL_REGISTRY = {
"AquilaModel": AquilaForCausalLM,
"BaiChuanForCausalLM": BaiChuanForCausalLM, # baichuan-7b
"BaichuanForCausalLM": BaichuanForCausalLM, # baichuan-13b
"BloomForCausalLM": BloomForCausalLM,
"FalconForCausalLM": FalconForCausalLM,
"GPT2LMHeadModel": GPT2LMHeadModel, "GPT2LMHeadModel": GPT2LMHeadModel,
"GPTBigCodeForCausalLM": GPTBigCodeForCausalLM,
"GPTJForCausalLM": GPTJForCausalLM,
"GPTNeoXForCausalLM": GPTNeoXForCausalLM, "GPTNeoXForCausalLM": GPTNeoXForCausalLM,
"InternLMForCausalLM": InternLMForCausalLM,
"LlamaForCausalLM": LlamaForCausalLM, "LlamaForCausalLM": LlamaForCausalLM,
"LLaMAForCausalLM": LlamaForCausalLM, # For decapoda-research/llama-*
"MPTForCausalLM": MPTForCausalLM,
"OPTForCausalLM": OPTForCausalLM, "OPTForCausalLM": OPTForCausalLM,
"QWenLMHeadModel": QWenLMHeadModel,
"RWForCausalLM": FalconForCausalLM,
} }
@ -26,8 +37,7 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
return _MODEL_REGISTRY[arch] return _MODEL_REGISTRY[arch]
raise ValueError( raise ValueError(
f"Model architectures {architectures} are not supported for now. " f"Model architectures {architectures} are not supported for now. "
f"Supported architectures: {list(_MODEL_REGISTRY.keys())}" f"Supported architectures: {list(_MODEL_REGISTRY.keys())}")
)
def get_model(model_config: ModelConfig) -> nn.Module: def get_model(model_config: ModelConfig) -> nn.Module:
@ -44,8 +54,7 @@ def get_model(model_config: ModelConfig) -> nn.Module:
initialize_dummy_weights(model) initialize_dummy_weights(model)
else: else:
# Load the weights from the cached or downloaded files. # Load the weights from the cached or downloaded files.
model.load_weights( model.load_weights(model_config.model, model_config.download_dir,
model_config.model, model_config.download_dir,
model_config.use_np_weights) model_config.use_np_weights)
model = model.cuda() model = model.cuda()
return model.eval() return model.eval()

View File

@ -1,12 +1,31 @@
from vllm.model_executor.models.gpt_neox import GPTNeoXForCausalLM from vllm.model_executor.models.aquila import AquilaForCausalLM
from vllm.model_executor.models.baichuan import (BaiChuanForCausalLM,
BaichuanForCausalLM)
from vllm.model_executor.models.bloom import BloomForCausalLM
from vllm.model_executor.models.falcon import FalconForCausalLM
from vllm.model_executor.models.gpt2 import GPT2LMHeadModel from vllm.model_executor.models.gpt2 import GPT2LMHeadModel
from vllm.model_executor.models.gpt_bigcode import GPTBigCodeForCausalLM
from vllm.model_executor.models.gpt_j import GPTJForCausalLM
from vllm.model_executor.models.gpt_neox import GPTNeoXForCausalLM
from vllm.model_executor.models.internlm import InternLMForCausalLM
from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.model_executor.models.mpt import MPTForCausalLM
from vllm.model_executor.models.opt import OPTForCausalLM from vllm.model_executor.models.opt import OPTForCausalLM
from vllm.model_executor.models.qwen import QWenLMHeadModel
__all__ = [ __all__ = [
"AquilaForCausalLM",
"BaiChuanForCausalLM",
"BaichuanForCausalLM",
"BloomForCausalLM",
"FalconForCausalLM",
"GPT2LMHeadModel", "GPT2LMHeadModel",
"GPTBigCodeForCausalLM",
"GPTJForCausalLM",
"GPTNeoXForCausalLM", "GPTNeoXForCausalLM",
"InternLMForCausalLM",
"LlamaForCausalLM", "LlamaForCausalLM",
"MPTForCausalLM",
"OPTForCausalLM", "OPTForCausalLM",
"QWenLMHeadModel",
] ]

View File

@ -0,0 +1,362 @@
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only LLaMA model compatible with HuggingFace weights.
The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
"""
from typing import Dict, List, Optional, Tuple
import torch
from torch import nn
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
from vllm.sequence import SequenceOutputs
from vllm.transformers_utils.configs.aquila import AquilaConfig
KVCache = Tuple[torch.Tensor, torch.Tensor]
class AquilaMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
):
super().__init__()
self.gate_up_proj = ColumnParallelLinear(hidden_size,
2 * intermediate_size,
bias=False,
gather_output=False,
perform_initialization=False)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
input_is_parallel=True,
perform_initialization=False)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class AquilaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
AquilaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).pow(2).mean(-1,
keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance +
self.variance_epsilon)
return (self.weight * hidden_states).to(input_dtype)
class AquilaAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
):
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
assert self.total_num_kv_heads % tp_size == 0
self.num_kv_heads = self.total_num_kv_heads // tp_size
self.head_dim = hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.qkv_proj = ColumnParallelLinear(
hidden_size,
(self.total_num_heads + 2 * self.total_num_kv_heads) *
self.head_dim,
bias=False,
gather_output=False,
perform_initialization=False,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
input_is_parallel=True,
perform_initialization=False,
)
self.attn = PagedAttentionWithRoPE(
self.num_heads,
self.head_dim,
self.scaling,
rotary_dim=self.head_dim,
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
k_cache, v_cache = kv_cache
attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
input_metadata, cache_event)
output, _ = self.o_proj(attn_output)
return output
class AquilaDecoderLayer(nn.Module):
def __init__(self, config: AquilaConfig):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = AquilaAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_attention_heads,
)
self.mlp = AquilaMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
)
self.input_layernorm = AquilaRMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = AquilaRMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
# Self Attention
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class AquilaModel(nn.Module):
def __init__(self, config: AquilaConfig):
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
#vocab_size = ((config.vocab_size + 63) // 64) * 64
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
perform_initialization=False)
self.layers = nn.ModuleList([
AquilaDecoderLayer(config) for _ in range(config.num_hidden_layers)
])
self.norm = AquilaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
for i in range(len(self.layers)):
if cache_events is None:
cache_event = None
else:
cache_event = cache_events[i]
layer = self.layers[i]
hidden_states = layer(
positions,
hidden_states,
kv_caches[i],
input_metadata,
cache_event,
)
hidden_states = self.norm(hidden_states)
return hidden_states
class AquilaForCausalLM(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.model = AquilaModel(config)
vocab_size = ((config.vocab_size + 63) // 64) * 64
self.lm_head = ColumnParallelLinear(config.hidden_size,
vocab_size,
bias=False,
gather_output=False,
perform_initialization=False)
self.sampler = Sampler(config.vocab_size)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]:
hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
input_metadata)
return next_tokens
_column_parallel_weights = [
"embed_tokens.weight", "lm_head.weight", "qkv_proj.weight",
"gate_proj.weight", "up_proj.weight"
]
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
use_np_cache: bool = False):
tp_size = get_tensor_model_parallel_world_size()
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
q_proj_shard_size = (self.config.hidden_size // tp_size)
kv_proj_shard_size = (self.config.hidden_size //
self.config.num_attention_heads *
self.config.num_attention_heads // tp_size)
attention_weight_specs = [
# (weight_name, shard_size, offset)
("q_proj", q_proj_shard_size, 0),
("k_proj", kv_proj_shard_size, q_proj_shard_size),
("v_proj", kv_proj_shard_size,
q_proj_shard_size + kv_proj_shard_size),
]
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, use_np_cache):
if "rotary_emb.inv_freq" in name:
continue
if "embed_tokens" in name or "lm_head" in name:
param = state_dict[name]
# Consider padding in the vocab size.
padded_vocab_size = (param.shape[0] * tp_size)
num_extra_rows = padded_vocab_size - self.config.vocab_size
extra_rows = torch.empty(num_extra_rows,
loaded_weight.shape[1])
extra_rows = extra_rows.to(loaded_weight)
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
is_attention_weight = False
for weight_name, shard_size, offset in attention_weight_specs:
if weight_name not in name:
continue
param = state_dict[name.replace(weight_name, "qkv_proj")]
loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank:shard_size *
(tensor_model_parallel_rank + 1)]
param_slice = param.data[offset:offset + shard_size]
assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight)
is_attention_weight = True
break
if is_attention_weight:
continue
is_gate_up_weight = False
for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
if weight_name not in name:
continue
param = state_dict[name.replace(weight_name, "gate_up_proj")]
shard_size = param.shape[0] // 2
loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank:shard_size *
(tensor_model_parallel_rank + 1)]
param_slice = param.data[shard_size * stride_id:shard_size *
(stride_id + 1)]
assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight)
is_gate_up_weight = True
break
if is_gate_up_weight:
continue
param = state_dict[name]
load_tensor_parallel_weights(param, loaded_weight, name,
self._column_parallel_weights,
self._row_parallel_weights,
tensor_model_parallel_rank)

View File

@ -0,0 +1,377 @@
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only BaiChuan model compatible with HuggingFace weights.
The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
"""
import math
from typing import Dict, List, Optional, Tuple
import torch
from torch import nn
from vllm.sequence import SequenceOutputs
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE, PagedAttentionWithALiBi
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
from vllm.transformers_utils.configs.baichuan import BaiChuanConfig
KVCache = Tuple[torch.Tensor, torch.Tensor]
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
base = torch.tensor(
2**(-(2**-(math.log2(closest_power_of_2) - 3))),
dtype=torch.float32,
)
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
slopes = torch.pow(base, powers)
if closest_power_of_2 != total_num_heads:
extra_base = torch.tensor(
2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
dtype=torch.float32,
)
num_remaining_heads = min(closest_power_of_2,
total_num_heads - closest_power_of_2)
extra_powers = torch.arange(start=1,
end=1 + 2 * num_remaining_heads,
step=2,
dtype=torch.int32)
slopes = torch.cat(
[slopes, torch.pow(extra_base, extra_powers)], dim=0)
return slopes
class BaiChuanMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
):
super().__init__()
self.gate_up_proj = ColumnParallelLinear(hidden_size,
2 * intermediate_size,
bias=False,
gather_output=False,
perform_initialization=False)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
input_is_parallel=True,
perform_initialization=False)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class BaiChuanAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
hidden_size: int,
num_heads: int,
position_embedding: str,
):
super().__init__()
self.hidden_size = hidden_size
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
)
self.total_num_heads = num_heads
assert self.total_num_heads % tensor_model_parallel_world_size == 0
self.num_heads = (self.total_num_heads //
tensor_model_parallel_world_size)
self.head_dim = hidden_size // self.total_num_heads
self.postion_embedding = position_embedding
# pylint: disable=invalid-name
self.W_pack = ColumnParallelLinear(
hidden_size,
3 * hidden_size,
bias=False,
gather_output=False,
perform_initialization=False,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
input_is_parallel=True,
perform_initialization=False,
)
# Create the alibi slopes and slice them.
if self.postion_embedding == "ALIBI":
tp_rank = get_tensor_model_parallel_rank()
head_start = tp_rank * self.num_heads
head_end = (tp_rank + 1) * self.num_heads
alibi_slopes = _get_alibi_slopes(self.total_num_heads)
alibi_slopes = alibi_slopes[head_start:head_end].tolist()
scaling = self.head_dim**-0.5
self.attn = PagedAttentionWithALiBi(self.num_heads, self.head_dim,
scaling, alibi_slopes)
else:
self.scaling = self.head_dim**-0.5
self.attn = PagedAttentionWithRoPE(self.num_heads,
self.head_dim,
self.scaling,
rotary_dim=self.head_dim)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
qkv, _ = self.W_pack(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
k_cache, v_cache = kv_cache
if self.postion_embedding == "ALIBI":
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
cache_event)
else:
attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
input_metadata, cache_event)
output, _ = self.o_proj(attn_output)
return output
class BaiChuanDecoderLayer(nn.Module):
def __init__(self, config: BaiChuanConfig, position_embedding: str):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = BaiChuanAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
position_embedding=position_embedding,
)
self.mlp = BaiChuanMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
# Self Attention
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class BaiChuanModel(nn.Module):
def __init__(self, config: BaiChuanConfig, position_embedding: str):
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
perform_initialization=False)
self.layers = nn.ModuleList([
BaiChuanDecoderLayer(config, position_embedding)
for _ in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
for i in range(len(self.layers)):
if cache_events is None:
cache_event = None
else:
cache_event = cache_events[i]
layer = self.layers[i]
hidden_states = layer(
positions,
hidden_states,
kv_caches[i],
input_metadata,
cache_event,
)
hidden_states = self.norm(hidden_states)
return hidden_states
class BaiChuanBaseForCausalLM(nn.Module):
def __init__(self, config, position_embedding: str):
super().__init__()
self.config = config
self.model = BaiChuanModel(config, position_embedding)
self.lm_head = ColumnParallelLinear(config.hidden_size,
config.vocab_size,
bias=False,
gather_output=False,
perform_initialization=False)
self.sampler = Sampler(config.vocab_size)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]:
hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
input_metadata)
return next_tokens
_column_parallel_weights = [
"embed_tokens.weight",
"lm_head.weight",
]
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
use_np_cache: bool = False):
tp_world_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, use_np_cache):
if "rotary_emb.inv_freq" in name:
continue
if "embed_tokens" in name or "lm_head" in name:
# Consider padding in the vocab size.
param = state_dict[name]
padded_vocab_size = param.shape[0] * tp_world_size
num_extra_rows = padded_vocab_size - self.config.vocab_size
extra_rows = torch.empty(num_extra_rows,
loaded_weight.shape[1])
extra_rows = extra_rows.to(loaded_weight)
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
if "W_pack" in name:
total_num_heads = self.config.num_attention_heads
hidden_size = self.config.hidden_size
head_size = hidden_size // total_num_heads
num_heads = total_num_heads // tp_world_size
head_start = tp_rank * num_heads
head_end = (tp_rank + 1) * num_heads
loaded_weight = loaded_weight.view(3, total_num_heads,
head_size, hidden_size)
loaded_weight = loaded_weight[:, head_start:head_end, :, :]
loaded_weight = loaded_weight.reshape(-1, hidden_size)
is_gate_up_weight = False
for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
if weight_name not in name:
continue
param = state_dict[name.replace(weight_name, "gate_up_proj")]
shard_size = param.shape[0] // 2
loaded_weight = loaded_weight[shard_size * tp_rank:shard_size *
(tp_rank + 1)]
param_slice = param.data[shard_size * stride_id:shard_size *
(stride_id + 1)]
assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight)
is_gate_up_weight = True
break
if is_gate_up_weight:
continue
param = state_dict[name]
load_tensor_parallel_weights(
param,
loaded_weight,
name,
self._column_parallel_weights,
self._row_parallel_weights,
tp_rank,
)
class BaichuanForCausalLM(BaiChuanBaseForCausalLM): # baichuan 13b
def __init__(self, config):
super().__init__(config, "ALIBI")
class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): # baichuan 7b
def __init__(self, config):
super().__init__(config, "ROPE")

View File

@ -0,0 +1,324 @@
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/bloom/modeling_bloom.py
# Copyright 2023 The CacheFlow team.
# Copyright 2022 HuggingFace Inc. team and BigScience workshop.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only BLOOM model compatible with HuggingFace weights.
The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
"""
import math
from typing import Dict, List, Optional, Tuple
import torch
from torch import nn
from transformers import BloomConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import PagedAttentionWithALiBi
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
from vllm.sequence import SequenceOutputs
KVCache = Tuple[torch.Tensor, torch.Tensor]
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
base = torch.tensor(
2**(-(2**-(math.log2(closest_power_of_2) - 3))),
dtype=torch.float32,
)
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
slopes = torch.pow(base, powers)
if closest_power_of_2 != total_num_heads:
extra_base = torch.tensor(
2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
dtype=torch.float32,
)
num_remaining_heads = min(closest_power_of_2,
total_num_heads - closest_power_of_2)
extra_powers = torch.arange(start=1,
end=1 + 2 * num_remaining_heads,
step=2,
dtype=torch.int32)
slopes = torch.cat(
[slopes, torch.pow(extra_base, extra_powers)], dim=0)
return slopes
class BloomAttention(nn.Module):
def __init__(self, config: BloomConfig):
super().__init__()
self.hidden_size = config.hidden_size
self.total_num_heads = config.n_head
self.head_dim = self.hidden_size // self.total_num_heads
assert self.head_dim * self.total_num_heads == self.hidden_size
tp_world_size = get_tensor_model_parallel_world_size()
assert self.total_num_heads % tp_world_size == 0
self.num_heads = self.total_num_heads // tp_world_size
self.query_key_value = ColumnParallelLinear(
self.hidden_size,
3 * self.hidden_size,
bias=True,
gather_output=False,
perform_initialization=False,
)
self.dense = RowParallelLinear(
self.hidden_size,
self.hidden_size,
bias=True,
input_is_parallel=True,
perform_initialization=False,
)
# Create the alibi slopes and slice them.
tp_rank = get_tensor_model_parallel_rank()
head_start = tp_rank * self.num_heads
head_end = (tp_rank + 1) * self.num_heads
alibi_slopes = _get_alibi_slopes(self.total_num_heads)
alibi_slopes = alibi_slopes[head_start:head_end].tolist()
scaling = self.head_dim**-0.5
self.attn = PagedAttentionWithALiBi(self.num_heads, self.head_dim,
scaling, alibi_slopes)
def forward(
self,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
del position_ids # Unused.
qkv, _ = self.query_key_value(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
k_cache, v_cache = kv_cache
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
cache_event)
output, _ = self.dense(attn_output)
return output
class BloomMLP(nn.Module):
def __init__(self, config: BloomConfig):
super().__init__()
hidden_size = config.hidden_size
self.dense_h_to_4h = ColumnParallelLinear(hidden_size,
4 * hidden_size,
gather_output=False,
perform_initialization=False)
self.act = get_act_fn("gelu")
self.dense_4h_to_h = RowParallelLinear(4 * hidden_size,
hidden_size,
input_is_parallel=True,
perform_initialization=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x, _ = self.dense_h_to_4h(x)
x = self.act(x)
x, _ = self.dense_4h_to_h(x)
return x
class BloomBlock(nn.Module):
def __init__(self, config: BloomConfig):
super().__init__()
hidden_size = config.hidden_size
self.input_layernorm = nn.LayerNorm(hidden_size,
eps=config.layer_norm_epsilon)
self.self_attention = BloomAttention(config)
self.post_attention_layernorm = nn.LayerNorm(
hidden_size, eps=config.layer_norm_epsilon)
self.mlp = BloomMLP(config)
self.apply_residual_connection_post_layernorm = (
config.apply_residual_connection_post_layernorm)
def forward(
self,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
# Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)
# Layer norm post the self attention.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = hidden_states
# Self attention.
attention_output = self.self_attention(
position_ids=position_ids,
hidden_states=layernorm_output,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event,
)
attention_output = attention_output + residual
layernorm_output = self.post_attention_layernorm(attention_output)
# Get residual
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = attention_output
# MLP.
output = self.mlp(layernorm_output) + residual
return output
class BloomModel(nn.Module):
def __init__(self, config: BloomConfig):
super().__init__()
self.embed_dim = config.hidden_size
# Embedding + LN Embedding
self.word_embeddings = VocabParallelEmbedding(
config.vocab_size, self.embed_dim, perform_initialization=False)
self.word_embeddings_layernorm = nn.LayerNorm(
self.embed_dim, eps=config.layer_norm_epsilon)
# Transformer blocks
self.h = nn.ModuleList(
[BloomBlock(config) for _ in range(config.num_hidden_layers)])
# Final Layer Norm
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.word_embeddings(input_ids)
hidden_states = self.word_embeddings_layernorm(hidden_states)
for i in range(len(self.h)):
if cache_events is None:
cache_event = None
else:
cache_event = cache_events[i]
layer = self.h[i]
hidden_states = layer(
position_ids,
hidden_states,
kv_caches[i],
input_metadata,
cache_event,
)
hidden_states = self.ln_f(hidden_states)
return hidden_states
class BloomForCausalLM(nn.Module):
def __init__(self, config: BloomConfig):
super().__init__()
self.config = config
self.transformer = BloomModel(config)
# TODO(zhuohan): create a new weight after implementing pipeline
# parallelism
self.lm_head_weight = self.transformer.word_embeddings.weight
self.sampler = Sampler(config.vocab_size)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]:
hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
input_metadata)
return next_tokens
_column_parallel_weights = [
"word_embeddings.weight", "dense_h_to_4h.weight", "dense_h_to_4h.bias"
]
_row_parallel_weights = ["dense.weight", "dense_4h_to_h.weight"]
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
use_np_cache: bool = False):
tp_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, use_np_cache):
if name == "lm_head.weight":
# Since hidden_states are parallelized, we need to
# load lm_head.weight in parallel.
self._column_parallel_weights.append(name)
# If lm_head is provided, use it instead.
param = self.lm_head_weight
else:
if not name.startswith("transformer."):
name = "transformer." + name
param = state_dict[name]
if "query_key_value" in name:
# NOTE(woosuk): BLOOM's fused QKV has the shape of
# [num_heads * 3 * head_size, hidden_size], while the
# required shape is [3 * num_heads * head_size, hidden_size].
# Thus, we need weight conversion.
shard_size = param.shape[0]
start = shard_size * tp_rank
end = shard_size * (tp_rank + 1)
loaded_weight = loaded_weight[start:end]
num_heads = self.config.num_attention_heads
hidden_size = self.config.hidden_size
head_size = hidden_size // num_heads
if "query_key_value.weight" in name:
loaded_weight = loaded_weight.view(-1, 3, head_size,
hidden_size)
loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight = loaded_weight.reshape(-1, hidden_size)
elif "query_key_value.bias" in name:
loaded_weight = loaded_weight.view(-1, 3, head_size)
loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight = loaded_weight.reshape(-1)
else:
raise ValueError(f"Unexpected weight name: {name}")
load_tensor_parallel_weights(param, loaded_weight, name,
self._column_parallel_weights,
self._row_parallel_weights, tp_rank)

View File

@ -0,0 +1,496 @@
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/a5cc30d72ae2dc19af534e4b35c986cc28db1275/src/transformers/models/falcon/modeling_falcon.py
# Copyright 2023 The vLLM team.
# Copyright 2023 the Falcon authors and HuggingFace Inc. team. All rights
# reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Falcon model."""
import math
from typing import Dict, List, Optional, Tuple, Union
import torch
from torch import nn
from torch.nn import LayerNorm
from transformers import FalconConfig as HF_FalconConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.attention import (PagedAttention,
PagedAttentionWithALiBi,
PagedAttentionWithRoPE)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear,
reduce_from_tensor_model_parallel_region)
from vllm.sequence import SequenceOutputs
from vllm.transformers_utils.configs import RWConfig
KVCache = Tuple[torch.Tensor, torch.Tensor]
FalconConfig = Union[HF_FalconConfig, RWConfig]
# NOTE(Hesslow): Unfortunately we did not fuse matmul and bias during
# training, this means that there's one additional quantization to bfloat16
# between the operations. In order not to degrade the quality of our HF-port,
# we keep these characteristics in the final model.
class FalconLinear(nn.Linear):
def forward(self, x: torch.Tensor) -> torch.Tensor:
hidden_states = x @ self.weight.T
if self.bias is None:
return hidden_states
return hidden_states + self.bias
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
base = torch.tensor(2**(-(2**-(math.log2(closest_power_of_2) - 3))),
dtype=torch.float32)
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
slopes = torch.pow(base, powers)
if closest_power_of_2 != total_num_heads:
extra_base = torch.tensor(
2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
dtype=torch.float32)
num_remaining_heads = min(closest_power_of_2,
total_num_heads - closest_power_of_2)
extra_powers = torch.arange(1,
1 + 2 * num_remaining_heads,
2,
dtype=torch.int32)
slopes = torch.cat(
[slopes, torch.pow(extra_base, extra_powers)], dim=0)
return slopes
class FalconAttention(nn.Module):
def __init__(self, config: FalconConfig):
super().__init__()
self.hidden_size = config.hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = config.num_attention_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.head_dim = self.hidden_size // self.total_num_heads
assert self.head_dim * self.total_num_heads == self.hidden_size
self.new_decoder_architecture = config.new_decoder_architecture
self.multi_query = config.multi_query
if self.new_decoder_architecture:
self.total_num_kv_heads = config.num_kv_heads
assert self.total_num_heads % tp_size == 0
self.num_kv_heads = self.total_num_kv_heads // tp_size
self.query_key_value = ColumnParallelLinear(
self.hidden_size,
(self.total_num_heads + 2 * self.total_num_kv_heads) *
self.head_dim,
bias=config.bias,
gather_output=False,
perform_initialization=False,
skip_bias_add=True,
)
elif self.multi_query:
self.total_num_kv_heads = 1
self.num_kv_heads = 1
self.query = ColumnParallelLinear(
self.hidden_size,
self.total_num_heads * self.head_dim,
bias=config.bias,
gather_output=False,
perform_initialization=False,
skip_bias_add=True,
)
self.key_value = FalconLinear(self.hidden_size,
2 * self.head_dim,
bias=config.bias)
else:
self.total_num_kv_heads = self.total_num_heads
self.num_kv_heads = self.num_heads
self.query_key_value = ColumnParallelLinear(
self.hidden_size,
(self.total_num_heads + 2 * self.total_num_kv_heads) *
self.head_dim,
bias=config.bias,
gather_output=False,
perform_initialization=False,
skip_bias_add=True,
)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
# Layer-wise attention scaling
self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
self.reduce_row_parallel_results = not (config.new_decoder_architecture
or config.parallel_attn)
self.dense = RowParallelLinear(
self.hidden_size,
self.hidden_size,
bias=config.bias,
input_is_parallel=True,
perform_initialization=False,
skip_bias_add=True,
reduce_results=self.reduce_row_parallel_results)
self.use_rotary = config.rotary
self.use_alibi = config.alibi
assert not (self.use_rotary and self.use_alibi), (
"Rotary and alibi are mutually exclusive.")
if self.use_rotary:
# TODO(zhuohan): Pass in correct `max_position``
self.attn = PagedAttentionWithRoPE(self.num_heads,
self.head_dim,
self.inv_norm_factor,
rotary_dim=self.head_dim,
num_kv_heads=self.num_kv_heads)
elif self.use_alibi:
tp_rank = get_tensor_model_parallel_rank()
head_start = tp_rank * self.num_heads
head_end = (tp_rank + 1) * self.num_heads
alibi_slopes = (_get_alibi_slopes(self.total_num_heads) *
self.inv_norm_factor)
alibi_slopes = alibi_slopes[head_start:head_end].tolist()
self.attn = PagedAttentionWithALiBi(self.num_heads,
self.head_dim,
self.inv_norm_factor,
alibi_slopes,
num_kv_heads=self.num_kv_heads)
else:
self.attn = PagedAttention(self.num_heads,
self.head_dim,
scale=self.inv_norm_factor,
num_kv_heads=self.num_kv_heads)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
if not self.new_decoder_architecture and self.multi_query:
q, bias = self.query(hidden_states)
if bias is not None:
q += bias
kv = self.key_value(hidden_states)
k, v = kv.split([self.kv_size, self.kv_size], dim=-1)
else:
qkv, bias = self.query_key_value(hidden_states)
if bias is not None:
qkv += bias
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
dim=-1)
k_cache, v_cache = kv_cache
if self.use_rotary:
attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
input_metadata, cache_event)
else:
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
cache_event)
attn_output, bias = self.dense(attn_output)
return attn_output, bias
class FalconMLP(nn.Module):
def __init__(self, config: FalconConfig):
super().__init__()
hidden_size = config.hidden_size
self.dense_h_to_4h = ColumnParallelLinear(hidden_size,
4 * hidden_size,
bias=config.bias,
gather_output=False,
perform_initialization=False,
skip_bias_add=True)
self.act = nn.GELU()
self.reduce_row_parallel_results = not (config.new_decoder_architecture
or config.parallel_attn)
self.dense_4h_to_h = RowParallelLinear(
4 * hidden_size,
hidden_size,
bias=config.bias,
input_is_parallel=True,
perform_initialization=False,
skip_bias_add=True,
reduce_results=self.reduce_row_parallel_results)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# NOTE(zhuohan): Following huggingface, we do not fuse bias add here.
x, bias = self.dense_h_to_4h(x)
if bias is not None:
x += bias
x = self.act(x)
x, bias = self.dense_4h_to_h(x)
return x, bias
class FalconDecoderLayer(nn.Module):
def __init__(self, config: FalconConfig):
super().__init__()
hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.self_attention = FalconAttention(config)
self.mlp = FalconMLP(config)
self.config = config
if config.new_decoder_architecture:
# The layer norm before self-attention
self.ln_attn = LayerNorm(hidden_size,
eps=config.layer_norm_epsilon)
# The layer norm before the MLP
self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
else:
self.input_layernorm = LayerNorm(hidden_size,
eps=config.layer_norm_epsilon)
if not config.parallel_attn:
self.post_attention_layernorm = LayerNorm(
hidden_size, eps=config.layer_norm_epsilon)
self.reduce_row_parallel_results = not (config.new_decoder_architecture
or config.parallel_attn)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
):
residual = hidden_states
if self.config.new_decoder_architecture:
attention_layernorm_out = self.ln_attn(hidden_states)
mlp_layernorm_out = self.ln_mlp(hidden_states)
else:
attention_layernorm_out = self.input_layernorm(hidden_states)
# Self attention.
attention_output, attention_bias = self.self_attention(
positions=positions,
hidden_states=attention_layernorm_out,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event,
)
if self.reduce_row_parallel_results and attention_bias is not None:
attention_output += attention_bias
if not self.config.new_decoder_architecture:
if self.config.parallel_attn:
mlp_layernorm_out = attention_layernorm_out
else:
residual += attention_output
mlp_layernorm_out = self.post_attention_layernorm(residual)
# MLP.
mlp_output, mlp_bias = self.mlp(mlp_layernorm_out)
if self.reduce_row_parallel_results and mlp_bias is not None:
mlp_output += mlp_bias
if not self.reduce_row_parallel_results:
# When MLP and Attention layers are parallel, we can use
# only one all-reduce operator to reduce the results from
# both MLP and Attention layers.
mlp_output += attention_output
mlp_output = reduce_from_tensor_model_parallel_region(mlp_output)
if attention_bias is not None:
mlp_output += attention_bias
if mlp_bias is not None:
mlp_output += mlp_bias
output = mlp_output + residual
return output
class FalconModel(nn.Module):
def __init__(self, config: FalconConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.use_alibi = config.alibi
# Embedding + LN Embedding
self.word_embeddings = VocabParallelEmbedding(
config.vocab_size, self.embed_dim, perform_initialization=False)
# Transformer blocks
self.h = nn.ModuleList([
FalconDecoderLayer(config) for _ in range(config.num_hidden_layers)
])
# Final Layer Norm
self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
def forward(
self,
input_ids: torch.LongTensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.word_embeddings(input_ids)
for i in range(len(self.h)):
if cache_events is None:
cache_event = None
else:
cache_event = cache_events[i]
layer = self.h[i]
hidden_states = layer(
positions,
hidden_states,
kv_caches[i],
input_metadata,
cache_event,
)
hidden_states = self.ln_f(hidden_states)
return hidden_states
class FalconForCausalLM(nn.Module):
def __init__(self, config: FalconConfig):
super().__init__()
self.config = config
self.transformer = FalconModel(config)
self.lm_head = ColumnParallelLinear(config.hidden_size,
config.vocab_size,
bias=False,
gather_output=False,
perform_initialization=False)
self.sampler = Sampler(config.vocab_size)
def forward(
self,
input_ids: torch.LongTensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]:
hidden_states = self.transformer(
input_ids,
positions,
kv_caches,
input_metadata,
cache_events,
)
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
input_metadata)
return next_tokens
_column_parallel_weights = [
"word_embeddings.weight", "lm_head.weight", "dense_h_to_4h.weight",
"dense_h_to_4h.bias"
]
_row_parallel_weights = ["dense.weight", "dense_4h_to_h.weight"]
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
use_np_cache: bool = False):
tp_size = (get_tensor_model_parallel_world_size())
tp_rank = get_tensor_model_parallel_rank()
hidden_size = self.config.hidden_size
total_num_heads = self.config.num_attention_heads
num_heads = total_num_heads // tp_size
head_size = hidden_size // total_num_heads
head_start = tp_rank * num_heads
head_end = (tp_rank + 1) * num_heads
if self.config.new_decoder_architecture:
total_num_kv_heads = self.config.num_kv_heads
num_kv_heads = total_num_kv_heads // tp_size
separated_q_kv = False
kv_head_start = tp_rank * num_kv_heads
kv_head_end = (tp_rank + 1) * num_kv_heads
elif self.config.multi_query:
total_num_kv_heads = 1
num_kv_heads = 1
separated_q_kv = True
kv_head_start = 0
kv_head_end = 1
else:
total_num_kv_heads = total_num_heads
num_kv_heads = total_num_kv_heads // tp_size
separated_q_kv = False
kv_head_start = tp_rank * num_kv_heads
kv_head_end = (tp_rank + 1) * num_kv_heads
num_query_heads_per_kv_head = total_num_heads // total_num_kv_heads
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, use_np_cache):
if "query_key_value" in name:
loaded_weight_size = loaded_weight.size()
loaded_weight = loaded_weight.view(
total_num_kv_heads, num_query_heads_per_kv_head + 2,
head_size, *loaded_weight_size[1:])
wq = loaded_weight[:, :-2].reshape(-1, *loaded_weight_size[1:])
wk = loaded_weight[:, [-2]].reshape(-1,
*loaded_weight_size[1:])
wv = loaded_weight[:, [-1]].reshape(-1,
*loaded_weight_size[1:])
wq = wq[head_size * head_start:head_size * head_end]
wk = wk[head_size * kv_head_start:head_size * kv_head_end]
wv = wv[head_size * kv_head_start:head_size * kv_head_end]
if separated_q_kv:
loaded_weight_q = wq
loaded_weight_kv = torch.cat([wk, wv], dim=0)
q_weight_name = name.replace("query_key_value", "query")
kv_weight_name = name.replace("query_key_value",
"key_value")
load_tensor_parallel_weights(state_dict[q_weight_name],
loaded_weight_q,
q_weight_name,
self._column_parallel_weights,
self._row_parallel_weights,
tp_rank)
load_tensor_parallel_weights(state_dict[kv_weight_name],
loaded_weight_kv,
kv_weight_name,
self._column_parallel_weights,
self._row_parallel_weights,
tp_rank)
continue
else:
loaded_weight = torch.cat([wq, wk, wv], dim=0)
param = state_dict[name]
load_tensor_parallel_weights(param, loaded_weight, name,
self._column_parallel_weights,
self._row_parallel_weights, tp_rank)

View File

@ -1,5 +1,6 @@
# coding=utf-8 # coding=utf-8
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py # Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
# Copyright 2023 The vLLM team. # Copyright 2023 The vLLM team.
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
@ -47,19 +48,25 @@ class GPT2Attention(nn.Module):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
total_num_heads = config.num_attention_heads total_num_heads = config.num_attention_heads
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
assert total_num_heads % tensor_model_parallel_world_size == 0 assert total_num_heads % tensor_model_parallel_world_size == 0
self.num_heads = total_num_heads // tensor_model_parallel_world_size self.num_heads = total_num_heads // tensor_model_parallel_world_size
self.head_dim = self.hidden_size // total_num_heads self.head_dim = self.hidden_size // total_num_heads
self.scale = self.head_dim**-0.5 self.scale = self.head_dim**-0.5
self.c_attn = ColumnParallelLinear(self.hidden_size, 3 * self.hidden_size, self.c_attn = ColumnParallelLinear(self.hidden_size,
bias=True, gather_output=False, 3 * self.hidden_size,
bias=True,
gather_output=False,
perform_initialization=False) perform_initialization=False)
self.c_proj = RowParallelLinear(self.hidden_size, self.hidden_size, self.c_proj = RowParallelLinear(self.hidden_size,
bias=True, input_is_parallel=True, self.hidden_size,
bias=True,
input_is_parallel=True,
perform_initialization=False) perform_initialization=False)
self.attn = PagedAttention(self.num_heads, self.head_dim, self.attn = PagedAttention(self.num_heads,
self.head_dim,
scale=self.scale) scale=self.scale)
def forward( def forward(
@ -72,8 +79,8 @@ class GPT2Attention(nn.Module):
qkv, _ = self.c_attn(hidden_states) qkv, _ = self.c_attn(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
key_cache, value_cache = kv_cache key_cache, value_cache = kv_cache
attn_output = self.attn( attn_output = self.attn(q, k, v, key_cache, value_cache,
q, k, v, key_cache, value_cache, input_metadata, cache_event) input_metadata, cache_event)
attn_output, _ = self.c_proj(attn_output) attn_output, _ = self.c_proj(attn_output)
return attn_output return attn_output
@ -87,11 +94,15 @@ class GPT2MLP(nn.Module):
): ):
super().__init__() super().__init__()
hidden_size = config.hidden_size hidden_size = config.hidden_size
self.c_fc = ColumnParallelLinear(hidden_size, intermediate_size, self.c_fc = ColumnParallelLinear(hidden_size,
bias=True, gather_output=False, intermediate_size,
bias=True,
gather_output=False,
perform_initialization=False) perform_initialization=False)
self.c_proj = RowParallelLinear(intermediate_size, hidden_size, self.c_proj = RowParallelLinear(intermediate_size,
bias=True, input_is_parallel=True, hidden_size,
bias=True,
input_is_parallel=True,
perform_initialization=False) perform_initialization=False)
self.act = get_act_fn(config.activation_function) self.act = get_act_fn(config.activation_function)
@ -107,7 +118,8 @@ class GPT2Block(nn.Module):
def __init__(self, config: GPT2Config): def __init__(self, config: GPT2Config):
super().__init__() super().__init__()
hidden_size = config.hidden_size hidden_size = config.hidden_size
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size inner_dim = (config.n_inner if config.n_inner is not None else 4 *
hidden_size)
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = GPT2Attention(config) self.attn = GPT2Attention(config)
@ -145,9 +157,9 @@ class GPT2Model(nn.Module):
def __init__(self, config: GPT2Config): def __init__(self, config: GPT2Config):
super().__init__() super().__init__()
self.config = config self.config = config
assert config.add_cross_attention == False assert not config.add_cross_attention
assert config.scale_attn_by_inverse_layer_idx == False assert not config.scale_attn_by_inverse_layer_idx
assert config.reorder_and_upcast_attn == False assert not config.reorder_and_upcast_attn
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
# Optimization: While the vocab size of GPT-2 is 50257, we extend it # Optimization: While the vocab size of GPT-2 is 50257, we extend it
@ -180,8 +192,8 @@ class GPT2Model(nn.Module):
else: else:
cache_event = cache_events[i] cache_event = cache_events[i]
layer = self.h[i] layer = self.h[i]
hidden_states = layer( hidden_states = layer(hidden_states, kv_caches[i], input_metadata,
hidden_states, kv_caches[i], input_metadata, cache_event) cache_event)
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
return hidden_states return hidden_states
@ -206,19 +218,21 @@ class GPT2LMHeadModel(nn.Module):
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]: ) -> Dict[int, SequenceOutputs]:
hidden_states = self.transformer( hidden_states = self.transformer(input_ids, positions, kv_caches,
input_ids, positions, kv_caches, input_metadata, cache_events) input_metadata, cache_events)
next_tokens = self.sampler( next_tokens = self.sampler(self.lm_head_weight, hidden_states,
self.lm_head_weight, hidden_states, input_metadata) input_metadata)
return next_tokens return next_tokens
_column_parallel_weights = ["wte.weight", "c_fc.weight", "c_fc.bias"] _column_parallel_weights = ["wte.weight", "c_fc.weight", "c_fc.bias"]
_row_parallel_weights = ["c_proj.weight"] _row_parallel_weights = ["c_proj.weight"]
def load_weights(self, model_name_or_path: str, def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
use_np_cache: bool = False): use_np_cache: bool = False):
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
tensor_model_parallel_rank = get_tensor_model_parallel_rank() tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict() state_dict = self.state_dict()
@ -228,10 +242,12 @@ class GPT2LMHeadModel(nn.Module):
# GPT-2 ties the weights of the embedding layer and the final # GPT-2 ties the weights of the embedding layer and the final
# linear layer. # linear layer.
continue continue
if ".attn.bias" in name: if ".attn.bias" in name or ".attn.masked_bias" in name:
# Skip attention mask. # Skip attention mask.
# NOTE: "c_attn.bias" should not be skipped. # NOTE: "c_attn.bias" should not be skipped.
continue continue
if not name.startswith("transformer."):
name = "transformer." + name name = "transformer." + name
# The HF's GPT-2 implementation uses Conv1D instead of Linear. # The HF's GPT-2 implementation uses Conv1D instead of Linear.
@ -246,16 +262,20 @@ class GPT2LMHeadModel(nn.Module):
if name == "transformer.wte.weight": if name == "transformer.wte.weight":
# Consider padding in the vocab size. # Consider padding in the vocab size.
padded_vocab_size = param.shape[0] * tensor_model_parallel_world_size padded_vocab_size = (param.shape[0] *
tensor_model_parallel_world_size)
num_extra_rows = padded_vocab_size - self.config.vocab_size num_extra_rows = padded_vocab_size - self.config.vocab_size
extra_rows = torch.empty(num_extra_rows, loaded_weight.shape[1]) extra_rows = torch.empty(num_extra_rows,
loaded_weight.shape[1])
extra_rows = extra_rows.to(loaded_weight) extra_rows = extra_rows.to(loaded_weight)
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0) loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
# For the fused QKV linear layer, manually shard the weights. # For the fused QKV linear layer, manually shard the weights.
if "c_attn" in name: if "c_attn" in name:
# GPT-2's fused QKV has the shape of [3 * num_heads * head_size, hidden_size]. # GPT-2's fused QKV has the shape of
# When tensor parallelism is used, we shard the weights along the head dimension. # [3 * num_heads * head_size, hidden_size].
# When tensor parallelism is used, we shard the weights along
# the head dimension.
total_num_heads = self.config.num_attention_heads total_num_heads = self.config.num_attention_heads
hidden_size = self.config.hidden_size hidden_size = self.config.hidden_size
head_size = hidden_size // total_num_heads head_size = hidden_size // total_num_heads
@ -264,11 +284,13 @@ class GPT2LMHeadModel(nn.Module):
head_end = (tensor_model_parallel_rank + 1) * num_heads head_end = (tensor_model_parallel_rank + 1) * num_heads
if name.endswith(".weight"): if name.endswith(".weight"):
loaded_weight = loaded_weight.view(3, total_num_heads, head_size, hidden_size) loaded_weight = loaded_weight.view(3, total_num_heads,
head_size, hidden_size)
loaded_weight = loaded_weight[:, head_start:head_end, :, :] loaded_weight = loaded_weight[:, head_start:head_end, :, :]
loaded_weight = loaded_weight.reshape(-1, hidden_size) loaded_weight = loaded_weight.reshape(-1, hidden_size)
elif name.endswith(".bias"): elif name.endswith(".bias"):
loaded_weight = loaded_weight.view(3, total_num_heads, head_size) loaded_weight = loaded_weight.view(3, total_num_heads,
head_size)
loaded_weight = loaded_weight[:, head_start:head_end, :] loaded_weight = loaded_weight[:, head_start:head_end, :]
loaded_weight = loaded_weight.reshape(-1) loaded_weight = loaded_weight.reshape(-1)
else: else:

View File

@ -0,0 +1,343 @@
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
# Copyright 2023 The vLLM team.
# Copyright 2023 CTranslate2, and Michael Feil
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GPTBigCode model compatible with HuggingFace weights.
The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
"""
from typing import Dict, List, Optional, Tuple
import torch
from torch import nn
from transformers import GPTBigCodeConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
from vllm.sequence import SequenceOutputs
KVCache = Tuple[torch.Tensor, torch.Tensor]
class GPTBigCodeAttention(nn.Module):
def __init__(self, config: GPTBigCodeConfig):
super().__init__()
self.hidden_size = config.hidden_size
total_num_heads = config.num_attention_heads
self.tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
assert total_num_heads % self.tensor_model_parallel_world_size == 0
self.num_heads = (total_num_heads //
self.tensor_model_parallel_world_size)
self.head_dim = self.hidden_size // total_num_heads
self.scale = self.head_dim**-0.5
self.multi_query = config.multi_query
if self.multi_query:
self.num_kv_heads = 1
self.kv_dim = self.head_dim
self.c_attn_q = ColumnParallelLinear(self.hidden_size,
self.hidden_size,
bias=True,
gather_output=False,
perform_initialization=False)
self.c_attn_kv = nn.Linear(self.hidden_size,
2 * self.kv_dim,
bias=True)
else:
self.num_kv_heads = self.num_heads
self.kv_dim = self.num_kv_heads * self.head_dim
self.c_attn = ColumnParallelLinear(self.hidden_size,
self.hidden_size +
2 * self.kv_dim,
bias=True,
gather_output=False,
perform_initialization=False)
self.c_proj = RowParallelLinear(self.hidden_size,
self.hidden_size,
bias=True,
input_is_parallel=True,
perform_initialization=False)
self.attn = PagedAttention(self.num_heads,
self.head_dim,
scale=self.scale,
num_kv_heads=self.num_kv_heads)
def forward(
self,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
if self.multi_query:
q, _ = self.c_attn_q(hidden_states)
kv = self.c_attn_kv(hidden_states)
k, v = kv.split([self.kv_dim, self.kv_dim], dim=-1)
else:
qkv, _ = self.c_attn(hidden_states)
q, k, v = qkv.split([
self.hidden_size // self.tensor_model_parallel_world_size,
self.kv_dim, self.kv_dim
],
dim=-1)
key_cache, value_cache = kv_cache
attn_output = self.attn(q, k, v, key_cache, value_cache,
input_metadata, cache_event)
attn_output, _ = self.c_proj(attn_output)
return attn_output
class GPTBigMLP(nn.Module):
def __init__(
self,
intermediate_size: int,
config: GPTBigCodeConfig,
):
super().__init__()
hidden_size = config.hidden_size
self.c_fc = ColumnParallelLinear(hidden_size,
intermediate_size,
bias=True,
gather_output=False,
perform_initialization=False)
self.c_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=True,
input_is_parallel=True,
perform_initialization=False)
self.act = get_act_fn(config.activation_function)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.c_fc(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states, _ = self.c_proj(hidden_states)
return hidden_states
class GPTBigCodeBlock(nn.Module):
def __init__(self, config: GPTBigCodeConfig):
super().__init__()
hidden_size = config.hidden_size
inner_dim = (config.n_inner if config.n_inner is not None else 4 *
hidden_size)
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = GPTBigCodeAttention(config)
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = GPTBigMLP(inner_dim, config)
def forward(
self,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
attn_output = self.attn(
hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event,
)
# residual connection
hidden_states = attn_output + residual
residual = hidden_states
hidden_states = self.ln_2(hidden_states)
feed_forward_hidden_states = self.mlp(hidden_states)
# residual connection
hidden_states = residual + feed_forward_hidden_states
return hidden_states
class GPTBigCodeModel(nn.Module):
def __init__(self, config: GPTBigCodeConfig):
super().__init__()
self.config = config
assert not config.add_cross_attention
self.embed_dim = config.hidden_size
# Optimization: While the vocab size of GPT-2 is 50257, we extend it
# to 50304 in order to make it divisible by 64.
# This improves performance since GPUs are faster if the dimension
# is divisible by 64. In addition, it allows us to shard the embedding
# layer across 2, 4, 8, or more GPUs.
vocab_size = ((config.vocab_size + 63) // 64) * 64
self.wte = VocabParallelEmbedding(vocab_size, self.embed_dim)
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.h = nn.ModuleList(
[GPTBigCodeBlock(config) for _ in range(config.num_hidden_layers)])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds
for i in range(len(self.h)):
if cache_events is None:
cache_event = None
else:
cache_event = cache_events[i]
layer = self.h[i]
hidden_states = layer(hidden_states, kv_caches[i], input_metadata,
cache_event)
hidden_states = self.ln_f(hidden_states)
return hidden_states
class GPTBigCodeForCausalLM(nn.Module):
def __init__(self, config: GPTBigCodeConfig):
super().__init__()
self.config = config
self.transformer = GPTBigCodeModel(config)
# TODO(zhuohan): create a new weight after implementing pipeline
# parallelism
self.lm_head_weight = self.transformer.wte.weight
self.sampler = Sampler(config.vocab_size)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]:
hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
input_metadata)
return next_tokens
_column_parallel_weights = ["wte.weight", "c_fc.weight", "c_fc.bias"]
_row_parallel_weights = ["c_proj.weight"]
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
use_np_cache: bool = False):
tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, use_np_cache):
if "lm_head.weight" in name:
# GPT-2 ties the weights of the embedding layer and the final
# linear layer.
continue
if ".attn.bias" in name:
# Skip attention mask.
# NOTE: "c_attn.bias" should not be skipped.
continue
if not name.startswith("transformer."):
name = "transformer." + name
# For the fused QKV linear layer, manually shard the weights.
if "c_attn" in name:
# GPT-2's fused QKV has the shape of
# [3 * num_heads * head_size, hidden_size].
# When tensor parallelism is used, we shard the weights along
# the head dimension.
total_num_heads = self.config.num_attention_heads
total_num_kv_heads = (1 if self.config.multi_query else
total_num_heads)
hidden_size = self.config.hidden_size
head_size = hidden_size // total_num_heads
total_kv_size = head_size * total_num_kv_heads
num_heads = total_num_heads // tensor_model_parallel_world_size
head_start = tensor_model_parallel_rank * num_heads
head_end = (tensor_model_parallel_rank + 1) * num_heads
wq, wk, wv = torch.split(
loaded_weight, [hidden_size, total_kv_size, total_kv_size],
dim=0)
wq = wq[head_size * head_start:head_size * head_end]
if not self.config.multi_query:
# Split the heads when using normal multi-head attention
wk = wk[head_size * head_start:head_size * head_end]
wv = wv[head_size * head_start:head_size * head_end]
loaded_weight = torch.cat([wq, wk, wv], dim=0)
else:
# For multi-query attention, we split the query
# but replicate the key and value.
loaded_weight_q = wq
loaded_weight_kv = torch.cat([wk, wv], dim=0)
q_weight_name = name.replace("c_attn", "c_attn_q")
kv_weight_name = name.replace("c_attn", "c_attn_kv")
load_tensor_parallel_weights(state_dict[q_weight_name],
loaded_weight_q,
q_weight_name,
self._column_parallel_weights,
self._row_parallel_weights,
tensor_model_parallel_rank)
load_tensor_parallel_weights(state_dict[kv_weight_name],
loaded_weight_kv,
kv_weight_name,
self._column_parallel_weights,
self._row_parallel_weights,
tensor_model_parallel_rank)
continue
param = state_dict[name]
if name == "transformer.wte.weight":
# Consider padding in the vocab size.
padded_vocab_size = param.shape[
0] * tensor_model_parallel_world_size
num_extra_rows = padded_vocab_size - self.config.vocab_size
extra_rows = torch.empty(num_extra_rows,
loaded_weight.shape[1])
extra_rows = extra_rows.to(loaded_weight)
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
load_tensor_parallel_weights(param, loaded_weight, name,
self._column_parallel_weights,
self._row_parallel_weights,
tensor_model_parallel_rank)

View File

@ -0,0 +1,251 @@
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gptj/modeling_gptj.py
# Copyright 2023 The vLLM team.
# Copyright 2021 The EleutherAI and HuggingFace Teams. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GPT-J model compatible with HuggingFace weights.
The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
"""
from typing import Dict, List, Optional, Tuple
import torch
from torch import nn
from transformers import GPTJConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
from vllm.sequence import SequenceOutputs
KVCache = Tuple[torch.Tensor, torch.Tensor]
class GPTJAttention(nn.Module):
def __init__(self, config: GPTJConfig):
super().__init__()
self.total_num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.total_num_heads
self.qkv_proj = ColumnParallelLinear(config.hidden_size,
3 * config.hidden_size,
bias=False,
gather_output=False,
perform_initialization=False)
self.out_proj = RowParallelLinear(config.hidden_size,
config.hidden_size,
bias=False,
input_is_parallel=True,
perform_initialization=False)
tp_world_size = get_tensor_model_parallel_world_size()
assert self.total_num_heads % tp_world_size == 0
self.num_heads = self.total_num_heads // tp_world_size
scaling = self.head_size**-0.5
assert getattr(config, "rotary", True)
assert config.rotary_dim % 2 == 0
self.attn = PagedAttentionWithRoPE(self.num_heads, self.head_size,
scaling, config.rotary_dim)
self.warmup = False
def forward(
self,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
k_cache, v_cache = kv_cache
attn_output = self.attn(position_ids, q, k, v, k_cache, v_cache,
input_metadata, cache_event)
attn_output, _ = self.out_proj(attn_output)
return attn_output
class GPTJMLP(nn.Module):
def __init__(self, intermediate_size: int, config: GPTJConfig):
super().__init__()
hidden_size = config.n_embd
self.fc_in = ColumnParallelLinear(hidden_size,
intermediate_size,
gather_output=False,
perform_initialization=False)
self.fc_out = RowParallelLinear(intermediate_size,
hidden_size,
input_is_parallel=True,
perform_initialization=False)
self.act = get_act_fn(config.activation_function)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.fc_in(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states, _ = self.fc_out(hidden_states)
return hidden_states
class GPTJBlock(nn.Module):
def __init__(self, config: GPTJConfig):
super().__init__()
if config.n_inner is None:
inner_dim = 4 * config.n_embd
else:
inner_dim = config.n_inner
self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.attn = GPTJAttention(config)
self.mlp = GPTJMLP(inner_dim, config)
def forward(
self,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
attn_output = self.attn(
position_ids=position_ids,
hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event,
)
mlp_output = self.mlp(hidden_states)
hidden_states = attn_output + mlp_output + residual
return hidden_states
class GPTJModel(nn.Module):
def __init__(self, config: GPTJConfig):
super().__init__()
self.config = config
self.embed_dim = config.n_embd
self.wte = VocabParallelEmbedding(config.vocab_size,
self.embed_dim,
perform_initialization=False)
self.h = nn.ModuleList(
[GPTJBlock(config) for _ in range(config.n_layer)])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.wte(input_ids)
for i in range(len(self.h)):
if cache_events is None:
cache_event = None
else:
cache_event = cache_events[i]
layer = self.h[i]
hidden_states = layer(
position_ids,
hidden_states,
kv_caches[i],
input_metadata,
cache_event,
)
hidden_states = self.ln_f(hidden_states)
return hidden_states
class GPTJForCausalLM(nn.Module):
def __init__(self, config: GPTJConfig):
super().__init__()
self.config = config
assert not config.tie_word_embeddings
self.transformer = GPTJModel(config)
self.lm_head = ColumnParallelLinear(config.n_embd,
config.vocab_size,
gather_output=False,
perform_initialization=False)
self.sampler = Sampler(config.vocab_size)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]:
hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
input_metadata, self.lm_head.bias)
return next_tokens
_column_parallel_weights = [
"wte.weight", "fc_in.weight", "fc_in.bias", "lm_head.weight",
"lm_head.bias"
]
_row_parallel_weights = ["out_proj.weight", "fc_out.weight"]
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
use_np_cache: bool = False):
tp_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, use_np_cache):
if "attn.bias" in name or "attn.masked_bias" in name:
continue
is_attention_weight = False
for stride_id, att_weight_name in enumerate(
["q_proj", "k_proj", "v_proj"]):
if att_weight_name not in name:
continue
param = state_dict[name.replace(att_weight_name, "qkv_proj")]
shard_size = param.shape[1]
loaded_weight = loaded_weight[shard_size * tp_rank:shard_size *
(tp_rank + 1)]
param_slice = param.data[shard_size * stride_id:shard_size *
(stride_id + 1)]
assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight)
is_attention_weight = True
break
if is_attention_weight:
continue
param = state_dict[name]
load_tensor_parallel_weights(param, loaded_weight, name,
self._column_parallel_weights,
self._row_parallel_weights, tp_rank)

View File

@ -1,5 +1,6 @@
# coding=utf-8 # coding=utf-8
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt_neox/modeling_gpt_neox.py # Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt_neox/modeling_gpt_neox.py
# Copyright 2023 The vLLM team. # Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI The HuggingFace Inc. team. All rights reserved. # Copyright 2022 EleutherAI The HuggingFace Inc. team. All rights reserved.
# #
@ -48,15 +49,19 @@ class GPTNeoXAttention(nn.Module):
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.total_num_heads self.head_size = self.hidden_size // self.total_num_heads
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
assert self.total_num_heads % tensor_model_parallel_world_size == 0 assert self.total_num_heads % tensor_model_parallel_world_size == 0
self.num_heads = self.total_num_heads // tensor_model_parallel_world_size self.num_heads = (self.total_num_heads //
tensor_model_parallel_world_size)
self.query_key_value = ColumnParallelLinear(config.hidden_size, self.query_key_value = ColumnParallelLinear(
config.hidden_size,
3 * config.hidden_size, 3 * config.hidden_size,
gather_output=False, gather_output=False,
perform_initialization=False) perform_initialization=False)
self.dense = RowParallelLinear(config.hidden_size, config.hidden_size, self.dense = RowParallelLinear(config.hidden_size,
config.hidden_size,
input_is_parallel=True, input_is_parallel=True,
perform_initialization=False) perform_initialization=False)
@ -75,11 +80,10 @@ class GPTNeoXAttention(nn.Module):
cache_event: Optional[torch.cuda.Event], cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.query_key_value(hidden_states) qkv, _ = self.query_key_value(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
k_cache, v_cache = kv_cache k_cache, v_cache = kv_cache
attn_output = self.attn( attn_output = self.attn(position_ids, q, k, v, k_cache, v_cache,
position_ids, q, k, v, k_cache, v_cache, input_metadata, cache_event) input_metadata, cache_event)
output, _ = self.dense(attn_output) output, _ = self.dense(attn_output)
return output return output
@ -92,7 +96,8 @@ class GPTNeoXMLP(nn.Module):
config.intermediate_size, config.intermediate_size,
gather_output=False, gather_output=False,
perform_initialization=False) perform_initialization=False)
self.dense_4h_to_h = RowParallelLinear(config.intermediate_size, config.hidden_size, self.dense_4h_to_h = RowParallelLinear(config.intermediate_size,
config.hidden_size,
input_is_parallel=True, input_is_parallel=True,
perform_initialization=False) perform_initialization=False)
self.act = get_act_fn(config.hidden_act) self.act = get_act_fn(config.hidden_act)
@ -109,8 +114,10 @@ class GPTNeoXLayer(nn.Module):
def __init__(self, config: GPTNeoXConfig): def __init__(self, config: GPTNeoXConfig):
super().__init__() super().__init__()
self.use_parallel_residual = config.use_parallel_residual self.use_parallel_residual = config.use_parallel_residual
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.input_layernorm = nn.LayerNorm(config.hidden_size,
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) eps=config.layer_norm_eps)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.attention = GPTNeoXAttention(config) self.attention = GPTNeoXAttention(config)
self.mlp = GPTNeoXMLP(config) self.mlp = GPTNeoXMLP(config)
@ -154,10 +161,13 @@ class GPTNeoXModel(nn.Module):
super().__init__() super().__init__()
self.config = config self.config = config
self.embed_in = VocabParallelEmbedding(config.vocab_size, config.hidden_size, self.embed_in = VocabParallelEmbedding(config.vocab_size,
config.hidden_size,
perform_initialization=False) perform_initialization=False)
self.layers = nn.ModuleList([GPTNeoXLayer(config) for _ in range(config.num_hidden_layers)]) self.layers = nn.ModuleList(
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) [GPTNeoXLayer(config) for _ in range(config.num_hidden_layers)])
self.final_layer_norm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
def forward( def forward(
self, self,
@ -191,8 +201,10 @@ class GPTNeoXForCausalLM(nn.Module):
super().__init__() super().__init__()
self.config = config self.config = config
self.gpt_neox = GPTNeoXModel(config) self.gpt_neox = GPTNeoXModel(config)
self.embed_out = ColumnParallelLinear(config.hidden_size, config.vocab_size, self.embed_out = ColumnParallelLinear(config.hidden_size,
bias=False, gather_output=False, config.vocab_size,
bias=False,
gather_output=False,
perform_initialization=False) perform_initialization=False)
self.sampler = Sampler(config.vocab_size) self.sampler = Sampler(config.vocab_size)
@ -204,16 +216,20 @@ class GPTNeoXForCausalLM(nn.Module):
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]: ) -> Dict[int, SequenceOutputs]:
hidden_states = self.gpt_neox( hidden_states = self.gpt_neox(input_ids, positions, kv_caches,
input_ids, positions, kv_caches, input_metadata, cache_events) input_metadata, cache_events)
next_tokens = self.sampler( next_tokens = self.sampler(self.embed_out.weight, hidden_states,
self.embed_out.weight, hidden_states, input_metadata) input_metadata)
return next_tokens return next_tokens
_column_parallel_weights = ["embed_in.weight", "embed_out.weight", "dense_h_to_4h.weight", "dense_h_to_4h.bias"] _column_parallel_weights = [
"embed_in.weight", "embed_out.weight", "dense_h_to_4h.weight",
"dense_h_to_4h.bias"
]
_row_parallel_weights = ["dense.weight", "dense_4h_to_h.weight"] _row_parallel_weights = ["dense.weight", "dense_4h_to_h.weight"]
def load_weights(self, model_name_or_path: str, def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
use_np_cache: bool = False): use_np_cache: bool = False):
tensor_model_parallel_rank = get_tensor_model_parallel_rank() tensor_model_parallel_rank = get_tensor_model_parallel_rank()
@ -230,17 +246,19 @@ class GPTNeoXForCausalLM(nn.Module):
# required shape is [3 * num_heads * head_size, hidden_size]. # required shape is [3 * num_heads * head_size, hidden_size].
# Thus, we need weight conversion. # Thus, we need weight conversion.
shard_size = param.shape[0] shard_size = param.shape[0]
loaded_weight = loaded_weight[shard_size * tensor_model_parallel_rank loaded_weight = loaded_weight[
:shard_size * (tensor_model_parallel_rank + 1)] shard_size * tensor_model_parallel_rank:shard_size *
(tensor_model_parallel_rank + 1)]
num_heads = self.config.num_attention_heads num_heads = self.config.num_attention_heads
hidden_size = self.config.hidden_size hidden_size = self.config.hidden_size
head_size = hidden_size // num_heads head_size = hidden_size // num_heads
if 'query_key_value.weight' in name: if "query_key_value.weight" in name:
loaded_weight = loaded_weight.view(-1, 3, head_size, hidden_size) loaded_weight = loaded_weight.view(-1, 3, head_size,
hidden_size)
loaded_weight = loaded_weight.transpose(0, 1) loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight = loaded_weight.reshape(-1, hidden_size) loaded_weight = loaded_weight.reshape(-1, hidden_size)
elif 'query_key_value.bias' in name: elif "query_key_value.bias" in name:
loaded_weight = loaded_weight.view(-1, 3, head_size) loaded_weight = loaded_weight.view(-1, 3, head_size)
loaded_weight = loaded_weight.transpose(0, 1) loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight = loaded_weight.reshape(-1) loaded_weight = loaded_weight.reshape(-1)

View File

@ -0,0 +1,299 @@
# -*- coding: utf-8 -*-
from typing import Dict, List, Optional, Tuple
import torch
from torch import nn
from transformers import LlamaConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import (
ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding)
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
load_tensor_parallel_weights)
from vllm.sequence import SequenceOutputs
KVCache = Tuple[torch.Tensor, torch.Tensor]
class InternLMMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
):
super().__init__()
self.gate_up_proj = ColumnParallelLinear(hidden_size,
2 * intermediate_size,
bias=False,
gather_output=False,
perform_initialization=False)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
input_is_parallel=True,
perform_initialization=False)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class InternLMAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
):
super().__init__()
self.hidden_size = hidden_size
tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
self.total_num_heads = num_heads
assert self.total_num_heads % tensor_model_parallel_world_size == 0
self.num_heads = (self.total_num_heads //
tensor_model_parallel_world_size)
self.head_dim = hidden_size // self.total_num_heads
self.scaling = self.head_dim**-0.5
self.qkv_proj = ColumnParallelLinear(
hidden_size,
3 * self.total_num_heads * self.head_dim,
bias=True,
gather_output=False,
perform_initialization=False,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=True,
input_is_parallel=True,
perform_initialization=False,
)
self.attn = PagedAttentionWithRoPE(self.num_heads,
self.head_dim,
self.scaling,
rotary_dim=self.head_dim)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
k_cache, v_cache = kv_cache
attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
input_metadata, cache_event)
output, _ = self.o_proj(attn_output)
return output
class InternLMDecoderLayer(nn.Module):
def __init__(self, config: LlamaConfig):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = InternLMAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
)
self.mlp = InternLMMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
# Self Attention
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class InternLMModel(nn.Module):
def __init__(self, config: LlamaConfig):
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
vocab_size = ((config.vocab_size + 63) // 64) * 64
self.embed_tokens = VocabParallelEmbedding(
vocab_size, config.hidden_size, perform_initialization=False)
self.layers = nn.ModuleList([
InternLMDecoderLayer(config)
for _ in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
for i in range(len(self.layers)):
if cache_events is None:
cache_event = None
else:
cache_event = cache_events[i]
layer = self.layers[i]
hidden_states = layer(
positions,
hidden_states,
kv_caches[i],
input_metadata,
cache_event,
)
hidden_states = self.norm(hidden_states)
return hidden_states
class InternLMForCausalLM(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.model = InternLMModel(config)
vocab_size = ((config.vocab_size + 63) // 64) * 64
self.lm_head = ColumnParallelLinear(config.hidden_size,
vocab_size,
bias=False,
gather_output=False,
perform_initialization=False)
self.sampler = Sampler(config.vocab_size)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]:
hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
input_metadata)
return next_tokens
_column_parallel_weights = [
"embed_tokens.weight", "lm_head.weight", "qkv_proj.weight",
"gate_proj.weight", "up_proj.weight"
]
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
use_np_cache: bool = False):
tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, use_np_cache):
if "rotary_emb.inv_freq" in name:
continue
if "embed_tokens" in name or "lm_head" in name:
param = state_dict[name]
# Consider padding in the vocab size.
padded_vocab_size = (param.shape[0] *
tensor_model_parallel_world_size)
num_extra_rows = padded_vocab_size - self.config.vocab_size
extra_rows = torch.empty(num_extra_rows,
loaded_weight.shape[1])
extra_rows = extra_rows.to(loaded_weight)
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
is_attention_weight = False
for stride_id, att_weight_name in enumerate(
["q_proj", "k_proj", "v_proj"]):
if att_weight_name not in name:
continue
param = state_dict[name.replace(att_weight_name, "qkv_proj")]
shard_size = param.shape[0] // 3
loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank:shard_size *
(tensor_model_parallel_rank + 1)]
param_slice = param.data[shard_size * stride_id:shard_size *
(stride_id + 1)]
assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight)
is_attention_weight = True
break
if is_attention_weight:
continue
is_gate_up_weight = False
for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
if weight_name not in name:
continue
param = state_dict[name.replace(weight_name, "gate_up_proj")]
shard_size = param.shape[0] // 2
loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank:shard_size *
(tensor_model_parallel_rank + 1)]
param_slice = param.data[shard_size * stride_id:shard_size *
(stride_id + 1)]
assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight)
is_gate_up_weight = True
break
if is_gate_up_weight:
continue
param = state_dict[name]
load_tensor_parallel_weights(param, loaded_weight, name,
self._column_parallel_weights,
self._row_parallel_weights,
tensor_model_parallel_rank)

View File

@ -1,5 +1,6 @@
# coding=utf-8 # coding=utf-8
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py # Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team. # Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
# #
@ -30,7 +31,6 @@ import torch
from torch import nn from torch import nn
from transformers import LlamaConfig from transformers import LlamaConfig
from vllm.sequence import SequenceOutputs
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
@ -56,15 +56,19 @@ class LlamaMLP(nn.Module):
hidden_act: str, hidden_act: str,
): ):
super().__init__() super().__init__()
self.gate_up_proj = ColumnParallelLinear(hidden_size, 2 * intermediate_size, self.gate_up_proj = ColumnParallelLinear(hidden_size,
bias=False, gather_output=False, 2 * intermediate_size,
bias=False,
gather_output=False,
perform_initialization=False) perform_initialization=False)
self.down_proj = RowParallelLinear(intermediate_size, hidden_size, self.down_proj = RowParallelLinear(intermediate_size,
bias=False, input_is_parallel=True, hidden_size,
bias=False,
input_is_parallel=True,
perform_initialization=False) perform_initialization=False)
if hidden_act != 'silu': if hidden_act != "silu":
raise ValueError(f'Unsupported activation: {hidden_act}. ' raise ValueError(f"Unsupported activation: {hidden_act}. "
'Only silu is supported for now.') "Only silu is supported for now.")
self.act_fn = SiluAndMul() self.act_fn = SiluAndMul()
def forward(self, x): def forward(self, x):
@ -80,19 +84,26 @@ class LlamaAttention(nn.Module):
self, self,
hidden_size: int, hidden_size: int,
num_heads: int, num_heads: int,
num_kv_heads: int,
): ):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads self.total_num_heads = num_heads
assert self.total_num_heads % tensor_model_parallel_world_size == 0 assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tensor_model_parallel_world_size self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
assert self.total_num_kv_heads % tp_size == 0
self.num_kv_heads = self.total_num_kv_heads // tp_size
self.head_dim = hidden_size // self.total_num_heads self.head_dim = hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.qkv_proj = ColumnParallelLinear( self.qkv_proj = ColumnParallelLinear(
hidden_size, hidden_size,
3 * self.total_num_heads * self.head_dim, (self.total_num_heads + 2 * self.total_num_kv_heads) *
self.head_dim,
bias=False, bias=False,
gather_output=False, gather_output=False,
perform_initialization=False, perform_initialization=False,
@ -104,8 +115,11 @@ class LlamaAttention(nn.Module):
input_is_parallel=True, input_is_parallel=True,
perform_initialization=False, perform_initialization=False,
) )
self.attn = PagedAttentionWithRoPE(self.num_heads, self.head_dim, self.attn = PagedAttentionWithRoPE(self.num_heads,
self.scaling, rotary_dim=self.head_dim) self.head_dim,
self.scaling,
rotary_dim=self.head_dim,
num_kv_heads=self.num_kv_heads)
def forward( def forward(
self, self,
@ -116,10 +130,10 @@ class LlamaAttention(nn.Module):
cache_event: Optional[torch.cuda.Event], cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
k_cache, v_cache = kv_cache k_cache, v_cache = kv_cache
attn_output = self.attn( attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
positions, q, k, v, k_cache, v_cache, input_metadata, cache_event) input_metadata, cache_event)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
@ -132,14 +146,17 @@ class LlamaDecoderLayer(nn.Module):
self.self_attn = LlamaAttention( self.self_attn = LlamaAttention(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
num_heads=config.num_attention_heads, num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
) )
self.mlp = LlamaMLP( self.mlp = LlamaMLP(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
) )
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.input_layernorm = RMSNorm(config.hidden_size,
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward( def forward(
self, self,
@ -177,9 +194,12 @@ class LlamaModel(nn.Module):
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size, vocab_size = ((config.vocab_size + 63) // 64) * 64
perform_initialization=False) self.embed_tokens = VocabParallelEmbedding(
self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) vocab_size, config.hidden_size, perform_initialization=False)
self.layers = nn.ModuleList([
LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward( def forward(
@ -209,12 +229,14 @@ class LlamaModel(nn.Module):
class LlamaForCausalLM(nn.Module): class LlamaForCausalLM(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.config = config self.config = config
self.model = LlamaModel(config) self.model = LlamaModel(config)
vocab_size = ((config.vocab_size + 63) // 64) * 64
self.lm_head = ColumnParallelLinear(config.hidden_size, self.lm_head = ColumnParallelLinear(config.hidden_size,
config.vocab_size, vocab_size,
bias=False, bias=False,
gather_output=False, gather_output=False,
perform_initialization=False) perform_initialization=False)
@ -228,21 +250,35 @@ class LlamaForCausalLM(nn.Module):
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]: ) -> Dict[int, SequenceOutputs]:
hidden_states = self.model( hidden_states = self.model(input_ids, positions, kv_caches,
input_ids, positions, kv_caches, input_metadata, cache_events) input_metadata, cache_events)
next_tokens = self.sampler( next_tokens = self.sampler(self.lm_head.weight, hidden_states,
self.lm_head.weight, hidden_states, input_metadata) input_metadata)
return next_tokens return next_tokens
_column_parallel_weights = ["embed_tokens.weight", "lm_head.weight", _column_parallel_weights = [
"qkv_proj.weight", "gate_proj.weight", "embed_tokens.weight", "lm_head.weight", "qkv_proj.weight",
"up_proj.weight"] "gate_proj.weight", "up_proj.weight"
]
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"] _row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
def load_weights(self, model_name_or_path: str, def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
use_np_cache: bool = False): use_np_cache: bool = False):
tp_size = get_tensor_model_parallel_world_size()
tensor_model_parallel_rank = get_tensor_model_parallel_rank() tensor_model_parallel_rank = get_tensor_model_parallel_rank()
q_proj_shard_size = (self.config.hidden_size // tp_size)
kv_proj_shard_size = (self.config.hidden_size //
self.config.num_attention_heads *
self.config.num_key_value_heads // tp_size)
attention_weight_specs = [
# (weight_name, shard_size, offset)
("q_proj", q_proj_shard_size, 0),
("k_proj", kv_proj_shard_size, q_proj_shard_size),
("v_proj", kv_proj_shard_size,
q_proj_shard_size + kv_proj_shard_size),
]
state_dict = self.state_dict() state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
@ -250,18 +286,28 @@ class LlamaForCausalLM(nn.Module):
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
if "embed_tokens" in name or "lm_head" in name:
param = state_dict[name]
# Consider padding in the vocab size.
padded_vocab_size = (param.shape[0] * tp_size)
num_extra_rows = padded_vocab_size - self.config.vocab_size
extra_rows = torch.empty(num_extra_rows,
loaded_weight.shape[1])
extra_rows = extra_rows.to(loaded_weight)
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
is_attention_weight = False is_attention_weight = False
for stride_id, att_weight_name in enumerate(["q_proj", "k_proj", "v_proj"]): for weight_name, shard_size, offset in attention_weight_specs:
if att_weight_name not in name: if weight_name not in name:
continue continue
param = state_dict[name.replace(att_weight_name, "qkv_proj")] param = state_dict[name.replace(weight_name, "qkv_proj")]
shard_size = param.shape[0] // 3
loaded_weight = loaded_weight[ loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank shard_size * tensor_model_parallel_rank:shard_size *
:shard_size * (tensor_model_parallel_rank + 1)] (tensor_model_parallel_rank + 1)]
param_slice = param.data[shard_size * stride_id param_slice = param.data[offset:offset + shard_size]
:shard_size * (stride_id + 1)]
assert param_slice.shape == loaded_weight.shape assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight) param_slice.copy_(loaded_weight)
is_attention_weight = True is_attention_weight = True
break break
@ -275,10 +321,10 @@ class LlamaForCausalLM(nn.Module):
param = state_dict[name.replace(weight_name, "gate_up_proj")] param = state_dict[name.replace(weight_name, "gate_up_proj")]
shard_size = param.shape[0] // 2 shard_size = param.shape[0] // 2
loaded_weight = loaded_weight[ loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank shard_size * tensor_model_parallel_rank:shard_size *
:shard_size * (tensor_model_parallel_rank + 1)] (tensor_model_parallel_rank + 1)]
param_slice = param.data[shard_size * stride_id param_slice = param.data[shard_size * stride_id:shard_size *
:shard_size * (stride_id + 1)] (stride_id + 1)]
assert param_slice.shape == loaded_weight.shape assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight) param_slice.copy_(loaded_weight)
is_gate_up_weight = True is_gate_up_weight = True

View File

@ -0,0 +1,280 @@
# coding=utf-8
# Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main
import math
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn as nn
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import PagedAttentionWithALiBi
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
from vllm.sequence import SequenceOutputs
from vllm.transformers_utils.configs.mpt import MPTConfig
KVCache = Tuple[torch.Tensor, torch.Tensor]
def _get_alibi_slopes(
total_num_heads: int,
alibi_bias_max: int,
) -> torch.Tensor:
next_power_of_2 = 2**math.ceil(math.log2(total_num_heads))
m = torch.arange(1, next_power_of_2 + 1, dtype=torch.float32)
m = m.mul(alibi_bias_max / next_power_of_2)
slopes = 1.0 / torch.pow(2, m)
if next_power_of_2 != total_num_heads:
slopes = torch.concat([slopes[1::2], slopes[::2]])[:total_num_heads]
return slopes
class MPTAttention(nn.Module):
def __init__(self, config: MPTConfig):
super().__init__()
self.d_model = config.d_model
self.total_num_heads = config.n_heads
self.clip_qkv = config.attn_config["clip_qkv"]
self.qk_ln = config.attn_config["qk_ln"]
self.alibi_bias_max = config.attn_config["alibi_bias_max"]
assert not config.attn_config["prefix_lm"]
assert config.attn_config["alibi"]
self.qkv_proj = ColumnParallelLinear(
self.d_model,
3 * self.d_model,
bias=not config.no_bias,
gather_output=False,
perform_initialization=False,
)
if self.qk_ln:
self.q_ln = nn.LayerNorm(self.d_model)
self.k_ln = nn.LayerNorm(self.d_model)
self.out_proj = RowParallelLinear(
self.d_model,
self.d_model,
bias=not config.no_bias,
input_is_parallel=True,
perform_initialization=False,
)
tp_world_size = get_tensor_model_parallel_world_size()
assert self.total_num_heads % tp_world_size == 0
self.num_heads = self.total_num_heads // tp_world_size
# Create the alibi slopes and slice them.
tp_rank = get_tensor_model_parallel_rank()
head_start = tp_rank * self.num_heads
head_end = (tp_rank + 1) * self.num_heads
alibi_slopes = _get_alibi_slopes(self.total_num_heads,
self.alibi_bias_max)
alibi_slopes = alibi_slopes[head_start:head_end].tolist()
self.head_dim = self.d_model // self.total_num_heads
scaling = self.head_dim**-0.5
self.attn = PagedAttentionWithALiBi(self.num_heads, self.head_dim,
scaling, alibi_slopes)
def forward(
self,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
del position_ids # unused.
qkv, _ = self.qkv_proj(hidden_states)
if self.clip_qkv is not None:
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
q, k, v = qkv.chunk(chunks=3, dim=-1)
if self.qk_ln:
q = self.q_ln(q)
k = self.k_ln(k)
k_cache, v_cache = kv_cache
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
cache_event)
output, _ = self.out_proj(attn_output)
return output
class MPTMLP(nn.Module):
def __init__(self, config: MPTConfig):
super().__init__()
hidden_size = config.d_model
expansion_ratio = config.expansion_ratio
intermediate_size = expansion_ratio * hidden_size
self.up_proj = ColumnParallelLinear(hidden_size,
intermediate_size,
bias=not config.no_bias,
gather_output=False,
perform_initialization=False)
self.act = get_act_fn("gelu")
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=not config.no_bias,
input_is_parallel=True,
perform_initialization=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x, _ = self.up_proj(x)
x = self.act(x)
x, _ = self.down_proj(x)
return x
class MPTBlock(nn.Module):
def __init__(self, config: MPTConfig):
super().__init__()
hidden_size = config.d_model
self.norm_1 = nn.LayerNorm(hidden_size)
self.attn = MPTAttention(config)
self.norm_2 = nn.LayerNorm(hidden_size)
self.ffn = MPTMLP(config)
def forward(
self,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
x = self.norm_1(hidden_states)
x = self.attn(
position_ids=position_ids,
hidden_states=x,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event,
)
hidden_states = hidden_states + x
x = self.norm_2(hidden_states)
x = self.ffn(x)
hidden_states = hidden_states + x
return hidden_states
class MPTModel(nn.Module):
def __init__(self, config: MPTConfig):
super().__init__()
assert config.embedding_fraction == 1.0
assert config.norm_type == "low_precision_layernorm"
self.wte = VocabParallelEmbedding(config.vocab_size,
config.d_model,
perform_initialization=False)
self.blocks = nn.ModuleList(
[MPTBlock(config) for _ in range(config.n_layers)])
self.norm_f = nn.LayerNorm(config.d_model)
if config.no_bias:
for module in self.modules():
if hasattr(module, "bias"):
if isinstance(module.bias, nn.Parameter):
# Remove the bias term in Linear and LayerNorm.
module.register_parameter("bias", None)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.wte(input_ids)
for i in range(len(self.blocks)):
if cache_events is None:
cache_event = None
else:
cache_event = cache_events[i]
block = self.blocks[i]
hidden_states = block(
position_ids,
hidden_states,
kv_caches[i],
input_metadata,
cache_event,
)
hidden_states = self.norm_f(hidden_states)
return hidden_states
class MPTForCausalLM(nn.Module):
def __init__(self, config: MPTConfig):
super().__init__()
self.config = config
assert config.tie_word_embeddings
self.transformer = MPTModel(config)
# TODO(zhuohan): create a new weight after implementing pipeline
# parallelism
self.lm_head_weight = self.transformer.wte.weight
self.sampler = Sampler(config.vocab_size)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]:
hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
input_metadata)
return next_tokens
_column_parallel_weights = ["wte.weight", "up_proj.weight", "up_proj.bias"]
_row_parallel_weights = ["out_proj.weight", "down_proj.weight"]
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
use_np_cache: bool = False):
tp_world_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, use_np_cache):
if "Wqkv" in name:
# NOTE(woosuk): MPT's fused QKV has the shape of
# [3 * num_heads * head_size, hidden_size].
# When tensor model parallelism is used, we need to shard
# the weight along the hidden dimension.
total_num_heads = self.config.num_attention_heads
hidden_size = self.config.hidden_size
head_size = hidden_size // total_num_heads
num_heads = total_num_heads // tp_world_size
head_start = tp_rank * num_heads
head_end = (tp_rank + 1) * num_heads
if name.endswith(".weight"):
loaded_weight = loaded_weight.view(3, total_num_heads,
head_size, hidden_size)
loaded_weight = loaded_weight[:, head_start:head_end, :, :]
loaded_weight = loaded_weight.reshape(-1, hidden_size)
elif name.endswith(".bias"):
loaded_weight = loaded_weight.view(3, total_num_heads,
head_size)
loaded_weight = loaded_weight[:, head_start:head_end, :]
loaded_weight = loaded_weight.reshape(-1)
else:
raise ValueError(f"Unexpected parameter name {name}")
name = name.replace("Wqkv", "qkv_proj")
param = state_dict[name]
load_tensor_parallel_weights(param, loaded_weight, name,
self._column_parallel_weights,
self._row_parallel_weights, tp_rank)

View File

@ -1,7 +1,9 @@
# coding=utf-8 # coding=utf-8
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/opt/modeling_opt.py # Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/opt/modeling_opt.py
# Copyright 2023 The vLLM team. # Copyright 2023 The vLLM team.
# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. # Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights
# reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -43,8 +45,9 @@ KVCache = Tuple[torch.Tensor, torch.Tensor]
class OPTLearnedPositionalEmbedding(nn.Embedding): class OPTLearnedPositionalEmbedding(nn.Embedding):
def __init__(self, num_embeddings: int, embedding_dim: int): def __init__(self, num_embeddings: int, embedding_dim: int):
# OPT is set up so that if padding_idx is specified then offset the embedding ids by 2 # OPT is set up so that if padding_idx is specified then offset the
# and adjust num_embeddings appropriately. Other models don't have this hack # embedding ids by 2 and adjust num_embeddings appropriately. Other
# models don't have this hack
self.offset = 2 self.offset = 2
super().__init__(num_embeddings + self.offset, embedding_dim) super().__init__(num_embeddings + self.offset, embedding_dim)
@ -62,20 +65,26 @@ class OPTAttention(nn.Module):
) -> None: ) -> None:
super().__init__() super().__init__()
self.embed_dim = embed_dim self.embed_dim = embed_dim
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
total_num_heads = num_heads total_num_heads = num_heads
assert num_heads % tensor_model_parallel_world_size == 0 assert num_heads % tensor_model_parallel_world_size == 0
self.num_heads = total_num_heads // tensor_model_parallel_world_size self.num_heads = total_num_heads // tensor_model_parallel_world_size
self.head_dim = embed_dim // total_num_heads self.head_dim = embed_dim // total_num_heads
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.qkv_proj = ColumnParallelLinear(embed_dim, 3 * embed_dim, bias=bias, self.qkv_proj = ColumnParallelLinear(embed_dim,
3 * embed_dim,
bias=bias,
gather_output=False, gather_output=False,
perform_initialization=False) perform_initialization=False)
self.out_proj = RowParallelLinear(embed_dim, embed_dim, bias=bias, self.out_proj = RowParallelLinear(embed_dim,
embed_dim,
bias=bias,
input_is_parallel=True, input_is_parallel=True,
perform_initialization=False) perform_initialization=False)
self.attn = PagedAttention(self.num_heads, self.head_dim, self.attn = PagedAttention(self.num_heads,
self.head_dim,
scale=self.scaling) scale=self.scaling)
def forward( def forward(
@ -88,8 +97,8 @@ class OPTAttention(nn.Module):
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
key_cache, value_cache = kv_cache key_cache, value_cache = kv_cache
attn_output = self.attn( attn_output = self.attn(q, k, v, key_cache, value_cache,
q, k, v, key_cache, value_cache, input_metadata, cache_event) input_metadata, cache_event)
output, _ = self.out_proj(attn_output) output, _ = self.out_proj(attn_output)
return output return output
@ -109,17 +118,21 @@ class OPTDecoderLayer(nn.Module):
self.activation_fn = get_act_fn(config.activation_function) self.activation_fn = get_act_fn(config.activation_function)
self.self_attn_layer_norm = nn.LayerNorm( self.self_attn_layer_norm = nn.LayerNorm(
self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine) self.embed_dim,
self.fc1 = ColumnParallelLinear(self.embed_dim, config.ffn_dim, elementwise_affine=config.layer_norm_elementwise_affine)
self.fc1 = ColumnParallelLinear(self.embed_dim,
config.ffn_dim,
bias=config.enable_bias, bias=config.enable_bias,
gather_output=False, gather_output=False,
perform_initialization=False) perform_initialization=False)
self.fc2 = RowParallelLinear(config.ffn_dim, self.embed_dim, self.fc2 = RowParallelLinear(config.ffn_dim,
self.embed_dim,
bias=config.enable_bias, bias=config.enable_bias,
input_is_parallel=True, input_is_parallel=True,
perform_initialization=False) perform_initialization=False)
self.final_layer_norm = nn.LayerNorm( self.final_layer_norm = nn.LayerNorm(
self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine) self.embed_dim,
elementwise_affine=config.layer_norm_elementwise_affine)
def forward( def forward(
self, self,
@ -133,8 +146,7 @@ class OPTDecoderLayer(nn.Module):
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
if self.do_layer_norm_before: if self.do_layer_norm_before:
hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states = self.self_attn( hidden_states = self.self_attn(hidden_states=hidden_states,
hidden_states=hidden_states,
kv_cache=kv_cache, kv_cache=kv_cache,
input_metadata=input_metadata, input_metadata=input_metadata,
cache_event=cache_event) cache_event=cache_event)
@ -167,7 +179,8 @@ class OPTDecoder(nn.Module):
self.max_target_positions = config.max_position_embeddings self.max_target_positions = config.max_position_embeddings
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(config.vocab_size, self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.word_embed_proj_dim, config.word_embed_proj_dim,
perform_initialization=False) perform_initialization=False)
# Positional embeddings are replicated (not sharded). # Positional embeddings are replicated (not sharded).
@ -176,26 +189,32 @@ class OPTDecoder(nn.Module):
# Project out & in will be replicated if they exist. # Project out & in will be replicated if they exist.
if config.word_embed_proj_dim != config.hidden_size: if config.word_embed_proj_dim != config.hidden_size:
self.project_out = nn.Linear(config.hidden_size, config.word_embed_proj_dim, bias=False) self.project_out = nn.Linear(config.hidden_size,
config.word_embed_proj_dim,
bias=False)
else: else:
self.project_out = None self.project_out = None
if config.word_embed_proj_dim != config.hidden_size: if config.word_embed_proj_dim != config.hidden_size:
self.project_in = nn.Linear(config.word_embed_proj_dim, config.hidden_size, bias=False) self.project_in = nn.Linear(config.word_embed_proj_dim,
config.hidden_size,
bias=False)
else: else:
self.project_in = None self.project_in = None
# Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility # Note that the only purpose of `config._remove_final_layer_norm` is to
# with checkpoints that have been fine-tuned before transformers v4.20.1 # keep backward compatibility with checkpoints that have been fine-tuned
# before transformers v4.20.1
# see https://github.com/facebookresearch/metaseq/pull/164 # see https://github.com/facebookresearch/metaseq/pull/164
if config.do_layer_norm_before and not config._remove_final_layer_norm: if config.do_layer_norm_before and not config._remove_final_layer_norm:
self.final_layer_norm = nn.LayerNorm( self.final_layer_norm = nn.LayerNorm(
config.hidden_size, elementwise_affine=config.layer_norm_elementwise_affine config.hidden_size,
) elementwise_affine=config.layer_norm_elementwise_affine)
else: else:
self.final_layer_norm = None self.final_layer_norm = None
self.layers = nn.ModuleList([OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)]) self.layers = nn.ModuleList(
[OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)])
def forward( def forward(
self, self,
@ -217,8 +236,8 @@ class OPTDecoder(nn.Module):
else: else:
cache_event = cache_events[i] cache_event = cache_events[i]
layer = self.layers[i] layer = self.layers[i]
hidden_states = layer( hidden_states = layer(hidden_states, kv_caches[i], input_metadata,
hidden_states, kv_caches[i], input_metadata, cache_event) cache_event)
if self.final_layer_norm is not None: if self.final_layer_norm is not None:
hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.final_layer_norm(hidden_states)
@ -241,8 +260,8 @@ class OPTModel(nn.Module):
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor: ) -> torch.Tensor:
return self.decoder( return self.decoder(input_ids, positions, kv_caches, input_metadata,
input_ids, positions, kv_caches, input_metadata, cache_events) cache_events)
class OPTForCausalLM(nn.Module): class OPTForCausalLM(nn.Module):
@ -264,16 +283,19 @@ class OPTForCausalLM(nn.Module):
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]: ) -> Dict[int, SequenceOutputs]:
hidden_states = self.model( hidden_states = self.model(input_ids, positions, kv_caches,
input_ids, positions, kv_caches, input_metadata, cache_events) input_metadata, cache_events)
next_tokens = self.sampler( next_tokens = self.sampler(self.lm_head_weight, hidden_states,
self.lm_head_weight, hidden_states, input_metadata) input_metadata)
return next_tokens return next_tokens
_column_parallel_weights = ["embed_tokens.weight", "fc1.weight", "fc1.bias"] _column_parallel_weights = [
"embed_tokens.weight", "fc1.weight", "fc1.bias"
]
_row_parallel_weights = ["out_proj.weight", "fc2.weight"] _row_parallel_weights = ["out_proj.weight", "fc2.weight"]
def load_weights(self, model_name_or_path: str, def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
use_np_cache: bool = False): use_np_cache: bool = False):
tensor_model_parallel_rank = get_tensor_model_parallel_rank() tensor_model_parallel_rank = get_tensor_model_parallel_rank()
@ -288,16 +310,17 @@ class OPTForCausalLM(nn.Module):
name = "model." + name name = "model." + name
is_attention_weight = False is_attention_weight = False
for stride_id, att_weight_name in enumerate(["q_proj", "k_proj", "v_proj"]): for stride_id, att_weight_name in enumerate(
["q_proj", "k_proj", "v_proj"]):
if att_weight_name not in name: if att_weight_name not in name:
continue continue
param = state_dict[name.replace(att_weight_name, "qkv_proj")] param = state_dict[name.replace(att_weight_name, "qkv_proj")]
shard_size = param.shape[0] // 3 shard_size = param.shape[0] // 3
loaded_weight = loaded_weight[ loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank shard_size * tensor_model_parallel_rank:shard_size *
:shard_size * (tensor_model_parallel_rank + 1)] (tensor_model_parallel_rank + 1)]
param_slice = param.data[shard_size * stride_id param_slice = param.data[shard_size * stride_id:shard_size *
:shard_size * (stride_id + 1)] (stride_id + 1)]
assert param_slice.shape == loaded_weight.shape assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight) param_slice.copy_(loaded_weight)
is_attention_weight = True is_attention_weight = True

View File

@ -0,0 +1,316 @@
# coding=utf-8
# Adapted from
# https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py
# Copyright (c) Alibaba Cloud.
# LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE
"""Inference-only QWen model compatible with HuggingFace weights.
The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
"""
from typing import Dict, List, Optional, Tuple
import torch
from torch import nn
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (
hf_model_weights_iterator,
load_tensor_parallel_weights,
)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding,
ColumnParallelLinear,
RowParallelLinear,
)
from vllm.sequence import SequenceOutputs
from vllm.transformers_utils.configs.qwen import QWenConfig
KVCache = Tuple[torch.Tensor, torch.Tensor]
class QWenMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str = "silu",
):
super().__init__()
self.gate_up_proj = ColumnParallelLinear(
hidden_size,
2 * intermediate_size,
bias=False,
gather_output=False,
perform_initialization=False,
)
self.c_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
input_is_parallel=True,
perform_initialization=False,
)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.c_proj(x)
return x
class QWenAttention(nn.Module):
def __init__(self, hidden_size: int, num_heads: int,
max_position_embeddings: int):
super().__init__()
self.hidden_size = hidden_size
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
)
self.total_num_heads = num_heads
assert self.total_num_heads % tensor_model_parallel_world_size == 0
self.num_heads = (self.total_num_heads //
tensor_model_parallel_world_size)
self.head_dim = hidden_size // self.total_num_heads
# pylint: disable=invalid-name
self.c_attn = ColumnParallelLinear(
hidden_size,
3 * hidden_size,
bias=True,
gather_output=False,
perform_initialization=False,
)
self.c_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
input_is_parallel=True,
perform_initialization=False,
)
self.scaling = self.head_dim**-0.5
self.attn = PagedAttentionWithRoPE(
self.num_heads,
self.head_dim,
self.scaling,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
qkv, _ = self.c_attn(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
k_cache, v_cache = kv_cache
attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
input_metadata, cache_event)
output, _ = self.c_proj(attn_output)
return output
class QWenBlock(nn.Module):
def __init__(self, config: QWenConfig):
super().__init__()
self.ln_1 = RMSNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.attn = QWenAttention(config.n_embd, config.num_attention_heads,
config.max_position_embeddings)
self.ln_2 = RMSNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.mlp = QWenMLP(config.n_embd, config.ffn_hidden_size // 2)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
# Self Attention
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
hidden_states = self.attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.ln_2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class QWenModel(nn.Module):
def __init__(self, config: QWenConfig):
super().__init__()
self.config = config
self.vocab_size = config.vocab_size
vocab_size = ((config.vocab_size + 63) // 64) * 64
self.wte = VocabParallelEmbedding(vocab_size,
config.n_embd,
perform_initialization=False)
self.h = nn.ModuleList(
[QWenBlock(config) for _ in range(config.num_hidden_layers)])
self.ln_f = RMSNorm(config.n_embd, eps=config.layer_norm_epsilon)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.wte(input_ids)
for i in range(len(self.h)):
if cache_events is None:
cache_event = None
else:
cache_event = cache_events[i]
layer = self.h[i]
hidden_states = layer(
positions,
hidden_states,
kv_caches[i],
input_metadata,
cache_event,
)
hidden_states = self.ln_f(hidden_states)
return hidden_states
class QWenLMHeadModel(nn.Module):
def __init__(self, config: QWenConfig):
super().__init__()
self.config = config
self.transformer = QWenModel(config)
vocab_size = ((config.vocab_size + 63) // 64) * 64
self.lm_head = ColumnParallelLinear(
config.n_embd,
vocab_size,
bias=False,
gather_output=False,
perform_initialization=False,
)
self.sampler = Sampler(config.vocab_size)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]:
hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
input_metadata)
return next_tokens
_column_parallel_weights = ["wte.weight", "lm_head.weight"]
_row_parallel_weights = ["c_proj.weight"]
def load_weights(
self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
use_np_cache: bool = False,
):
tp_world_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, use_np_cache):
if "rotary_emb.inv_freq" in name:
continue
if "wte" in name or "lm_head" in name:
# Consider padding in the vocab size.
param = state_dict[name]
padded_vocab_size = param.shape[0] * tp_world_size
num_extra_rows = padded_vocab_size - self.config.vocab_size
extra_rows = torch.empty(num_extra_rows,
loaded_weight.shape[1])
extra_rows = extra_rows.to(loaded_weight)
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
if "c_attn" in name:
total_num_heads = self.config.num_attention_heads
hidden_size = self.config.hidden_size
head_size = hidden_size // total_num_heads
num_heads = total_num_heads // tp_world_size
head_start = tp_rank * num_heads
head_end = (tp_rank + 1) * num_heads
if "weight" in name:
loaded_weight = loaded_weight.view(3, total_num_heads,
head_size, hidden_size)
loaded_weight = loaded_weight[:, head_start:head_end, :, :]
loaded_weight = loaded_weight.reshape(-1, hidden_size)
elif "bias" in name:
loaded_weight = loaded_weight.view(3, total_num_heads,
head_size)
loaded_weight = loaded_weight[:, head_start:head_end, :]
loaded_weight = loaded_weight.reshape(-1)
is_gate_up_weight = False
for stride_id, weight_name in enumerate(["w2", "w1"]):
if weight_name not in name:
continue
param = state_dict[name.replace(weight_name, "gate_up_proj")]
shard_size = param.shape[0] // 2
loaded_weight = loaded_weight[shard_size * tp_rank:shard_size *
(tp_rank + 1)]
param_slice = param.data[shard_size * stride_id:shard_size *
(stride_id + 1)]
assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight)
is_gate_up_weight = True
break
if is_gate_up_weight:
continue
param = state_dict[name]
load_tensor_parallel_weights(
param,
loaded_weight,
name,
self._column_parallel_weights,
self._row_parallel_weights,
tp_rank,
)

View File

@ -1,9 +1,6 @@
import vllm.model_executor.parallel_utils.parallel_state import vllm.model_executor.parallel_utils.parallel_state
import vllm.model_executor.parallel_utils.tensor_parallel import vllm.model_executor.parallel_utils.tensor_parallel
# Alias parallel_state as mpu, its legacy name
mpu = parallel_state
__all__ = [ __all__ = [
"parallel_state", "parallel_state",
"tensor_parallel", "tensor_parallel",

View File

@ -44,7 +44,6 @@ _PIPELINE_GLOBAL_RANKS = None
# rank when broadcasting weights from src to all other data parallel ranks # rank when broadcasting weights from src to all other data parallel ranks
_DATA_PARALLEL_GLOBAL_RANKS = None _DATA_PARALLEL_GLOBAL_RANKS = None
_ALL_REDUCE_LAUNCHER: Optional['GraphAllReduce'] = None
def initialize_model_parallel( def initialize_model_parallel(
tensor_model_parallel_size: int = 1, tensor_model_parallel_size: int = 1,
@ -196,20 +195,6 @@ def initialize_model_parallel(
if rank in ranks: if rank in ranks:
_POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks _POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks
def initialize_all_reduce_launcher(
max_num_tokens: int,
hidden_size: int,
dtype: torch.dtype,
disable_graph: bool = False,
) -> None:
global _ALL_REDUCE_LAUNCHER
_ALL_REDUCE_LAUNCHER = GraphAllReduce(
max_num_tokens=max_num_tokens,
hidden_size=hidden_size,
dtype=dtype,
disable_graph=disable_graph,
)
def model_parallel_is_initialized(): def model_parallel_is_initialized():
"""Check if model and data parallel groups are initialized.""" """Check if model and data parallel groups are initialized."""
if _TENSOR_MODEL_PARALLEL_GROUP is None or \ if _TENSOR_MODEL_PARALLEL_GROUP is None or \
@ -458,6 +443,7 @@ def get_pipeline_model_parallel_last_rank():
last_rank_local = get_pipeline_model_parallel_world_size() - 1 last_rank_local = get_pipeline_model_parallel_world_size() - 1
return _PIPELINE_GLOBAL_RANKS[last_rank_local] return _PIPELINE_GLOBAL_RANKS[last_rank_local]
def get_pipeline_model_parallel_next_rank(): def get_pipeline_model_parallel_next_rank():
"""Return the global rank that follows the caller in the pipeline""" """Return the global rank that follows the caller in the pipeline"""
assert _PIPELINE_GLOBAL_RANKS is not None, \ assert _PIPELINE_GLOBAL_RANKS is not None, \
@ -485,10 +471,6 @@ def get_data_parallel_rank():
"""Return my rank for the data parallel group.""" """Return my rank for the data parallel group."""
return torch.distributed.get_rank(group=get_data_parallel_group()) return torch.distributed.get_rank(group=get_data_parallel_group())
def get_all_reduce_launcher() -> 'GraphAllReduce':
assert _ALL_REDUCE_LAUNCHER is not None, 'all reduce launcher is not initialized'
return _ALL_REDUCE_LAUNCHER
def destroy_model_parallel(): def destroy_model_parallel():
"""Set the groups to none.""" """Set the groups to none."""
global _MODEL_PARALLEL_GROUP global _MODEL_PARALLEL_GROUP
@ -515,56 +497,3 @@ def destroy_model_parallel():
_MPU_TENSOR_MODEL_PARALLEL_RANK = None _MPU_TENSOR_MODEL_PARALLEL_RANK = None
global _MPU_PIPELINE_MODEL_PARALLEL_RANK global _MPU_PIPELINE_MODEL_PARALLEL_RANK
_MPU_PIPELINE_MODEL_PARALLEL_RANK = None _MPU_PIPELINE_MODEL_PARALLEL_RANK = None
class GraphAllReduce:
def __init__(
self,
max_num_tokens: int,
hidden_size: int,
dtype: torch.dtype,
disable_graph: bool = False,
) -> None:
self.max_num_tokens = max_num_tokens
self.hidden_size = hidden_size
self.disable_graph = disable_graph
tp_world_size = get_tensor_model_parallel_world_size()
if tp_world_size == 1:
return
self.group = get_tensor_model_parallel_group()
self.buffer = torch.empty(
size=(max_num_tokens, hidden_size),
dtype=dtype,
device='cuda',
)
# Build graphs for different number of tokens.
if not self.disable_graph:
self.graphs = {}
for num_tokens in range(8, max_num_tokens + 1, 8):
self.graphs[num_tokens] = self._build_graph(num_tokens)
def _build_graph(self, num_tokens: int) -> torch.cuda.CUDAGraph:
# Warm up.
torch.distributed.all_reduce(self.buffer[:num_tokens], group=self.group)
torch.cuda.synchronize()
# Build graph.
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
torch.distributed.all_reduce(
self.buffer[:num_tokens], group=self.group)
torch.cuda.synchronize()
return graph
def launch(self, x: torch.Tensor) -> torch.Tensor:
# NOTE: x must be a slice of self.buffer.
num_tokens = x.shape[0]
if self.disable_graph:
torch.distributed.all_reduce(x, group=self.group)
else:
self.graphs[num_tokens].replay()
return x

View File

@ -12,6 +12,7 @@ from .mappings import (
copy_to_tensor_model_parallel_region, copy_to_tensor_model_parallel_region,
gather_from_tensor_model_parallel_region, gather_from_tensor_model_parallel_region,
gather_from_sequence_parallel_region, gather_from_sequence_parallel_region,
reduce_from_tensor_model_parallel_region,
scatter_to_tensor_model_parallel_region, scatter_to_tensor_model_parallel_region,
scatter_to_sequence_parallel_region, scatter_to_sequence_parallel_region,
) )
@ -38,7 +39,7 @@ __all__ = [
"copy_to_tensor_model_parallel_region", "copy_to_tensor_model_parallel_region",
"gather_from_tensor_model_parallel_region", "gather_from_tensor_model_parallel_region",
"gather_from_sequence_parallel_region", "gather_from_sequence_parallel_region",
# "reduce_from_tensor_model_parallel_region", "reduce_from_tensor_model_parallel_region",
"scatter_to_tensor_model_parallel_region", "scatter_to_tensor_model_parallel_region",
"scatter_to_sequence_parallel_region", "scatter_to_sequence_parallel_region",
# random.py # random.py

View File

@ -14,7 +14,6 @@ from torch.nn.parameter import Parameter
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
get_all_reduce_launcher,
) )
from .mappings import ( from .mappings import (
copy_to_tensor_model_parallel_region, copy_to_tensor_model_parallel_region,
@ -248,8 +247,8 @@ class ColumnParallelLinear(torch.nn.Module):
self.output_size = output_size self.output_size = output_size
self.gather_output = gather_output self.gather_output = gather_output
# Divide the weight matrix along the last dimension. # Divide the weight matrix along the last dimension.
world_size = get_tensor_model_parallel_world_size() self.world_size = get_tensor_model_parallel_world_size()
self.output_size_per_partition = divide(output_size, world_size) self.output_size_per_partition = divide(output_size, self.world_size)
self.skip_bias_add = skip_bias_add self.skip_bias_add = skip_bias_add
if params_dtype is None: if params_dtype is None:
@ -350,6 +349,7 @@ class RowParallelLinear(torch.nn.Module):
params_dtype: params_dtype:
use_cpu_initialization: use_cpu_initialization:
perform_initialization: perform_initialization:
reduce_results:
""" """
def __init__(self, input_size, output_size, *, def __init__(self, input_size, output_size, *,
@ -360,6 +360,7 @@ class RowParallelLinear(torch.nn.Module):
params_dtype=None, params_dtype=None,
use_cpu_initialization=False, use_cpu_initialization=False,
perform_initialization=True, perform_initialization=True,
reduce_results=True,
): ):
super(RowParallelLinear, self).__init__() super(RowParallelLinear, self).__init__()
@ -367,14 +368,19 @@ class RowParallelLinear(torch.nn.Module):
self.input_size = input_size self.input_size = input_size
self.output_size = output_size self.output_size = output_size
self.input_is_parallel = input_is_parallel self.input_is_parallel = input_is_parallel
self.reduce_results = reduce_results
if params_dtype is None: if params_dtype is None:
params_dtype = torch.get_default_dtype() params_dtype = torch.get_default_dtype()
# Divide the weight matrix along the last dimension. # Divide the weight matrix along the last dimension.
world_size = get_tensor_model_parallel_world_size() self.world_size = get_tensor_model_parallel_world_size()
self.input_size_per_partition = divide(input_size, world_size) self.input_size_per_partition = divide(input_size, self.world_size)
self.skip_bias_add = skip_bias_add self.skip_bias_add = skip_bias_add
if not reduce_results and (bias and not skip_bias_add):
raise ValueError("When not reduce the results, adding bias to the "
"results can lead to incorrect results")
# Parameters. # Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result # Note: torch.nn.functional.linear performs XA^T + b and as a result
# we allocate the transpose. # we allocate the transpose.
@ -427,17 +433,12 @@ class RowParallelLinear(torch.nn.Module):
input_parallel = input_ input_parallel = input_
else: else:
input_parallel = scatter_to_tensor_model_parallel_region(input_) input_parallel = scatter_to_tensor_model_parallel_region(input_)
if get_tensor_model_parallel_world_size() == 1:
# Matrix multiply. # Matrix multiply.
output_ = F.linear(input_parallel, self.weight) output_parallel = F.linear(input_parallel, self.weight)
if self.reduce_results and self.world_size > 1:
output_ = reduce_from_tensor_model_parallel_region(output_parallel)
else: else:
# Matrix multiply. output_ = output_parallel
all_reduce_launcher = get_all_reduce_launcher()
num_tokens = input_parallel.shape[0]
output_buffer = all_reduce_launcher.buffer[:num_tokens]
torch.matmul(input_parallel, self.weight_t, out=output_buffer)
# All-reduce across all the partitions.
output_ = all_reduce_launcher.launch(output_buffer)
if not self.skip_bias_add: if not self.skip_bias_add:
output = output_ + self.bias if self.bias is not None else output_ output = output_ + self.bias if self.bias is not None else output_

View File

@ -39,14 +39,17 @@ def hf_model_weights_iterator(
else: else:
hf_folder = model_name_or_path hf_folder = model_name_or_path
hf_bin_files = glob.glob(os.path.join(hf_folder, "*.bin")) hf_bin_files = [
x for x in glob.glob(os.path.join(hf_folder, "*.bin"))
if not x.endswith("training_args.bin")
]
if use_np_cache: if use_np_cache:
# Convert the model weights from torch tensors to numpy arrays for # Convert the model weights from torch tensors to numpy arrays for
# faster loading. # faster loading.
np_folder = os.path.join(hf_folder, 'np') np_folder = os.path.join(hf_folder, "np")
os.makedirs(np_folder, exist_ok=True) os.makedirs(np_folder, exist_ok=True)
weight_names_file = os.path.join(np_folder, 'weight_names.json') weight_names_file = os.path.join(np_folder, "weight_names.json")
with lock: with lock:
if not os.path.exists(weight_names_file): if not os.path.exists(weight_names_file):
weight_names = [] weight_names = []
@ -57,10 +60,10 @@ def hf_model_weights_iterator(
with open(param_path, "wb") as f: with open(param_path, "wb") as f:
np.save(f, param.cpu().detach().numpy()) np.save(f, param.cpu().detach().numpy())
weight_names.append(name) weight_names.append(name)
with open(weight_names_file, 'w') as f: with open(weight_names_file, "w") as f:
json.dump(weight_names, f) json.dump(weight_names, f)
with open(weight_names_file, 'r') as f: with open(weight_names_file, "r") as f:
weight_names = json.load(f) weight_names = json.load(f)
for name in weight_names: for name in weight_names:
@ -73,6 +76,8 @@ def hf_model_weights_iterator(
state = torch.load(bin_file, map_location="cpu") state = torch.load(bin_file, map_location="cpu")
for name, param in state.items(): for name, param in state.items():
yield name, param yield name, param
del state
torch.cuda.empty_cache()
def load_tensor_parallel_weights( def load_tensor_parallel_weights(
@ -86,19 +91,20 @@ def load_tensor_parallel_weights(
for p in column_parallel_weight_names: for p in column_parallel_weight_names:
if p in param_name: if p in param_name:
shard_size = param.shape[0] shard_size = param.shape[0]
loaded_weight = loaded_weight[ start_idx = tensor_model_parallel_rank * shard_size
shard_size * tensor_model_parallel_rank end_idx = (tensor_model_parallel_rank + 1) * shard_size
:shard_size * (tensor_model_parallel_rank + 1)] loaded_weight = loaded_weight[start_idx:end_idx]
break break
for p in row_parallel_weight_names: for p in row_parallel_weight_names:
if p in param_name: if p in param_name:
shard_size = param.shape[1] shard_size = param.shape[1]
loaded_weight = loaded_weight[ start_idx = tensor_model_parallel_rank * shard_size
:, end_idx = (tensor_model_parallel_rank + 1) * shard_size
shard_size * tensor_model_parallel_rank loaded_weight = loaded_weight[:, start_idx:end_idx]
:shard_size * (tensor_model_parallel_rank + 1)]
break break
assert param.shape == loaded_weight.shape assert param.shape == loaded_weight.shape, (
f"{param_name} shape mismatch between model and checkpoint: "
f"{param.shape} != {loaded_weight.shape}")
param.data.copy_(loaded_weight) param.data.copy_(loaded_weight)

View File

@ -55,6 +55,7 @@ class RequestOutput:
outputs: The output sequences of the request. outputs: The output sequences of the request.
finished: Whether the whole request is finished. finished: Whether the whole request is finished.
""" """
def __init__( def __init__(
self, self,
request_id: str, request_id: str,
@ -75,8 +76,9 @@ class RequestOutput:
n = seq_group.sampling_params.n n = seq_group.sampling_params.n
seqs = seq_group.get_seqs() seqs = seq_group.get_seqs()
assert n <= len(seqs) assert n <= len(seqs)
sorted_seqs = sorted( sorted_seqs = sorted(seqs,
seqs, key=lambda seq: seq.get_cumulative_logprob(), reverse=True) key=lambda seq: seq.get_cumulative_logprob(),
reverse=True)
top_n_seqs = sorted_seqs[:n] top_n_seqs = sorted_seqs[:n]
# Create the outputs. # Create the outputs.

View File

@ -1,6 +1,8 @@
"""Sampling parameters for text generation.""" """Sampling parameters for text generation."""
from typing import List, Optional, Union from typing import List, Optional, Union
_SAMPLING_EPS = 1e-5
class SamplingParams: class SamplingParams:
"""Sampling parameters for text generation. """Sampling parameters for text generation.
@ -50,7 +52,7 @@ class SamplingParams:
top_p: float = 1.0, top_p: float = 1.0,
top_k: int = -1, top_k: int = -1,
use_beam_search: bool = False, use_beam_search: bool = False,
stop: Union[str, List[str]] = [], stop: Union[None, str, List[str]] = None,
ignore_eos: bool = False, ignore_eos: bool = False,
max_tokens: int = 16, max_tokens: int = 16,
logprobs: Optional[int] = None, logprobs: Optional[int] = None,
@ -63,15 +65,20 @@ class SamplingParams:
self.top_p = top_p self.top_p = top_p
self.top_k = top_k self.top_k = top_k
self.use_beam_search = use_beam_search self.use_beam_search = use_beam_search
self.stop = [stop] if isinstance(stop, str) else list(stop) if stop is None:
self.stop = []
elif isinstance(stop, str):
self.stop = [stop]
else:
self.stop = list(stop)
self.ignore_eos = ignore_eos self.ignore_eos = ignore_eos
self.max_tokens = max_tokens self.max_tokens = max_tokens
self.logprobs = logprobs self.logprobs = logprobs
self._verify_args() self._verify_args()
if self.use_beam_search: if self.use_beam_search:
self._verity_beam_search() self._verify_beam_search()
elif self.temperature == 0.0: elif self.temperature < _SAMPLING_EPS:
# Zero temperature means greedy sampling. # Zero temperature means greedy sampling.
self._verify_greedy_sampling() self._verify_greedy_sampling()
@ -102,13 +109,13 @@ class SamplingParams:
raise ValueError( raise ValueError(
f"logprobs must be non-negative, got {self.logprobs}.") f"logprobs must be non-negative, got {self.logprobs}.")
def _verity_beam_search(self) -> None: def _verify_beam_search(self) -> None:
if self.best_of == 1: if self.best_of == 1:
raise ValueError("best_of must be greater than 1 when using beam " raise ValueError("best_of must be greater than 1 when using beam "
f"search. Got {self.best_of}.") f"search. Got {self.best_of}.")
if self.temperature > 0.0: if self.temperature > _SAMPLING_EPS:
raise ValueError("temperature must be 0 when using beam search.") raise ValueError("temperature must be 0 when using beam search.")
if self.top_p < 1.0: if self.top_p < 1.0 - _SAMPLING_EPS:
raise ValueError("top_p must be 1 when using beam search.") raise ValueError("top_p must be 1 when using beam search.")
if self.top_k != -1: if self.top_k != -1:
raise ValueError("top_k must be -1 when using beam search.") raise ValueError("top_k must be -1 when using beam search.")
@ -117,7 +124,7 @@ class SamplingParams:
if self.best_of > 1: if self.best_of > 1:
raise ValueError("best_of must be 1 when using greedy sampling." raise ValueError("best_of must be 1 when using greedy sampling."
f"Got {self.best_of}.") f"Got {self.best_of}.")
if self.top_p < 1.0: if self.top_p < 1.0 - _SAMPLING_EPS:
raise ValueError("top_p must be 1 when using greedy sampling.") raise ValueError("top_p must be 1 when using greedy sampling.")
if self.top_k != -1: if self.top_k != -1:
raise ValueError("top_k must be -1 when using greedy sampling.") raise ValueError("top_k must be -1 when using greedy sampling.")

View File

@ -1,3 +1,4 @@
"""Sequence and its related classes."""
import copy import copy
import enum import enum
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
@ -7,12 +8,14 @@ from vllm.sampling_params import SamplingParams
class SequenceStatus(enum.Enum): class SequenceStatus(enum.Enum):
"""Status of a sequence."""
WAITING = enum.auto() WAITING = enum.auto()
RUNNING = enum.auto() RUNNING = enum.auto()
SWAPPED = enum.auto() SWAPPED = enum.auto()
FINISHED_STOPPED = enum.auto() FINISHED_STOPPED = enum.auto()
FINISHED_LENGTH_CAPPED = enum.auto() FINISHED_LENGTH_CAPPED = enum.auto()
FINISHED_ABORTED = enum.auto() FINISHED_ABORTED = enum.auto()
FINISHED_IGNORED = enum.auto()
@staticmethod @staticmethod
def is_finished(status: "SequenceStatus") -> bool: def is_finished(status: "SequenceStatus") -> bool:
@ -20,6 +23,7 @@ class SequenceStatus(enum.Enum):
SequenceStatus.FINISHED_STOPPED, SequenceStatus.FINISHED_STOPPED,
SequenceStatus.FINISHED_LENGTH_CAPPED, SequenceStatus.FINISHED_LENGTH_CAPPED,
SequenceStatus.FINISHED_ABORTED, SequenceStatus.FINISHED_ABORTED,
SequenceStatus.FINISHED_IGNORED,
] ]
@staticmethod @staticmethod
@ -30,12 +34,25 @@ class SequenceStatus(enum.Enum):
finish_reason = "length" finish_reason = "length"
elif status == SequenceStatus.FINISHED_ABORTED: elif status == SequenceStatus.FINISHED_ABORTED:
finish_reason = "abort" finish_reason = "abort"
elif status == SequenceStatus.FINISHED_IGNORED:
finish_reason = "length"
else: else:
finish_reason = None finish_reason = None
return finish_reason return finish_reason
class SequenceData: class SequenceData:
"""Data associated with a sequence.
Args:
prompt_token_ids: The token IDs of the prompt.
Attributes:
prompt_token_ids: The token IDs of the prompt.
output_token_ids: The token IDs of the output.
cumulative_logprob: The cumulative log probability of the output.
"""
def __init__( def __init__(
self, self,
@ -71,6 +88,15 @@ class SequenceData:
class Sequence: class Sequence:
"""Stores the data, status, and block information of a sequence.
Args:
seq_id: The ID of the sequence.
prompt: The prompt of the sequence.
prompt_token_ids: The token IDs of the prompt.
block_size: The block size of the sequence. Should be the same as the
block size used by the block manager and cache engine.
"""
def __init__( def __init__(
self, self,
@ -101,7 +127,8 @@ class Sequence:
self.logical_token_blocks.append(block) self.logical_token_blocks.append(block)
def _append_tokens_to_blocks(self, token_ids: List[int]) -> None: def _append_tokens_to_blocks(self, token_ids: List[int]) -> None:
while token_ids: cursor = 0
while cursor < len(token_ids):
if not self.logical_token_blocks: if not self.logical_token_blocks:
self._append_logical_block() self._append_logical_block()
@ -111,8 +138,9 @@ class Sequence:
last_block = self.logical_token_blocks[-1] last_block = self.logical_token_blocks[-1]
num_empty_slots = last_block.get_num_empty_slots() num_empty_slots = last_block.get_num_empty_slots()
last_block.append_tokens(token_ids[:num_empty_slots]) last_block.append_tokens(token_ids[cursor:cursor +
token_ids = token_ids[num_empty_slots:] num_empty_slots])
cursor += num_empty_slots
def append_token_id( def append_token_id(
self, self,
@ -145,19 +173,27 @@ class Sequence:
def is_finished(self) -> bool: def is_finished(self) -> bool:
return SequenceStatus.is_finished(self.status) return SequenceStatus.is_finished(self.status)
def fork(self, child_seq: 'Sequence') -> None: def fork(self, child_seq: "Sequence") -> None:
child_seq.logical_token_blocks = copy.deepcopy(self.logical_token_blocks) child_seq.logical_token_blocks = copy.deepcopy(
self.logical_token_blocks)
child_seq.output_logprobs = copy.deepcopy(self.output_logprobs) child_seq.output_logprobs = copy.deepcopy(self.output_logprobs)
child_seq.data = copy.deepcopy(self.data) child_seq.data = copy.deepcopy(self.data)
return None
def __repr__(self) -> str: def __repr__(self) -> str:
return (f'Sequence(seq_id={self.seq_id}, ' return (f"Sequence(seq_id={self.seq_id}, "
f'status={self.status.name}, ' f"status={self.status.name}, "
f'num_blocks={len(self.logical_token_blocks)})') f"num_blocks={len(self.logical_token_blocks)})")
class SequenceGroup: class SequenceGroup:
"""A group of sequences that are generated from the same prompt.
Args:
request_id: The ID of the request.
seqs: The list of sequences.
sampling_params: The sampling parameters used to generate the outputs.
arrival_time: The arrival time of the request.
"""
def __init__( def __init__(
self, self,
@ -187,7 +223,7 @@ class SequenceGroup:
for seq in self.seqs: for seq in self.seqs:
if seq.seq_id == seq_id: if seq.seq_id == seq_id:
return seq return seq
raise ValueError(f'Sequence {seq_id} not found.') raise ValueError(f"Sequence {seq_id} not found.")
def is_finished(self) -> bool: def is_finished(self) -> bool:
return all(seq.is_finished() for seq in self.seqs) return all(seq.is_finished() for seq in self.seqs)
@ -199,14 +235,25 @@ class SequenceGroup:
class SequenceGroupMetadata: class SequenceGroupMetadata:
"""Metadata for a sequence group. Used to create `InputMetadata`.
Args:
request_id: The ID of the request.
is_prompt: Whether the request is at prompt stage.
seq_data: The sequence data. (Seq id -> sequence data)
sampling_params: The sampling parameters used to generate the outputs.
block_tables: The block tables. (Seq id -> list of physical block
numbers)
"""
def __init__( def __init__(
self, self,
request_id: str, request_id: str,
is_prompt: bool, is_prompt: bool,
seq_data: Dict[int, SequenceData], # Seq id -> sequence data. seq_data: Dict[int, SequenceData],
sampling_params: SamplingParams, sampling_params: SamplingParams,
block_tables: Dict[int, List[int]], # Seq id -> list of physical block numbers. block_tables: Dict[int, List[int]],
) -> None: ) -> None:
self.request_id = request_id self.request_id = request_id
self.is_prompt = is_prompt self.is_prompt = is_prompt
@ -216,13 +263,23 @@ class SequenceGroupMetadata:
class SequenceOutputs: class SequenceOutputs:
"""The model output associated with a sequence.
Args:
seq_id: The ID of the sequence.
parent_seq_id: The ID of the parent sequence (for forking in beam
search).
output_token: The output token ID.
logprobs: The logprobs of the output token.
(Token id -> logP(x_i+1 | x_0, ..., x_i))
"""
def __init__( def __init__(
self, self,
seq_id: int, seq_id: int,
parent_seq_id: int, parent_seq_id: int,
output_token: int, output_token: int,
logprobs: Dict[int, float], # Token id -> logP(x_i+1 | x_0, ..., x_i). logprobs: Dict[int, float],
) -> None: ) -> None:
self.seq_id = seq_id self.seq_id = seq_id
self.parent_seq_id = parent_seq_id self.parent_seq_id = parent_seq_id
@ -230,15 +287,15 @@ class SequenceOutputs:
self.logprobs = logprobs self.logprobs = logprobs
def __repr__(self) -> str: def __repr__(self) -> str:
return (f'SequenceOutputs(seq_id={self.seq_id}, ' return (f"SequenceOutputs(seq_id={self.seq_id}, "
f'parent_seq_id={self.parent_seq_id}, ' f"parent_seq_id={self.parent_seq_id}, "
f'output_token={self.output_token}), ' f"output_token={self.output_token}), "
f'logprobs={self.logprobs}') f"logprobs={self.logprobs}")
def __eq__(self, other: object) -> bool: def __eq__(self, other: object) -> bool:
if not isinstance(other, SequenceOutputs): if not isinstance(other, SequenceOutputs):
return NotImplemented return NotImplemented
return (self.seq_id == other.seq_id and return (self.seq_id == other.seq_id
self.parent_seq_id == other.parent_seq_id and and self.parent_seq_id == other.parent_seq_id
self.output_token == other.output_token and and self.output_token == other.output_token
self.logprobs == other.logprobs) and self.logprobs == other.logprobs)

View File

View File

@ -0,0 +1,33 @@
from transformers import AutoConfig, PretrainedConfig
from vllm.transformers_utils.configs import * # pylint: disable=wildcard-import
_CONFIG_REGISTRY = {
"mpt": MPTConfig,
"baichuan": BaiChuanConfig,
"aquila": AquilaConfig,
"qwen": QWenConfig,
"RefinedWeb": RWConfig, # For tiiuae/falcon-40b(-instruct)
"RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct)
}
def get_config(model: str, trust_remote_code: bool) -> PretrainedConfig:
try:
config = AutoConfig.from_pretrained(
model, trust_remote_code=trust_remote_code)
except ValueError as e:
if (not trust_remote_code and
"requires you to execute the configuration file" in str(e)):
err_msg = (
"Failed to load the model config. If the model is a custom "
"model not yet available in the HuggingFace transformers "
"library, consider setting `trust_remote_code=True` in LLM "
"or using the `--trust-remote-code` flag in the CLI.")
raise RuntimeError(err_msg) from e
else:
raise e
if config.model_type in _CONFIG_REGISTRY:
config_class = _CONFIG_REGISTRY[config.model_type]
config = config_class.from_pretrained(model)
return config

View File

@ -0,0 +1,16 @@
from vllm.transformers_utils.configs.mpt import MPTConfig
from vllm.transformers_utils.configs.baichuan import BaiChuanConfig
from vllm.transformers_utils.configs.aquila import AquilaConfig
from vllm.transformers_utils.configs.qwen import QWenConfig
# RWConfig is for the original tiiuae/falcon-40b(-instruct) and
# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
# `FalconConfig` class from the official HuggingFace transformers library.
from vllm.transformers_utils.configs.falcon import RWConfig
__all__ = [
"MPTConfig",
"BaiChuanConfig",
"AquilaConfig",
"QWenConfig",
"RWConfig",
]

View File

@ -0,0 +1,63 @@
# coding=utf-8
# Copyright 2023 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Aquila model configuration"""
from transformers import PretrainedConfig
class AquilaConfig(PretrainedConfig):
model_type = "aquila"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=100008,
hidden_size=4096,
intermediate_size=11008,
num_hidden_layers=32,
num_attention_heads=32,
hidden_act="silu",
max_position_embeddings=2048,
initializer_range=0.006,
rms_norm_eps=1e-5,
use_cache=True,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
tie_word_embeddings=False,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)

View File

@ -0,0 +1,62 @@
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from transformers.configuration_utils import PretrainedConfig
class BaiChuanConfig(PretrainedConfig):
model_type = "baichuan"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=64000,
hidden_size=4096,
intermediate_size=11008,
num_hidden_layers=32,
num_attention_heads=32,
hidden_act="silu",
max_position_embeddings=4096,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
tie_word_embeddings=False,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)

View File

@ -0,0 +1,87 @@
# Adapted from
# https://huggingface.co/tiiuae/falcon-7b/blob/main/configuration_RW.py
# Copyright 2023 The vLLM team.
# Copyright 2022 the Big Science Workshop and HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Falcon configuration"""
from transformers.configuration_utils import PretrainedConfig
class RWConfig(PretrainedConfig):
model_type = "falcon"
keys_to_ignore_at_inference = ["past_key_values"]
attribute_map = {
"num_hidden_layers": "n_layer",
"num_attention_heads": "n_head",
"num_kv_heads": "n_head_kv",
}
def __init__(
self,
vocab_size=250880,
hidden_size=64,
n_layer=2,
n_head=8,
layer_norm_epsilon=1e-5,
initializer_range=0.02,
use_cache=True,
bos_token_id=1,
eos_token_id=2,
hidden_dropout=0.0,
attention_dropout=0.0,
multi_query=True,
n_head_kv=None,
alibi=False,
bias=False,
parallel_attn=False,
new_decoder_architecture=False,
**kwargs,
) -> None:
self.vocab_size = vocab_size
# Backward compatibility with n_embed kwarg
n_embed = kwargs.pop("n_embed", None)
self.hidden_size = hidden_size if n_embed is None else n_embed
self.n_layer = n_layer
self.n_head = n_head
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_range = initializer_range
self.use_cache = use_cache
self.hidden_dropout = hidden_dropout
self.attention_dropout = attention_dropout
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.multi_query = multi_query
self.n_head_kv = 1 if n_head_kv is None else n_head_kv
self.alibi = alibi
self.bias = bias
self.parallel_attn = parallel_attn
self.new_decoder_architecture = new_decoder_architecture
if self.hidden_size == 8192:
# Hack for falcon-40b
self.new_decoder_architecture = True
super().__init__(bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
**kwargs)
@property
def head_dim(self):
return self.hidden_size // self.n_head
@property
def rotary(self):
return not self.alibi

View File

@ -0,0 +1,74 @@
# Adapted from
# https://huggingface.co/mosaicml/mpt-7b/blob/main/configuration_mpt.py
from typing import Any, Dict, Optional, Union
from transformers import PretrainedConfig
_ATTN_CONFIG_DEFAULTS = {
"attn_type": "multihead_attention",
"attn_pdrop": 0.0,
"attn_impl": "triton",
"qk_ln": False,
"clip_qkv": None,
"softmax_scale": None,
"prefix_lm": False,
"attn_uses_sequence_id": False,
"alibi": False,
"alibi_bias_max": 8,
}
class MPTConfig(PretrainedConfig):
model_type = "mpt"
attribute_map = {
"hidden_size": "d_model",
"num_attention_heads": "n_heads",
"num_hidden_layers": "n_layers",
}
def __init__(
self,
d_model: int = 2048,
n_heads: int = 16,
n_layers: int = 24,
expansion_ratio: int = 4,
max_seq_len: int = 2048,
vocab_size: int = 50368,
resid_pdrop: float = 0.0,
emb_pdrop: float = 0.0,
learned_pos_emb: bool = True,
attn_config: Optional[Dict[str, Any]] = None,
init_device: str = "cpu",
logit_scale: Optional[Union[float, str]] = None,
no_bias: bool = False,
verbose: int = 0,
embedding_fraction: float = 1.0,
norm_type: str = "low_precision_layernorm",
use_cache: bool = False,
**kwargs,
) -> None:
self.d_model = d_model
self.n_heads = n_heads
self.n_layers = n_layers
self.expansion_ratio = expansion_ratio
self.max_seq_len = max_seq_len
self.vocab_size = vocab_size
self.resid_pdrop = resid_pdrop
self.emb_pdrop = emb_pdrop
self.learned_pos_emb = learned_pos_emb
if attn_config is None:
self.attn_config = _ATTN_CONFIG_DEFAULTS
else:
self.attn_config = attn_config
self.init_device = init_device
self.logit_scale = logit_scale
self.no_bias = no_bias
self.verbose = verbose
self.embedding_fraction = embedding_fraction
self.norm_type = norm_type
self.use_cache = use_cache
if "name" in kwargs:
del kwargs["name"]
if "loss_fn" in kwargs:
del kwargs["loss_fn"]
super().__init__(**kwargs)

View File

@ -0,0 +1,71 @@
# Copyright (c) Alibaba Cloud.
# LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE
from transformers import PretrainedConfig
class QWenConfig(PretrainedConfig):
model_type = "qwen"
keys_to_ignore_at_inference = ["past_key_values"]
attribute_map = {
"hidden_size": "n_embd",
"num_attention_heads": "n_head",
"max_position_embeddings": "n_positions",
"num_hidden_layers": "n_layer",
}
def __init__(
self,
vocab_size=151851,
n_embd=4096,
n_layer=32,
n_head=32,
n_inner=None,
embd_pdrop=0.0,
attn_pdrop=0.0,
layer_norm_epsilon=1e-5,
initializer_range=0.02,
scale_attn_weights=True,
use_cache=True,
eos_token_id=151643,
apply_residual_connection_post_layernorm=False,
bf16=True,
kv_channels=128,
rotary_pct=1.0,
rotary_emb_base=10000,
use_dynamic_ntk=False,
use_logn_attn=False,
use_flash_attn=True,
ffn_hidden_size=22016,
no_bias=True,
tie_word_embeddings=False,
**kwargs,
):
self.eos_token_id = eos_token_id
super().__init__(eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs)
self.vocab_size = vocab_size
self.n_embd = n_embd
self.n_layer = n_layer
self.n_head = n_head
self.n_inner = n_inner
self.embd_pdrop = embd_pdrop
self.attn_pdrop = attn_pdrop
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_range = initializer_range
self.scale_attn_weights = scale_attn_weights
self.use_cache = use_cache
self.apply_residual_connection_post_layernorm = (
apply_residual_connection_post_layernorm)
self.bf16 = bf16
self.kv_channels = kv_channels
self.rotary_pct = rotary_pct
self.rotary_emb_base = rotary_emb_base
self.use_dynamic_ntk = use_dynamic_ntk
self.use_logn_attn = use_logn_attn
self.use_flash_attn = use_flash_attn
self.ffn_hidden_size = ffn_hidden_size
self.no_bias = no_bias
self.tie_word_embeddings = tie_word_embeddings

View File

@ -0,0 +1,118 @@
from typing import List, Tuple, Union
from transformers import (AutoTokenizer, PreTrainedTokenizer,
PreTrainedTokenizerFast)
from vllm.logger import init_logger
logger = init_logger(__name__)
# A fast LLaMA tokenizer with the pre-processed `tokenizer.json` file.
_FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer"
def get_tokenizer(
tokenizer_name: str,
*args,
tokenizer_mode: str = "auto",
trust_remote_code: bool = False,
**kwargs,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
"""Gets a tokenizer for the given model name via Huggingface."""
if tokenizer_mode == "slow":
if kwargs.get("use_fast", False):
raise ValueError(
"Cannot use the fast tokenizer in slow tokenizer mode.")
kwargs["use_fast"] = False
if "llama" in tokenizer_name.lower() and kwargs.get("use_fast", True):
logger.info(
"For some LLaMA-based models, initializing the fast tokenizer may "
"take a long time. To eliminate the initialization time, consider "
f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original "
"tokenizer.")
try:
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name,
*args,
trust_remote_code=trust_remote_code,
**kwargs)
except TypeError as e:
# The LLaMA tokenizer causes a protobuf error in some environments.
err_msg = (
"Failed to load the tokenizer. If you are using a LLaMA-based "
f"model, use '{_FAST_LLAMA_TOKENIZER}' instead of the original "
"tokenizer.")
raise RuntimeError(err_msg) from e
except ValueError as e:
# If the error pertains to the tokenizer class not existing or not
# currently being imported, suggest using the --trust-remote-code flag.
if (not trust_remote_code and
("does not exist or is not currently imported." in str(e)
or "requires you to execute the tokenizer file" in str(e))):
err_msg = (
"Failed to load the tokenizer. If the tokenizer is a custom "
"tokenizer not yet available in the HuggingFace transformers "
"library, consider setting `trust_remote_code=True` in LLM "
"or using the `--trust-remote-code` flag in the CLI.")
raise RuntimeError(err_msg) from e
else:
raise e
if not isinstance(tokenizer, PreTrainedTokenizerFast):
logger.warning(
"Using a slow tokenizer. This might cause a significant "
"slowdown. Consider using a fast tokenizer instead.")
return tokenizer
def detokenize_incrementally(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
prev_output_tokens: List[str],
new_token_id: int,
skip_special_tokens: bool,
) -> Tuple[str, str]:
"""Detokenizes the new token in conjunction with the previous output tokens.
NOTE: This function does not update prev_output_tokens.
Returns:
new_token: The new token as a string.
output_text: The new output text as a string.
"""
if skip_special_tokens and (new_token_id in tokenizer.all_special_ids):
return None, prev_output_tokens
new_token = tokenizer.convert_ids_to_tokens(
new_token_id, skip_special_tokens=skip_special_tokens)
output_tokens = prev_output_tokens + [new_token]
# Convert the tokens to a string.
# Optimization: If the tokenizer does not have `added_tokens_encoder`,
# then we can directly use `convert_tokens_to_string`.
if not getattr(tokenizer, "added_tokens_encoder", {}):
output_text = tokenizer.convert_tokens_to_string(output_tokens)
return new_token, output_text
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921
# NOTE(woosuk): The following code is slow because it runs a for loop over
# the output_tokens. In Python, running a for loop over a list can be slow
# even when the loop body is very simple.
sub_texts = []
current_sub_text = []
for token in output_tokens:
if skip_special_tokens and token in tokenizer.all_special_tokens:
continue
if token in tokenizer.added_tokens_encoder:
if current_sub_text:
sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
sub_texts.append(sub_text)
current_sub_text = []
sub_texts.append(token)
else:
current_sub_text.append(token)
if current_sub_text:
sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
sub_texts.append(sub_text)
output_text = " ".join(sub_texts)
return new_token, output_text

View File

@ -1,4 +1,5 @@
import enum import enum
from platform import uname
import uuid import uuid
import psutil import psutil
@ -16,9 +17,9 @@ class Counter:
self.counter = start self.counter = start
def __next__(self) -> int: def __next__(self) -> int:
id = self.counter i = self.counter
self.counter += 1 self.counter += 1
return id return i
def reset(self) -> None: def reset(self) -> None:
self.counter = 0 self.counter = 0
@ -36,3 +37,8 @@ def get_cpu_memory() -> int:
def random_uuid() -> str: def random_uuid() -> str:
return str(uuid.uuid4().hex) return str(uuid.uuid4().hex)
def in_wsl() -> bool:
# Reference: https://github.com/microsoft/WSL/issues/4071
return "microsoft" in " ".join(uname()).lower()

View File

@ -5,6 +5,10 @@ import torch
from vllm import cache_ops from vllm import cache_ops
from vllm.config import CacheConfig, ModelConfig, ParallelConfig from vllm.config import CacheConfig, ModelConfig, ParallelConfig
from vllm.logger import init_logger
from vllm.utils import in_wsl
logger = init_logger(__name__)
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
@ -85,16 +89,22 @@ class CacheEngine:
cpu_cache: List[KVCache] = [] cpu_cache: List[KVCache] = []
key_block_shape = self.get_key_block_shape() key_block_shape = self.get_key_block_shape()
value_block_shape = self.get_value_block_shape() value_block_shape = self.get_value_block_shape()
pin_memory = not in_wsl()
if not pin_memory:
# Pinning memory in WSL is not supported.
# https://docs.nvidia.com/cuda/wsl-user-guide/index.html#known-limitations-for-linux-cuda-applications
logger.warning("Using 'pin_memory=False' as WSL is detected. "
"This may slow down the performance.")
for _ in range(self.num_layers): for _ in range(self.num_layers):
key_blocks = torch.empty( key_blocks = torch.empty(
size=(self.num_cpu_blocks, *key_block_shape), size=(self.num_cpu_blocks, *key_block_shape),
dtype=self.dtype, dtype=self.dtype,
pin_memory=True, pin_memory=pin_memory,
) )
value_blocks = torch.empty( value_blocks = torch.empty(
size=(self.num_cpu_blocks, *value_block_shape), size=(self.num_cpu_blocks, *value_block_shape),
dtype=self.dtype, dtype=self.dtype,
pin_memory=True, pin_memory=pin_memory,
) )
cpu_cache.append((key_blocks, value_blocks)) cpu_cache.append((key_blocks, value_blocks))
return cpu_cache return cpu_cache
@ -110,11 +120,10 @@ class CacheEngine:
src_key_cache, src_value_cache = src[i] src_key_cache, src_value_cache = src[i]
dst_key_cache, dst_value_cache = dst[i] dst_key_cache, dst_value_cache = dst[i]
# Copy the key blocks. # Copy the key blocks.
cache_ops.swap_blocks( cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
src_key_cache, dst_key_cache, src_to_dst)
# Copy the value blocks. # Copy the value blocks.
cache_ops.swap_blocks( cache_ops.swap_blocks(src_value_cache, dst_value_cache,
src_value_cache, dst_value_cache, src_to_dst) src_to_dst)
event = self.events[i] event = self.events[i]
event.record(stream=self.cache_stream) event.record(stream=self.cache_stream)

View File

@ -1,13 +1,15 @@
"""A GPU worker class.""" """A GPU worker class."""
from typing import Dict, List, Tuple import os
from typing import Dict, List, Tuple, Optional
import torch import torch
import torch.distributed
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig) SchedulerConfig)
from vllm.model_executor import get_model, InputMetadata, set_random_seed from vllm.model_executor import get_model, InputMetadata, set_random_seed
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
initialize_model_parallel, initialize_all_reduce_launcher) initialize_model_parallel)
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import SequenceData, SequenceGroupMetadata, SequenceOutputs from vllm.sequence import SequenceData, SequenceGroupMetadata, SequenceOutputs
from vllm.worker.cache_engine import CacheEngine from vllm.worker.cache_engine import CacheEngine
@ -27,8 +29,8 @@ class Worker:
model_config: ModelConfig, model_config: ModelConfig,
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
rank: int, rank: Optional[int] = None,
distributed_init_method: str, distributed_init_method: Optional[str] = None,
) -> None: ) -> None:
self.model_config = model_config self.model_config = model_config
self.parallel_config = parallel_config self.parallel_config = parallel_config
@ -36,19 +38,6 @@ class Worker:
self.rank = rank self.rank = rank
self.distributed_init_method = distributed_init_method self.distributed_init_method = distributed_init_method
# Initialize the distributed environment.
_init_distributed_environment(parallel_config, rank,
distributed_init_method)
# Initialize the model.
set_random_seed(self.model_config.seed)
self.model = get_model(model_config)
initialize_all_reduce_launcher(
self.scheduler_config.max_num_batched_tokens,
self.model_config.get_hidden_size(),
self.model_config.dtype,
)
# Uninitialized cache engine. Will be initialized by # Uninitialized cache engine. Will be initialized by
# self.init_cache_engine(). # self.init_cache_engine().
self.cache_config = None self.cache_config = None
@ -57,6 +46,26 @@ class Worker:
self.cache_events = None self.cache_events = None
self.gpu_cache = None self.gpu_cache = None
def init_model(self):
# This env var set by Ray causes exceptions with graph building.
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
# Env vars will be set by Ray.
self.rank = self.rank if self.rank is not None else int(
os.getenv("RANK", "-1"))
local_rank = int(os.getenv("LOCAL_RANK", "0"))
self.device = torch.device(f"cuda:{local_rank}")
if self.rank < 0:
raise ValueError("Invalid or unspecified rank.")
torch.cuda.set_device(self.device)
# Initialize the distributed environment.
_init_distributed_environment(self.parallel_config, self.rank,
self.distributed_init_method)
# Initialize the model.
set_random_seed(self.model_config.seed)
self.model = get_model(self.model_config)
@torch.inference_mode() @torch.inference_mode()
def profile_num_available_blocks( def profile_num_available_blocks(
self, self,
@ -73,8 +82,8 @@ class Worker:
# number of tokens equal to max_num_batched_tokens. # number of tokens equal to max_num_batched_tokens.
# Enable top-k sampling to reflect the accurate memory usage. # Enable top-k sampling to reflect the accurate memory usage.
sampling_params = SamplingParams(top_p=0.99, vocab_size = self.model.config.vocab_size
top_k=self.model.config.vocab_size - 1) sampling_params = SamplingParams(top_p=0.99, top_k=vocab_size - 1)
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
max_num_seqs = self.scheduler_config.max_num_seqs max_num_seqs = self.scheduler_config.max_num_seqs
seqs = [] seqs = []
@ -91,7 +100,8 @@ class Worker:
) )
seqs.append(seq) seqs.append(seq)
input_tokens, input_positions, input_metadata = self._prepare_inputs(seqs) input_tokens, input_positions, input_metadata = self._prepare_inputs(
seqs)
# Execute the model. # Execute the model.
num_layers = self.model_config.get_num_layers(self.parallel_config) num_layers = self.model_config.get_num_layers(self.parallel_config)
@ -110,9 +120,12 @@ class Worker:
total_gpu_memory = get_gpu_memory() total_gpu_memory = get_gpu_memory()
cache_block_size = CacheEngine.get_cache_block_size( cache_block_size = CacheEngine.get_cache_block_size(
block_size, self.model_config, self.parallel_config) block_size, self.model_config, self.parallel_config)
num_gpu_blocks = int((total_gpu_memory * gpu_memory_utilization num_gpu_blocks = int(
- peak_memory) // cache_block_size) (total_gpu_memory * gpu_memory_utilization - peak_memory) //
cache_block_size)
num_cpu_blocks = int(cpu_swap_space // cache_block_size) num_cpu_blocks = int(cpu_swap_space // cache_block_size)
num_gpu_blocks = max(num_gpu_blocks, 0)
num_cpu_blocks = max(num_cpu_blocks, 0)
torch.cuda.empty_cache() torch.cuda.empty_cache()
# Reset the seed to ensure that the random state is not affected by # Reset the seed to ensure that the random state is not affected by
@ -123,8 +136,8 @@ class Worker:
def init_cache_engine(self, cache_config: CacheConfig) -> None: def init_cache_engine(self, cache_config: CacheConfig) -> None:
self.cache_config = cache_config self.cache_config = cache_config
self.block_size = cache_config.block_size self.block_size = cache_config.block_size
self.cache_engine = CacheEngine( self.cache_engine = CacheEngine(self.cache_config, self.model_config,
self.cache_config, self.model_config, self.parallel_config) self.parallel_config)
self.cache_events = self.cache_engine.events self.cache_events = self.cache_engine.events
self.gpu_cache = self.cache_engine.gpu_cache self.gpu_cache = self.cache_engine.gpu_cache
@ -200,8 +213,8 @@ class Worker:
generation_block_tables.append(block_table) generation_block_tables.append(block_table)
max_context_len = max(max_context_len, context_len) max_context_len = max(max_context_len, context_len)
max_num_blocks_per_seq = max( max_num_blocks_per_seq = max(max_num_blocks_per_seq,
max_num_blocks_per_seq, len(block_table)) len(block_table))
context_lens.append(context_len) context_lens.append(context_len)
block_number = block_table[position // self.block_size] block_number = block_table[position // self.block_size]
@ -221,7 +234,8 @@ class Worker:
context_lens_tensor = torch.cuda.IntTensor(context_lens) context_lens_tensor = torch.cuda.IntTensor(context_lens)
padded_block_tables = [ padded_block_tables = [
_pad_to_max(block_table, max_num_blocks_per_seq) _pad_to_max(block_table, max_num_blocks_per_seq)
for block_table in generation_block_tables] for block_table in generation_block_tables
]
block_tables_tensor = torch.cuda.IntTensor(padded_block_tables) block_tables_tensor = torch.cuda.IntTensor(padded_block_tables)
seq_data: Dict[int, SequenceData] = {} seq_data: Dict[int, SequenceData] = {}
@ -289,15 +303,28 @@ class Worker:
def _init_distributed_environment( def _init_distributed_environment(
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
rank: int, rank: int,
distributed_init_method: str, distributed_init_method: Optional[str] = None,
) -> None: ) -> None:
"""Initialize the distributed environment.""" """Initialize the distributed environment."""
if torch.distributed.is_initialized():
torch_world_size = torch.distributed.get_world_size()
if torch_world_size != parallel_config.world_size:
raise RuntimeError(
"torch.distributed is already initialized but the torch world "
"size does not match parallel_config.world_size "
f"({torch_world_size} vs. {parallel_config.world_size}).")
elif not distributed_init_method:
raise ValueError(
"distributed_init_method must be set if torch.distributed "
"is not already initialized")
else:
torch.distributed.init_process_group( torch.distributed.init_process_group(
backend="nccl", backend="nccl",
world_size=parallel_config.world_size, world_size=parallel_config.world_size,
rank=rank, rank=rank,
init_method=distributed_init_method, init_method=distributed_init_method,
) )
# A small all_reduce for warmup. # A small all_reduce for warmup.
torch.distributed.all_reduce(torch.zeros(1).cuda()) torch.distributed.all_reduce(torch.zeros(1).cuda())
initialize_model_parallel(parallel_config.tensor_parallel_size, initialize_model_parallel(parallel_config.tensor_parallel_size,