mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-28 12:24:34 +08:00
Compare commits
129 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 852ef5b4f5 | |||
| db09d4ad83 | |||
| c957c741d9 | |||
| c07ece5ca4 | |||
| 7a9c20c715 | |||
| 005ba458b5 | |||
| 320a622ec4 | |||
| c9927c1a6a | |||
| fbd80ad409 | |||
| 22379d5513 | |||
| 1696725879 | |||
| 002800f081 | |||
| e15932bb60 | |||
| ce741ba3e4 | |||
| bf87484efa | |||
| 8ce9c50d40 | |||
| 32b6816e55 | |||
| c128d69856 | |||
| 55b28b1eee | |||
| e11222333f | |||
| 28873a2799 | |||
| 0080d8329d | |||
| 0d93f15694 | |||
| becd7a56f1 | |||
| 75471386de | |||
| d2b2eed67c | |||
| 4b6f069b6f | |||
| 791d79de32 | |||
| 94d2f59895 | |||
| 75c0ca9d43 | |||
| 2a4ec90854 | |||
| 85ebcda94d | |||
| d64bf1646c | |||
| a41c20435e | |||
| eedac9dba0 | |||
| 14f9c72bfd | |||
| ad5f2fe34c | |||
| 4f8584756d | |||
| 65fc1c3127 | |||
| c393af6cd7 | |||
| 0c04ce3234 | |||
| 73b3de79ea | |||
| d1744376ae | |||
| 805de738f6 | |||
| 1b151ed181 | |||
| e06f504a76 | |||
| 462ae5220a | |||
| 66c54aa9c3 | |||
| 735ecfff61 | |||
| a57d13cc96 | |||
| 79af7e96a0 | |||
| 621980bdc0 | |||
| aa84c92ef6 | |||
| f7389f4763 | |||
| 55fe8a81ec | |||
| e8ddc08ec8 | |||
| 1b0bd0fe8a | |||
| 20044cab7a | |||
| 64f23c2900 | |||
| d4c7755ca8 | |||
| aa39e42c5a | |||
| 953f28cf9a | |||
| c0d00f5be6 | |||
| 58a072be15 | |||
| 82ad323dee | |||
| df5dd3c68e | |||
| 2d867b55fa | |||
| d7a1c6d614 | |||
| 7d5a155e4a | |||
| 1dde34e0f8 | |||
| 6fc2a38b11 | |||
| c487a221ee | |||
| 9925c17940 | |||
| 8c4b2592fb | |||
| cf21a9bd5c | |||
| 16c3e295a8 | |||
| bda41c70dd | |||
| 453bafb96f | |||
| 328d231c17 | |||
| b4b195b360 | |||
| 20b0d88d16 | |||
| 2bdea7ac11 | |||
| 58df2883cb | |||
| 6d7d95a70a | |||
| 96853af5a8 | |||
| dbed69058c | |||
| 7b6ae94059 | |||
| c6dfc3cdbe | |||
| 51be365143 | |||
| c894836108 | |||
| 75beba29b5 | |||
| ddfdf470ae | |||
| b6fbb9a565 | |||
| 2179e4f4c5 | |||
| a945fcc2ae | |||
| be54f8e5c4 | |||
| b396cb4998 | |||
| 1c395b4eaa | |||
| 3d64cf019e | |||
| 98fe8cb542 | |||
| ffa6d2f9f9 | |||
| 404422f42e | |||
| 7717d0838b | |||
| 42e0c1df78 | |||
| e41f06702c | |||
| d6fa1be3a8 | |||
| 0ffded812a | |||
| 0bd2a573a5 | |||
| 49b26e2cec | |||
| dafd924c1f | |||
| 598dc4b79a | |||
| 85de093472 | |||
| f72297562f | |||
| 9d27b09d12 | |||
| 998d9d1509 | |||
| 425040d4c1 | |||
| 4338cc4750 | |||
| bdd6b4c8bc | |||
| 2b7d3aca2e | |||
| 4026a049d3 | |||
| 43710e8d09 | |||
| 526df28fb2 | |||
| 2cf1a333b6 | |||
| 0b7db411b5 | |||
| 471a7a4566 | |||
| 6214dd6ce9 | |||
| 0603379863 | |||
| 665c48963b | |||
| 298695b766 |
101
.github/workflows/publish.yml
vendored
Normal file
101
.github/workflows/publish.yml
vendored
Normal file
@ -0,0 +1,101 @@
|
||||
# This workflow will upload a Python Package to Release asset
|
||||
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions
|
||||
|
||||
name: Create Release
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- v*
|
||||
|
||||
# Needed to create release and upload assets
|
||||
permissions:
|
||||
contents: write
|
||||
|
||||
jobs:
|
||||
release:
|
||||
# Retrieve tag and create release
|
||||
name: Create Release
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
upload_url: ${{ steps.create_release.outputs.upload_url }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Extract branch info
|
||||
shell: bash
|
||||
run: |
|
||||
echo "release_tag=${GITHUB_REF#refs/*/}" >> $GITHUB_ENV
|
||||
|
||||
- name: Create Release
|
||||
id: create_release
|
||||
uses: "actions/github-script@v6"
|
||||
env:
|
||||
RELEASE_TAG: ${{ env.release_tag }}
|
||||
with:
|
||||
github-token: "${{ secrets.GITHUB_TOKEN }}"
|
||||
script: |
|
||||
const script = require('.github/workflows/scripts/create_release.js')
|
||||
await script(github, context, core)
|
||||
|
||||
wheel:
|
||||
name: Build Wheel
|
||||
runs-on: ${{ matrix.os }}
|
||||
needs: release
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: ['ubuntu-20.04']
|
||||
python-version: ['3.8', '3.9', '3.10', '3.11']
|
||||
cuda-version: ['11.8'] # Github runner can't build anything older than 11.8
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Set up Linux Env
|
||||
if: ${{ runner.os == 'Linux' }}
|
||||
run: |
|
||||
bash -x .github/workflows/scripts/env.sh
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Install CUDA ${{ matrix.cuda-version }}
|
||||
run: |
|
||||
bash -x .github/workflows/scripts/cuda-install.sh ${{ matrix.cuda-version }} ${{ matrix.os }}
|
||||
|
||||
- name: Install PyTorch-cu${{ matrix.cuda-version }}
|
||||
run: |
|
||||
bash -x .github/workflows/scripts/pytorch-install.sh ${{ matrix.python-version }} ${{ matrix.cuda-version }}
|
||||
|
||||
- name: Build wheel
|
||||
shell: bash
|
||||
run: |
|
||||
bash -x .github/workflows/scripts/build.sh ${{ matrix.python-version }} ${{ matrix.cuda-version }}
|
||||
wheel_name=$(ls dist/*whl | xargs -n 1 basename)
|
||||
asset_name=${wheel_name//"linux"/"manylinux1"}
|
||||
echo "wheel_name=${wheel_name}" >> $GITHUB_ENV
|
||||
echo "asset_name=${asset_name}" >> $GITHUB_ENV
|
||||
|
||||
- name: Upload Release Asset
|
||||
uses: actions/upload-release-asset@v1
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
with:
|
||||
upload_url: ${{ needs.release.outputs.upload_url }}
|
||||
asset_path: ./dist/${{ env.wheel_name }}
|
||||
asset_name: ${{ env.asset_name }}
|
||||
asset_content_type: application/*
|
||||
|
||||
# (Danielkinz): This last step will publish the .whl to pypi. Warning: untested
|
||||
# - name: Publish package
|
||||
# uses: pypa/gh-action-pypi-publish@release/v1.8
|
||||
# with:
|
||||
# repository-url: https://test.pypi.org/legacy/
|
||||
# password: ${{ secrets.PYPI_API_TOKEN }}
|
||||
# skip-existing: true
|
||||
31
.github/workflows/pylint.yml
vendored
Normal file
31
.github/workflows/pylint.yml
vendored
Normal file
@ -0,0 +1,31 @@
|
||||
name: pylint
|
||||
|
||||
on:
|
||||
# Trigger the workflow on push or pull request,
|
||||
# but only for the main branch
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
|
||||
jobs:
|
||||
pylint:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.10"]
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install pylint==2.8.2
|
||||
- name: Analysing the code with pylint
|
||||
run: |
|
||||
pylint vllm
|
||||
15
.github/workflows/scripts/build.sh
vendored
Normal file
15
.github/workflows/scripts/build.sh
vendored
Normal file
@ -0,0 +1,15 @@
|
||||
#!/bin/bash
|
||||
|
||||
python_executable=python$1
|
||||
cuda_home=/usr/local/cuda-$2
|
||||
|
||||
# Update paths
|
||||
PATH=${cuda_home}/bin:$PATH
|
||||
LD_LIBRARY_PATH=${cuda_home}/lib64:$LD_LIBRARY_PATH
|
||||
|
||||
# Install requirements
|
||||
$python_executable -m pip install wheel packaging
|
||||
$python_executable -m pip install -r requirements.txt
|
||||
|
||||
# Build
|
||||
$python_executable setup.py bdist_wheel --dist-dir=dist
|
||||
20
.github/workflows/scripts/create_release.js
vendored
Normal file
20
.github/workflows/scripts/create_release.js
vendored
Normal file
@ -0,0 +1,20 @@
|
||||
// Uses Github's API to create the release and wait for result.
|
||||
// We use a JS script since github CLI doesn't provide a way to wait for the release's creation and returns immediately.
|
||||
|
||||
module.exports = async (github, context, core) => {
|
||||
try {
|
||||
const response = await github.rest.repos.createRelease({
|
||||
draft: false,
|
||||
generate_release_notes: true,
|
||||
name: process.env.RELEASE_TAG,
|
||||
owner: context.repo.owner,
|
||||
prerelease: false,
|
||||
repo: context.repo.repo,
|
||||
tag_name: process.env.RELEASE_TAG,
|
||||
});
|
||||
|
||||
core.setOutput('upload_url', response.data.upload_url);
|
||||
} catch (error) {
|
||||
core.setFailed(error.message);
|
||||
}
|
||||
}
|
||||
18
.github/workflows/scripts/cuda-install.sh
vendored
Normal file
18
.github/workflows/scripts/cuda-install.sh
vendored
Normal file
@ -0,0 +1,18 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Replace '.' with '-' ex: 11.8 -> 11-8
|
||||
cuda_version=$(echo $1 | tr "." "-")
|
||||
# Removes '-' and '.' ex: ubuntu-20.04 -> ubuntu2004
|
||||
OS=$(echo $2 | tr -d ".\-")
|
||||
|
||||
# Installs CUDA
|
||||
wget -nv https://developer.download.nvidia.com/compute/cuda/repos/${OS}/x86_64/cuda-keyring_1.1-1_all.deb
|
||||
sudo dpkg -i cuda-keyring_1.1-1_all.deb
|
||||
rm cuda-keyring_1.1-1_all.deb
|
||||
sudo apt -qq update
|
||||
sudo apt -y install cuda-${cuda_version} cuda-nvcc-${cuda_version} cuda-libraries-dev-${cuda_version}
|
||||
sudo apt clean
|
||||
|
||||
# Test nvcc
|
||||
PATH=/usr/local/cuda-$1/bin:${PATH}
|
||||
nvcc --version
|
||||
56
.github/workflows/scripts/env.sh
vendored
Normal file
56
.github/workflows/scripts/env.sh
vendored
Normal file
@ -0,0 +1,56 @@
|
||||
#!/bin/bash
|
||||
|
||||
# This file installs common linux environment tools
|
||||
|
||||
export LANG C.UTF-8
|
||||
|
||||
# python_version=$1
|
||||
|
||||
sudo apt-get update && \
|
||||
sudo apt-get install -y --no-install-recommends \
|
||||
software-properties-common \
|
||||
|
||||
sudo apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
apt-utils \
|
||||
ca-certificates \
|
||||
wget \
|
||||
git \
|
||||
vim \
|
||||
libssl-dev \
|
||||
curl \
|
||||
unzip \
|
||||
unrar \
|
||||
cmake \
|
||||
net-tools \
|
||||
sudo \
|
||||
autotools-dev \
|
||||
rsync \
|
||||
jq \
|
||||
openssh-server \
|
||||
tmux \
|
||||
screen \
|
||||
htop \
|
||||
pdsh \
|
||||
openssh-client \
|
||||
lshw \
|
||||
dmidecode \
|
||||
util-linux \
|
||||
automake \
|
||||
autoconf \
|
||||
libtool \
|
||||
net-tools \
|
||||
pciutils \
|
||||
libpci-dev \
|
||||
libaio-dev \
|
||||
libcap2 \
|
||||
libtinfo5 \
|
||||
fakeroot \
|
||||
devscripts \
|
||||
debhelper \
|
||||
nfs-common
|
||||
|
||||
# Remove github bloat files to free up disk space
|
||||
sudo rm -rf "/usr/local/share/boost"
|
||||
sudo rm -rf "$AGENT_TOOLSDIRECTORY"
|
||||
sudo rm -rf "/usr/share/dotnet"
|
||||
14
.github/workflows/scripts/pytorch-install.sh
vendored
Normal file
14
.github/workflows/scripts/pytorch-install.sh
vendored
Normal file
@ -0,0 +1,14 @@
|
||||
#!/bin/bash
|
||||
|
||||
python_executable=python$1
|
||||
cuda_version=$2
|
||||
|
||||
# Install torch
|
||||
$python_executable -m pip install numpy pyyaml scipy ipython mkl mkl-include ninja cython typing pandas typing-extensions dataclasses setuptools && conda clean -ya
|
||||
$python_executable -m pip install torch -f https://download.pytorch.org/whl/cu${cuda_version//./}/torch_stable.html
|
||||
|
||||
# Print version information
|
||||
$python_executable --version
|
||||
$python_executable -c "import torch; print('PyTorch:', torch.__version__)"
|
||||
$python_executable -c "import torch; print('CUDA:', torch.version.cuda)"
|
||||
$python_executable -c "from torch.utils import cpp_extension; print (cpp_extension.CUDA_HOME)"
|
||||
31
.github/workflows/yapf.yml
vendored
Normal file
31
.github/workflows/yapf.yml
vendored
Normal file
@ -0,0 +1,31 @@
|
||||
name: yapf
|
||||
|
||||
on:
|
||||
# Trigger the workflow on push or pull request,
|
||||
# but only for the main branch
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
jobs:
|
||||
yapf:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.10"]
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install yapf==0.32.0
|
||||
pip install toml==0.10.2
|
||||
- name: Running yapf
|
||||
run: |
|
||||
yapf --diff --recursive vllm --exclude 'vllm/model_executor/parallel_utils/**'
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@ -170,3 +170,6 @@ cython_debug/
|
||||
|
||||
# Python pickle files
|
||||
*.pkl
|
||||
|
||||
# Sphinx documentation
|
||||
_build/
|
||||
|
||||
434
.pylintrc
Normal file
434
.pylintrc
Normal file
@ -0,0 +1,434 @@
|
||||
# This Pylint rcfile contains a best-effort configuration to uphold the
|
||||
# best-practices and style described in the Google Python style guide:
|
||||
# https://google.github.io/styleguide/pyguide.html
|
||||
#
|
||||
# Its canonical open-source location is:
|
||||
# https://google.github.io/styleguide/pylintrc
|
||||
|
||||
[MASTER]
|
||||
|
||||
# Files or directories to be skipped. They should be base names, not paths.
|
||||
ignore=docs,parallel_utils
|
||||
|
||||
# Files or directories matching the regex patterns are skipped. The regex
|
||||
# matches against base names, not paths.
|
||||
ignore-patterns=
|
||||
|
||||
# Pickle collected data for later comparisons.
|
||||
persistent=no
|
||||
|
||||
# List of plugins (as comma separated values of python modules names) to load,
|
||||
# usually to register additional checkers.
|
||||
load-plugins=
|
||||
|
||||
# Use multiple processes to speed up Pylint.
|
||||
jobs=4
|
||||
|
||||
# Allow loading of arbitrary C extensions. Extensions are imported into the
|
||||
# active Python interpreter and may run arbitrary code.
|
||||
unsafe-load-any-extension=no
|
||||
|
||||
|
||||
[MESSAGES CONTROL]
|
||||
|
||||
# Only show warnings with the listed confidence levels. Leave empty to show
|
||||
# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED
|
||||
confidence=
|
||||
|
||||
# Enable the message, report, category or checker with the given id(s). You can
|
||||
# either give multiple identifier separated by comma (,) or put this option
|
||||
# multiple time (only on the command line, not in the configuration file where
|
||||
# it should appear only once). See also the "--disable" option for examples.
|
||||
#enable=
|
||||
|
||||
# Disable the message, report, category or checker with the given id(s). You
|
||||
# can either give multiple identifiers separated by comma (,) or put this
|
||||
# option multiple times (only on the command line, not in the configuration
|
||||
# file where it should appear only once).You can also use "--disable=all" to
|
||||
# disable everything first and then reenable specific checks. For example, if
|
||||
# you want to run only the similarities checker, you can use "--disable=all
|
||||
# --enable=similarities". If you want to run only the classes checker, but have
|
||||
# no Warning level messages displayed, use"--disable=all --enable=classes
|
||||
# --disable=W"
|
||||
disable=abstract-method,
|
||||
apply-builtin,
|
||||
arguments-differ,
|
||||
attribute-defined-outside-init,
|
||||
backtick,
|
||||
bad-option-value,
|
||||
basestring-builtin,
|
||||
buffer-builtin,
|
||||
c-extension-no-member,
|
||||
consider-using-enumerate,
|
||||
cmp-builtin,
|
||||
cmp-method,
|
||||
coerce-builtin,
|
||||
coerce-method,
|
||||
delslice-method,
|
||||
div-method,
|
||||
duplicate-code,
|
||||
eq-without-hash,
|
||||
execfile-builtin,
|
||||
file-builtin,
|
||||
filter-builtin-not-iterating,
|
||||
fixme,
|
||||
getslice-method,
|
||||
global-statement,
|
||||
hex-method,
|
||||
idiv-method,
|
||||
implicit-str-concat-in-sequence,
|
||||
import-error,
|
||||
import-self,
|
||||
import-star-module-level,
|
||||
inconsistent-return-statements,
|
||||
input-builtin,
|
||||
intern-builtin,
|
||||
invalid-str-codec,
|
||||
locally-disabled,
|
||||
logging-fstring-interpolation, # added by vLLM
|
||||
logging-not-lazy, # added by vLLM
|
||||
long-builtin,
|
||||
long-suffix,
|
||||
map-builtin-not-iterating,
|
||||
misplaced-comparison-constant,
|
||||
missing-class-docstring, # TODO (vLLM): enable
|
||||
missing-function-docstring,
|
||||
missing-module-docstring, # TODO (vLLM): enable
|
||||
metaclass-assignment,
|
||||
next-method-called,
|
||||
next-method-defined,
|
||||
no-absolute-import,
|
||||
no-else-break,
|
||||
no-else-continue,
|
||||
no-else-raise,
|
||||
no-else-return,
|
||||
no-init, # added
|
||||
no-member,
|
||||
no-name-in-module,
|
||||
no-self-use,
|
||||
nonzero-method,
|
||||
oct-method,
|
||||
old-division,
|
||||
old-ne-operator,
|
||||
old-octal-literal,
|
||||
old-raise-syntax,
|
||||
parameter-unpacking,
|
||||
print-statement,
|
||||
raising-string,
|
||||
range-builtin-not-iterating,
|
||||
raw_input-builtin,
|
||||
rdiv-method,
|
||||
reduce-builtin,
|
||||
relative-import,
|
||||
reload-builtin,
|
||||
round-builtin,
|
||||
setslice-method,
|
||||
signature-differs,
|
||||
standarderror-builtin,
|
||||
suppressed-message,
|
||||
sys-max-int,
|
||||
too-few-public-methods,
|
||||
too-many-ancestors,
|
||||
too-many-arguments,
|
||||
too-many-boolean-expressions,
|
||||
too-many-branches,
|
||||
too-many-instance-attributes,
|
||||
too-many-locals,
|
||||
too-many-nested-blocks,
|
||||
too-many-public-methods,
|
||||
too-many-return-statements,
|
||||
too-many-statements,
|
||||
trailing-newlines,
|
||||
unichr-builtin,
|
||||
unicode-builtin,
|
||||
unnecessary-pass,
|
||||
unpacking-in-except,
|
||||
unspecified-encoding,
|
||||
useless-else-on-loop,
|
||||
useless-object-inheritance,
|
||||
useless-suppression,
|
||||
using-cmp-argument,
|
||||
wrong-import-order,
|
||||
xrange-builtin,
|
||||
zip-builtin-not-iterating,
|
||||
|
||||
|
||||
[REPORTS]
|
||||
|
||||
# Set the output format. Available formats are text, parseable, colorized, msvs
|
||||
# (visual studio) and html. You can also give a reporter class, eg
|
||||
# mypackage.mymodule.MyReporterClass.
|
||||
output-format=text
|
||||
|
||||
# Tells whether to display a full report or only the messages
|
||||
reports=no
|
||||
|
||||
# Python expression which should return a note less than 10 (10 is the highest
|
||||
# note). You have access to the variables errors warning, statement which
|
||||
# respectively contain the number of errors / warnings messages and the total
|
||||
# number of statements analyzed. This is used by the global evaluation report
|
||||
# (RP0004).
|
||||
evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
|
||||
|
||||
# Template used to display messages. This is a python new-style format string
|
||||
# used to format the message information. See doc for all details
|
||||
#msg-template=
|
||||
|
||||
|
||||
[BASIC]
|
||||
|
||||
# Good variable names which should always be accepted, separated by a comma
|
||||
good-names=main,_
|
||||
|
||||
# Bad variable names which should always be refused, separated by a comma
|
||||
bad-names=
|
||||
|
||||
# Colon-delimited sets of names that determine each other's naming style when
|
||||
# the name regexes allow several styles.
|
||||
name-group=
|
||||
|
||||
# Include a hint for the correct naming format with invalid-name
|
||||
include-naming-hint=no
|
||||
|
||||
# List of decorators that produce properties, such as abc.abstractproperty. Add
|
||||
# to this list to register other decorators that produce valid properties.
|
||||
property-classes=abc.abstractproperty,cached_property.cached_property,cached_property.threaded_cached_property,cached_property.cached_property_with_ttl,cached_property.threaded_cached_property_with_ttl
|
||||
|
||||
# Regular expression matching correct function names
|
||||
function-rgx=^(?:(?P<exempt>setUp|tearDown|setUpModule|tearDownModule)|(?P<camel_case>_?[A-Z][a-zA-Z0-9]*)|(?P<snake_case>_?[a-z][a-z0-9_]*))$
|
||||
|
||||
# Regular expression matching correct variable names
|
||||
variable-rgx=^[a-z][a-z0-9_]*$
|
||||
|
||||
# Regular expression matching correct constant names
|
||||
const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$
|
||||
|
||||
# Regular expression matching correct attribute names
|
||||
attr-rgx=^_{0,2}[a-z][a-z0-9_]*$
|
||||
|
||||
# Regular expression matching correct argument names
|
||||
argument-rgx=^[a-z][a-z0-9_]*$
|
||||
|
||||
# Regular expression matching correct class attribute names
|
||||
class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$
|
||||
|
||||
# Regular expression matching correct inline iteration names
|
||||
inlinevar-rgx=^[a-z][a-z0-9_]*$
|
||||
|
||||
# Regular expression matching correct class names
|
||||
class-rgx=^_?[A-Z][a-zA-Z0-9]*$
|
||||
|
||||
# Regular expression matching correct module names
|
||||
module-rgx=^(_?[a-z][a-z0-9_]*|__init__)$
|
||||
|
||||
# Regular expression matching correct method names
|
||||
method-rgx=(?x)^(?:(?P<exempt>_[a-z0-9_]+__|runTest|setUp|tearDown|setUpTestCase|tearDownTestCase|setupSelf|tearDownClass|setUpClass|(test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|next)|(?P<camel_case>_{0,2}[A-Z][a-zA-Z0-9_]*)|(?P<snake_case>_{0,2}[a-z][a-z0-9_]*))$
|
||||
|
||||
# Regular expression which should only match function or class names that do
|
||||
# not require a docstring.
|
||||
no-docstring-rgx=(__.*__|main|test.*|.*test|.*Test)$
|
||||
|
||||
# Minimum line length for functions/classes that require docstrings, shorter
|
||||
# ones are exempt.
|
||||
docstring-min-length=10
|
||||
|
||||
|
||||
[TYPECHECK]
|
||||
|
||||
# List of decorators that produce context managers, such as
|
||||
# contextlib.contextmanager. Add to this list to register other decorators that
|
||||
# produce valid context managers.
|
||||
contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager
|
||||
|
||||
# Tells whether missing members accessed in mixin class should be ignored. A
|
||||
# mixin class is detected if its name ends with "mixin" (case insensitive).
|
||||
ignore-mixin-members=yes
|
||||
|
||||
# List of module names for which member attributes should not be checked
|
||||
# (useful for modules/projects where namespaces are manipulated during runtime
|
||||
# and thus existing member attributes cannot be deduced by static analysis. It
|
||||
# supports qualified module names, as well as Unix pattern matching.
|
||||
ignored-modules=
|
||||
|
||||
# List of class names for which member attributes should not be checked (useful
|
||||
# for classes with dynamically set attributes). This supports the use of
|
||||
# qualified names.
|
||||
ignored-classes=optparse.Values,thread._local,_thread._local
|
||||
|
||||
# List of members which are set dynamically and missed by pylint inference
|
||||
# system, and so shouldn't trigger E1101 when accessed. Python regular
|
||||
# expressions are accepted.
|
||||
generated-members=
|
||||
|
||||
|
||||
[FORMAT]
|
||||
|
||||
# Maximum number of characters on a single line.
|
||||
max-line-length=80
|
||||
|
||||
# TODO(https://github.com/PyCQA/pylint/issues/3352): Direct pylint to exempt
|
||||
# lines made too long by directives to pytype.
|
||||
|
||||
# Regexp for a line that is allowed to be longer than the limit.
|
||||
ignore-long-lines=(?x)(
|
||||
^\s*(\#\ )?<?https?://\S+>?$|
|
||||
^\s*(from\s+\S+\s+)?import\s+.+$)
|
||||
|
||||
# Allow the body of an if to be on the same line as the test if there is no
|
||||
# else.
|
||||
single-line-if-stmt=yes
|
||||
|
||||
# Maximum number of lines in a module
|
||||
max-module-lines=99999
|
||||
|
||||
# String used as indentation unit. The internal Google style guide mandates 2
|
||||
# spaces. Google's externaly-published style guide says 4, consistent with
|
||||
# PEP 8. Here, we use 2 spaces, for conformity with many open-sourced Google
|
||||
# projects (like TensorFlow).
|
||||
indent-string=' '
|
||||
|
||||
# Number of spaces of indent required inside a hanging or continued line.
|
||||
indent-after-paren=4
|
||||
|
||||
# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
|
||||
expected-line-ending-format=
|
||||
|
||||
|
||||
[MISCELLANEOUS]
|
||||
|
||||
# List of note tags to take in consideration, separated by a comma.
|
||||
notes=TODO
|
||||
|
||||
|
||||
[STRING]
|
||||
|
||||
# This flag controls whether inconsistent-quotes generates a warning when the
|
||||
# character used as a quote delimiter is used inconsistently within a module.
|
||||
check-quote-consistency=yes
|
||||
|
||||
|
||||
[VARIABLES]
|
||||
|
||||
# Tells whether we should check for unused import in __init__ files.
|
||||
init-import=no
|
||||
|
||||
# A regular expression matching the name of dummy variables (i.e. expectedly
|
||||
# not used).
|
||||
dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_)
|
||||
|
||||
# List of additional names supposed to be defined in builtins. Remember that
|
||||
# you should avoid to define new builtins when possible.
|
||||
additional-builtins=
|
||||
|
||||
# List of strings which can identify a callback function by name. A callback
|
||||
# name must start or end with one of those strings.
|
||||
callbacks=cb_,_cb
|
||||
|
||||
# List of qualified module names which can have objects that can redefine
|
||||
# builtins.
|
||||
redefining-builtins-modules=six,six.moves,past.builtins,future.builtins,functools
|
||||
|
||||
|
||||
[LOGGING]
|
||||
|
||||
# Logging modules to check that the string format arguments are in logging
|
||||
# function parameter format
|
||||
logging-modules=logging,absl.logging,tensorflow.io.logging
|
||||
|
||||
|
||||
[SIMILARITIES]
|
||||
|
||||
# Minimum lines number of a similarity.
|
||||
min-similarity-lines=4
|
||||
|
||||
# Ignore comments when computing similarities.
|
||||
ignore-comments=yes
|
||||
|
||||
# Ignore docstrings when computing similarities.
|
||||
ignore-docstrings=yes
|
||||
|
||||
# Ignore imports when computing similarities.
|
||||
ignore-imports=no
|
||||
|
||||
|
||||
[SPELLING]
|
||||
|
||||
# Spelling dictionary name. Available dictionaries: none. To make it working
|
||||
# install python-enchant package.
|
||||
spelling-dict=
|
||||
|
||||
# List of comma separated words that should not be checked.
|
||||
spelling-ignore-words=
|
||||
|
||||
# A path to a file that contains private dictionary; one word per line.
|
||||
spelling-private-dict-file=
|
||||
|
||||
# Tells whether to store unknown words to indicated private dictionary in
|
||||
# --spelling-private-dict-file option instead of raising a message.
|
||||
spelling-store-unknown-words=no
|
||||
|
||||
|
||||
[IMPORTS]
|
||||
|
||||
# Deprecated modules which should not be used, separated by a comma
|
||||
deprecated-modules=regsub,
|
||||
TERMIOS,
|
||||
Bastion,
|
||||
rexec,
|
||||
sets
|
||||
|
||||
# Create a graph of every (i.e. internal and external) dependencies in the
|
||||
# given file (report RP0402 must not be disabled)
|
||||
import-graph=
|
||||
|
||||
# Create a graph of external dependencies in the given file (report RP0402 must
|
||||
# not be disabled)
|
||||
ext-import-graph=
|
||||
|
||||
# Create a graph of internal dependencies in the given file (report RP0402 must
|
||||
# not be disabled)
|
||||
int-import-graph=
|
||||
|
||||
# Force import order to recognize a module as part of the standard
|
||||
# compatibility libraries.
|
||||
known-standard-library=
|
||||
|
||||
# Force import order to recognize a module as part of a third party library.
|
||||
known-third-party=enchant, absl
|
||||
|
||||
# Analyse import fallback blocks. This can be used to support both Python 2 and
|
||||
# 3 compatible code, which means that the block might have code that exists
|
||||
# only in one or another interpreter, leading to false positives when analysed.
|
||||
analyse-fallback-blocks=no
|
||||
|
||||
|
||||
[CLASSES]
|
||||
|
||||
# List of method names used to declare (i.e. assign) instance attributes.
|
||||
defining-attr-methods=__init__,
|
||||
__new__,
|
||||
setUp
|
||||
|
||||
# List of member names, which should be excluded from the protected access
|
||||
# warning.
|
||||
exclude-protected=_asdict,
|
||||
_fields,
|
||||
_replace,
|
||||
_source,
|
||||
_make
|
||||
|
||||
# List of valid names for the first argument in a class method.
|
||||
valid-classmethod-first-arg=cls,
|
||||
class_
|
||||
|
||||
# List of valid names for the first argument in a metaclass class method.
|
||||
valid-metaclass-classmethod-first-arg=mcs
|
||||
|
||||
|
||||
[EXCEPTIONS]
|
||||
|
||||
# Exceptions that will emit a warning when being caught. Defaults to
|
||||
# "Exception"
|
||||
overgeneral-exceptions=StandardError,
|
||||
Exception,
|
||||
BaseException
|
||||
@ -49,12 +49,15 @@ If not, please file a new issue, providing as much relevant information as possi
|
||||
|
||||
In general, we adhere to [Google Python style guide](https://google.github.io/styleguide/pyguide.html) and [Google C++ style guide](https://google.github.io/styleguide/cppguide.html).
|
||||
|
||||
We include a formatting script [`format.sh`](./format.sh) to format the code.
|
||||
|
||||
### Pull Requests
|
||||
|
||||
When submitting a pull request:
|
||||
|
||||
1. Make sure your code has been rebased on top of the latest commit on the main branch.
|
||||
2. Include a detailed description of the changes in the pull request.
|
||||
2. Ensure code is properly formatted by running [`format.sh`](./format.sh).
|
||||
3. Include a detailed description of the changes in the pull request.
|
||||
Explain why you made the changes you did.
|
||||
If your pull request fixes an open issue, please include a reference to it in the description.
|
||||
|
||||
|
||||
21
README.md
21
README.md
@ -17,8 +17,10 @@ Easy, fast, and cheap LLM serving for everyone
|
||||
---
|
||||
|
||||
*Latest News* 🔥
|
||||
|
||||
- [2023/06] We officially released vLLM! vLLM has powered [LMSYS Vicuna and Chatbot Arena](https://chat.lmsys.org) since mid April. Check out our [blog post](https://vllm.ai).
|
||||
- [2023/08] We would like to express our sincere gratitude to [Andreessen Horowitz](https://a16z.com/2023/08/30/supporting-the-open-source-ai-community/) (a16z) for providing a generous grant to support the open-source development and research of vLLM.
|
||||
- [2023/07] Added support for LLaMA-2! You can run and serve 7B/13B/70B LLaMA-2s on vLLM with a single command!
|
||||
- [2023/06] Serving vLLM On any Cloud with SkyPilot. Check out a 1-click [example](https://github.com/skypilot-org/skypilot/blob/master/llm/vllm) to start the vLLM demo, and the [blog post](https://blog.skypilot.co/serving-llm-24x-faster-on-the-cloud-with-vllm-and-skypilot/) for the story behind vLLM development on the clouds.
|
||||
- [2023/06] We officially released vLLM! FastChat-vLLM integration has powered [LMSYS Vicuna and Chatbot Arena](https://chat.lmsys.org) since mid-April. Check out our [blog post](https://vllm.ai).
|
||||
|
||||
---
|
||||
|
||||
@ -28,7 +30,7 @@ vLLM is fast with:
|
||||
|
||||
- State-of-the-art serving throughput
|
||||
- Efficient management of attention key and value memory with **PagedAttention**
|
||||
- Dynamic batching of incoming requests
|
||||
- Continuous batching of incoming requests
|
||||
- Optimized CUDA kernels
|
||||
|
||||
vLLM is flexible and easy to use with:
|
||||
@ -41,10 +43,19 @@ vLLM is flexible and easy to use with:
|
||||
|
||||
vLLM seamlessly supports many Huggingface models, including the following architectures:
|
||||
|
||||
- Aquila (`BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc.)
|
||||
- Baichuan (`baichuan-inc/Baichuan-7B`, `baichuan-inc/Baichuan-13B-Chat`, etc.)
|
||||
- BLOOM (`bigscience/bloom`, `bigscience/bloomz`, etc.)
|
||||
- Falcon (`tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc.)
|
||||
- GPT-2 (`gpt2`, `gpt2-xl`, etc.)
|
||||
- GPTNeoX (`EleutherAI/gpt-neox-20b`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc.)
|
||||
- LLaMA (`lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.)
|
||||
- GPT BigCode (`bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, etc.)
|
||||
- GPT-J (`EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc.)
|
||||
- GPT-NeoX (`EleutherAI/gpt-neox-20b`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc.)
|
||||
- InternLM (`internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc.)
|
||||
- LLaMA & LLaMA-2 (`meta-llama/Llama-2-70b-hf`, `lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.)
|
||||
- MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.)
|
||||
- OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.)
|
||||
- Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.)
|
||||
|
||||
Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source):
|
||||
|
||||
|
||||
@ -17,9 +17,11 @@ def main(args: argparse.Namespace):
|
||||
# the engine will automatically process the request in multiple batches.
|
||||
llm = LLM(
|
||||
model=args.model,
|
||||
tokenizer=args.tokenizer,
|
||||
tensor_parallel_size=args.tensor_parallel_size,
|
||||
max_num_seqs=args.batch_size,
|
||||
max_num_batched_tokens=args.batch_size * args.input_len,
|
||||
trust_remote_code=args.trust_remote_code,
|
||||
)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
@ -63,6 +65,7 @@ if __name__ == '__main__':
|
||||
description='Benchmark the latency of processing a single batch of '
|
||||
'requests till completion.')
|
||||
parser.add_argument('--model', type=str, default='facebook/opt-125m')
|
||||
parser.add_argument('--tokenizer', type=str, default=None)
|
||||
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
|
||||
parser.add_argument('--input-len', type=int, default=32)
|
||||
parser.add_argument('--output-len', type=int, default=128)
|
||||
@ -72,5 +75,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--use-beam-search', action='store_true')
|
||||
parser.add_argument('--num-iters', type=int, default=3,
|
||||
help='Number of iterations to run.')
|
||||
parser.add_argument('--trust-remote-code', action='store_true',
|
||||
help='trust remote code from huggingface')
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
||||
@ -24,20 +24,13 @@ from typing import AsyncGenerator, List, Tuple
|
||||
|
||||
import aiohttp
|
||||
import numpy as np
|
||||
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizerBase
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
# (prompt len, output len, latency)
|
||||
REQUEST_LATENCY: List[Tuple[int, int, float]] = []
|
||||
|
||||
|
||||
def get_tokenizer(model_name: str) -> PreTrainedTokenizerBase:
|
||||
config = AutoConfig.from_pretrained(model_name)
|
||||
if config.model_type == "llama":
|
||||
# A workaround for potential protobuf errors.
|
||||
model_name = "hf-internal-testing/llama-tokenizer"
|
||||
return AutoTokenizer.from_pretrained(model_name)
|
||||
|
||||
|
||||
def sample_requests(
|
||||
dataset_path: str,
|
||||
num_requests: int,
|
||||
@ -184,7 +177,7 @@ def main(args: argparse.Namespace):
|
||||
np.random.seed(args.seed)
|
||||
|
||||
api_url = f"http://{args.host}:{args.port}/generate"
|
||||
tokenizer = get_tokenizer(args.tokenizer)
|
||||
tokenizer = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code)
|
||||
input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
|
||||
|
||||
benchmark_start_time = time.time()
|
||||
@ -217,7 +210,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--backend", type=str, default="vllm",
|
||||
choices=["vllm", "tgi"])
|
||||
parser.add_argument("--host", type=str, default="localhost")
|
||||
parser.add_argument("--port", type=int, default=8001)
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
parser.add_argument("--dataset", type=str, required=True,
|
||||
help="Path to the dataset.")
|
||||
parser.add_argument("--tokenizer", type=str, required=True,
|
||||
@ -234,5 +227,7 @@ if __name__ == "__main__":
|
||||
"Otherwise, we use Poisson process to synthesize "
|
||||
"the request arrival times.")
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument('--trust-remote-code', action='store_true',
|
||||
help='trust remote code from huggingface')
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
||||
@ -6,23 +6,11 @@ import time
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
from transformers import (AutoConfig, AutoTokenizer, AutoModelForCausalLM,
|
||||
PreTrainedTokenizerBase)
|
||||
from transformers import AutoModelForCausalLM, PreTrainedTokenizerBase
|
||||
from tqdm import tqdm
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
|
||||
def get_tokenizer(model_name: str) -> PreTrainedTokenizerBase:
|
||||
config = AutoConfig.from_pretrained(model_name)
|
||||
if config.model_type == "llama":
|
||||
# A workaround for potential protobuf errors.
|
||||
model_name = "hf-internal-testing/llama-tokenizer"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
# To enable padding in the HF backend.
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
return tokenizer
|
||||
return AutoTokenizer.from_pretrained(model_name)
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
|
||||
def sample_requests(
|
||||
@ -74,15 +62,19 @@ def sample_requests(
|
||||
def run_vllm(
|
||||
requests: List[Tuple[str, int, int]],
|
||||
model: str,
|
||||
tokenizer: str,
|
||||
tensor_parallel_size: int,
|
||||
seed: int,
|
||||
n: int,
|
||||
use_beam_search: bool,
|
||||
trust_remote_code: bool,
|
||||
) -> float:
|
||||
llm = LLM(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
seed=seed,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
# Add the requests to the engine.
|
||||
@ -116,11 +108,14 @@ def run_hf(
|
||||
n: int,
|
||||
use_beam_search: bool,
|
||||
max_batch_size: int,
|
||||
trust_remote_code: bool,
|
||||
) -> float:
|
||||
assert not use_beam_search
|
||||
tokenizer = get_tokenizer(model)
|
||||
llm = AutoModelForCausalLM.from_pretrained(
|
||||
model, torch_dtype=torch.float16)
|
||||
llm = AutoModelForCausalLM.from_pretrained(model,
|
||||
torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
|
||||
if llm.config.model_type == "llama":
|
||||
# To enable padding in the HF backend.
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
llm = llm.cuda()
|
||||
|
||||
pbar = tqdm(total=len(requests))
|
||||
@ -170,17 +165,18 @@ def main(args: argparse.Namespace):
|
||||
random.seed(args.seed)
|
||||
|
||||
# Sample the requests.
|
||||
tokenizer = get_tokenizer(args.model)
|
||||
tokenizer = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code)
|
||||
requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
|
||||
|
||||
if args.backend == "vllm":
|
||||
elapsed_time = run_vllm(
|
||||
requests, args.model, args.tensor_parallel_size, args.seed, args.n,
|
||||
args.use_beam_search)
|
||||
requests, args.model, args.tokenizer, args.tensor_parallel_size,
|
||||
args.seed, args.n, args.use_beam_search, args.trust_remote_code)
|
||||
elif args.backend == "hf":
|
||||
assert args.tensor_parallel_size == 1
|
||||
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
|
||||
args.use_beam_search, args.hf_max_batch_size)
|
||||
elapsed_time = run_hf(
|
||||
requests, args.model, tokenizer, args.n, args.use_beam_search,
|
||||
args.hf_max_batch_size, args.trust_remote_code)
|
||||
else:
|
||||
raise ValueError(f"Unknown backend: {args.backend}")
|
||||
total_num_tokens = sum(
|
||||
@ -198,6 +194,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--dataset", type=str, required=True,
|
||||
help="Path to the dataset.")
|
||||
parser.add_argument("--model", type=str, default="facebook/opt-125m")
|
||||
parser.add_argument("--tokenizer", type=str, default=None)
|
||||
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
|
||||
parser.add_argument("--n", type=int, default=1,
|
||||
help="Number of generated sequences per prompt.")
|
||||
@ -207,12 +204,18 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--hf-max-batch-size", type=int, default=None,
|
||||
help="Maximum batch size for HF backend.")
|
||||
parser.add_argument('--trust-remote-code',
|
||||
action='store_true',
|
||||
help='trust remote code from huggingface')
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.backend == "vllm":
|
||||
if args.hf_max_batch_size is not None:
|
||||
raise ValueError("HF max batch size is only for HF backend.")
|
||||
elif args.backend == "hf":
|
||||
if args.hf_max_batch_size is None:
|
||||
raise ValueError("HF max batch size is required for HF backend.")
|
||||
if args.tokenizer is None:
|
||||
args.tokenizer = args.model
|
||||
|
||||
main(args)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
#!/bin/bash
|
||||
|
||||
PORT=8001
|
||||
PORT=8000
|
||||
MODEL=$1
|
||||
TOKENS=$2
|
||||
|
||||
|
||||
@ -4,9 +4,25 @@ void silu_and_mul(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& input);
|
||||
|
||||
void gelu_new(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& input);
|
||||
|
||||
void gelu_fast(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& input);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def(
|
||||
"silu_and_mul",
|
||||
&silu_and_mul,
|
||||
"Activation function used in SwiGLU.");
|
||||
m.def(
|
||||
"gelu_new",
|
||||
&gelu_new,
|
||||
"GELU implementation used in GPT-2.");
|
||||
m.def(
|
||||
"gelu_fast",
|
||||
&gelu_fast,
|
||||
"Approximate GELU implementation.");
|
||||
}
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include "dispatch_utils.h"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
template<typename T>
|
||||
@ -34,9 +36,7 @@ void silu_and_mul(
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(d, 1024));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
at::ScalarType::Half,
|
||||
at::ScalarType::BFloat16,
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(),
|
||||
"silu_and_mul_kernel",
|
||||
[&] {
|
||||
@ -46,3 +46,69 @@ void silu_and_mul(
|
||||
d);
|
||||
});
|
||||
}
|
||||
|
||||
namespace vllm {
|
||||
|
||||
// Element-wise activation kernel template.
|
||||
template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
|
||||
__global__ void activation_kernel(
|
||||
scalar_t* __restrict__ out, // [num_tokens, d]
|
||||
const scalar_t* __restrict__ input, // [num_tokens, d]
|
||||
const int d) {
|
||||
const int token_idx = blockIdx.x;
|
||||
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||
const scalar_t x = __ldg(&input[token_idx * d + idx]);
|
||||
out[token_idx * d + idx] = ACT_FN(x);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
// Launch element-wise activation kernel.
|
||||
#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
|
||||
int num_tokens = input.size(0); \
|
||||
int d = input.size(1); \
|
||||
dim3 grid(num_tokens); \
|
||||
dim3 block(std::min(d, 1024)); \
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|
||||
VLLM_DISPATCH_FLOATING_TYPES( \
|
||||
input.scalar_type(), \
|
||||
"activation_kernel", \
|
||||
[&] { \
|
||||
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>( \
|
||||
out.data_ptr<scalar_t>(), \
|
||||
input.data_ptr<scalar_t>(), \
|
||||
d); \
|
||||
});
|
||||
|
||||
namespace vllm {
|
||||
|
||||
template<typename T>
|
||||
__device__ __forceinline__ T gelu_new_kernel(const T& x) {
|
||||
const float x3 = (float) (x * x * x);
|
||||
const T t = (T) tanhf((T) (0.79788456f * (float) (x + (T) (0.044715f * x3))));
|
||||
return ((T) 0.5) * x * (((T) 1.0) + t);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
__device__ __forceinline__ T gelu_fast_kernel(const T& x) {
|
||||
const float f = (float) x;
|
||||
const T t = (T) tanhf(((T) (f * 0.79788456f)) * (((T) 1.0) + (T) (0.044715f * f) * x));
|
||||
return ((T) 0.5) * x * (((T) 1.0) + t);
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
void gelu_new(
|
||||
torch::Tensor& out, // [num_tokens, d]
|
||||
torch::Tensor& input) // [num_tokens, d]
|
||||
{
|
||||
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel);
|
||||
}
|
||||
|
||||
void gelu_fast(
|
||||
torch::Tensor& out, // [num_tokens, d]
|
||||
torch::Tensor& input) // [num_tokens, d]
|
||||
{
|
||||
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel);
|
||||
}
|
||||
|
||||
@ -1,15 +1,18 @@
|
||||
#include <torch/extension.h>
|
||||
#include <c10/util/Optional.h>
|
||||
|
||||
void single_query_cached_kv_attention(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& query,
|
||||
torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache,
|
||||
torch::Tensor& head_mapping,
|
||||
float scale,
|
||||
torch::Tensor& block_tables,
|
||||
torch::Tensor& context_lens,
|
||||
int block_size,
|
||||
int max_context_len);
|
||||
int max_context_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def(
|
||||
|
||||
@ -74,14 +74,20 @@ template<
|
||||
__global__ void single_query_cached_kv_attention_kernel(
|
||||
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||
const scalar_t* __restrict__ k_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||
const scalar_t* __restrict__ v_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
||||
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
||||
const int* __restrict__ head_mapping, // [num_heads]
|
||||
const float scale,
|
||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
const int* __restrict__ context_lens, // [num_seqs]
|
||||
const int max_num_blocks_per_seq,
|
||||
const int q_stride) {
|
||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||
const int q_stride,
|
||||
const int kv_block_stride,
|
||||
const int kv_head_stride) {
|
||||
constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
|
||||
constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS
|
||||
assert(NUM_THREADS % THREAD_GROUP_SIZE == 0);
|
||||
constexpr int NUM_TOKENS_PER_THREAD_GROUP = (BLOCK_SIZE + WARP_SIZE - 1) / WARP_SIZE;
|
||||
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
||||
const int thread_idx = threadIdx.x;
|
||||
@ -90,7 +96,9 @@ __global__ void single_query_cached_kv_attention_kernel(
|
||||
|
||||
const int head_idx = blockIdx.x;
|
||||
const int num_heads = gridDim.x;
|
||||
const int kv_head_idx = head_mapping[head_idx];
|
||||
const int seq_idx = blockIdx.y;
|
||||
const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
|
||||
|
||||
// A vector type to store a part of a key or a query.
|
||||
// The vector size is configured in such a way that the threads in a thread group
|
||||
@ -114,12 +122,13 @@ __global__ void single_query_cached_kv_attention_kernel(
|
||||
// th vectors of the query, and so on.
|
||||
// NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous.
|
||||
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
|
||||
Q_vec q_vecs[NUM_VECS_PER_THREAD];
|
||||
__shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NUM_VECS_PER_THREAD; i++) {
|
||||
for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS) {
|
||||
const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
|
||||
q_vecs[i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
|
||||
q_vecs[thread_group_offset][i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
|
||||
}
|
||||
__syncthreads(); // TODO(naed90): possible speedup if this is replaced with a memory wall right before we use q_vecs
|
||||
|
||||
// Memory planning.
|
||||
extern __shared__ char shared_mem[];
|
||||
@ -156,8 +165,8 @@ __global__ void single_query_cached_kv_attention_kernel(
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
|
||||
const scalar_t* k_ptr = k_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE
|
||||
+ head_idx * HEAD_SIZE * BLOCK_SIZE
|
||||
const scalar_t* k_ptr = k_cache + physical_block_number * kv_block_stride
|
||||
+ kv_head_idx * kv_head_stride
|
||||
+ physical_block_offset * x;
|
||||
const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
|
||||
const int offset1 = (vec_idx * VEC_SIZE) / x;
|
||||
@ -167,12 +176,14 @@ __global__ void single_query_cached_kv_attention_kernel(
|
||||
|
||||
// Compute dot product.
|
||||
// This includes a reduction across the threads in the same thread group.
|
||||
const float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs, k_vecs);
|
||||
const bool mask = token_idx >= context_len;
|
||||
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs);
|
||||
// Add the ALiBi bias if slopes are given.
|
||||
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0;
|
||||
|
||||
if (thread_group_offset == 0) {
|
||||
// Store the partial reductions to shared memory.
|
||||
// NOTE(woosuk): It is required to zero out the masked logits.
|
||||
const bool mask = token_idx >= context_len;
|
||||
logits[token_idx] = mask ? 0.f : qk;
|
||||
// Update the max value.
|
||||
qk_max = mask ? qk_max : fmaxf(qk_max, qk);
|
||||
@ -235,6 +246,8 @@ __global__ void single_query_cached_kv_attention_kernel(
|
||||
accs[i] = 0.f;
|
||||
}
|
||||
|
||||
scalar_t zero_value;
|
||||
zero(zero_value);
|
||||
for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) {
|
||||
const int physical_block_number = block_table[block_idx];
|
||||
const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
|
||||
@ -242,14 +255,24 @@ __global__ void single_query_cached_kv_attention_kernel(
|
||||
L_vec logits_vec;
|
||||
from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx));
|
||||
|
||||
const scalar_t* v_ptr = v_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE
|
||||
+ head_idx * HEAD_SIZE * BLOCK_SIZE;
|
||||
const scalar_t* v_ptr = v_cache + physical_block_number * kv_block_stride
|
||||
+ kv_head_idx * kv_head_stride;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
||||
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
||||
if (row_idx < HEAD_SIZE) {
|
||||
const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
|
||||
V_vec v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
|
||||
if (block_idx == num_blocks - 1) {
|
||||
// NOTE(woosuk): When v_vec contains the tokens that are out of the context,
|
||||
// we should explicitly zero out the values since they may contain NaNs.
|
||||
// See https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
|
||||
scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
|
||||
#pragma unroll
|
||||
for (int j = 0; j <= V_VEC_SIZE; j++) {
|
||||
v_vec_ptr[j] = token_idx + j < context_len ? v_vec_ptr[j] : zero_value;
|
||||
}
|
||||
}
|
||||
accs[i] += dot(logits_vec, v_vec);
|
||||
}
|
||||
}
|
||||
@ -324,11 +347,15 @@ __global__ void single_query_cached_kv_attention_kernel(
|
||||
query_ptr, \
|
||||
key_cache_ptr, \
|
||||
value_cache_ptr, \
|
||||
head_mapping_ptr, \
|
||||
scale, \
|
||||
block_tables_ptr, \
|
||||
context_lens_ptr, \
|
||||
max_num_blocks_per_seq, \
|
||||
query_stride);
|
||||
alibi_slopes_ptr, \
|
||||
q_stride, \
|
||||
kv_block_stride, \
|
||||
kv_head_stride);
|
||||
|
||||
// TODO(woosuk): Tune NUM_THREADS.
|
||||
template<
|
||||
@ -340,23 +367,33 @@ void single_query_cached_kv_attention_launcher(
|
||||
torch::Tensor& query,
|
||||
torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache,
|
||||
torch::Tensor& head_mapping,
|
||||
float scale,
|
||||
torch::Tensor& block_tables,
|
||||
torch::Tensor& context_lens,
|
||||
int max_context_len) {
|
||||
int max_context_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes) {
|
||||
int num_seqs = query.size(0);
|
||||
int num_heads = query.size(1);
|
||||
int head_size = query.size(2);
|
||||
int max_num_blocks_per_seq = block_tables.size(1);
|
||||
int query_stride = query.stride(0);
|
||||
int q_stride = query.stride(0);
|
||||
int kv_block_stride = key_cache.stride(0);
|
||||
int kv_head_stride = key_cache.stride(1);
|
||||
|
||||
int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
|
||||
assert(head_size % thread_group_size == 0);
|
||||
|
||||
// NOTE: alibi_slopes is optional.
|
||||
const float* alibi_slopes_ptr = alibi_slopes ?
|
||||
reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
|
||||
: nullptr;
|
||||
|
||||
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
|
||||
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
|
||||
T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
|
||||
T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
|
||||
int* head_mapping_ptr = reinterpret_cast<int*>(head_mapping.data_ptr());
|
||||
int* block_tables_ptr = block_tables.data_ptr<int>();
|
||||
int* context_lens_ptr = context_lens.data_ptr<int>();
|
||||
|
||||
@ -371,7 +408,7 @@ void single_query_cached_kv_attention_launcher(
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
switch (head_size) {
|
||||
// NOTE(woosuk): To reduce the compilation time, we omitted head sizes
|
||||
// 32, 160, 192, 256.
|
||||
// 32, 160, 192.
|
||||
// case 32:
|
||||
// LAUNCH_ATTENTION_KERNEL(T, 32, BLOCK_SIZE, NUM_THREADS);
|
||||
// break;
|
||||
@ -384,6 +421,9 @@ void single_query_cached_kv_attention_launcher(
|
||||
case 96:
|
||||
LAUNCH_ATTENTION_KERNEL(T, 96, BLOCK_SIZE, NUM_THREADS);
|
||||
break;
|
||||
case 112:
|
||||
LAUNCH_ATTENTION_KERNEL(T, 112, BLOCK_SIZE, NUM_THREADS);
|
||||
break;
|
||||
case 128:
|
||||
LAUNCH_ATTENTION_KERNEL(T, 128, BLOCK_SIZE, NUM_THREADS);
|
||||
break;
|
||||
@ -393,9 +433,9 @@ void single_query_cached_kv_attention_launcher(
|
||||
// case 192:
|
||||
// LAUNCH_ATTENTION_KERNEL(T, 192, BLOCK_SIZE, NUM_THREADS);
|
||||
// break;
|
||||
// case 256:
|
||||
// LAUNCH_ATTENTION_KERNEL(T, 256, BLOCK_SIZE, NUM_THREADS);
|
||||
// break;
|
||||
case 256:
|
||||
LAUNCH_ATTENTION_KERNEL(T, 256, BLOCK_SIZE, NUM_THREADS);
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(false, "Unsupported head size: ", head_size);
|
||||
break;
|
||||
@ -408,10 +448,12 @@ void single_query_cached_kv_attention_launcher(
|
||||
query, \
|
||||
key_cache, \
|
||||
value_cache, \
|
||||
head_mapping, \
|
||||
scale, \
|
||||
block_tables, \
|
||||
context_lens, \
|
||||
max_context_len);
|
||||
max_context_len, \
|
||||
alibi_slopes);
|
||||
|
||||
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
|
||||
// 1, 2, 4, 64, 128, 256.
|
||||
@ -454,11 +496,13 @@ void single_query_cached_kv_attention(
|
||||
torch::Tensor& query, // [num_seqs, num_heads, head_size]
|
||||
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||
torch::Tensor& head_mapping, // [num_heads]
|
||||
float scale,
|
||||
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
torch::Tensor& context_lens, // [num_seqs]
|
||||
int block_size,
|
||||
int max_context_len) {
|
||||
int max_context_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes) {
|
||||
if (query.dtype() == at::ScalarType::Float) {
|
||||
CALL_KERNEL_LAUNCHER_BLOCK_SIZE(float);
|
||||
} else if (query.dtype() == at::ScalarType::Half) {
|
||||
|
||||
@ -420,4 +420,14 @@ inline __device__ void from_float(bf16_8_t& dst, Float8_ src) {
|
||||
#endif
|
||||
}
|
||||
|
||||
// Zero-out a variable.
|
||||
inline __device__ void zero(__nv_bfloat16& dst) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
assert(false);
|
||||
#else
|
||||
// Same as CUDART_ZERO_BF16 introduced in CUDA 12.2.
|
||||
dst = __ushort_as_bfloat16((unsigned short)0x0000U);
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
@ -390,11 +390,6 @@ inline __device__ float sum(uint4 v) {
|
||||
return sum(c);
|
||||
}
|
||||
|
||||
// Zero-out a vector.
|
||||
inline __device__ void zero(uint16_t& dst) {
|
||||
dst = uint16_t(0);
|
||||
}
|
||||
|
||||
// From float32 to float16.
|
||||
inline __device__ void from_float(uint16_t& dst, float src) {
|
||||
dst = float_to_half(src);
|
||||
@ -441,4 +436,9 @@ inline __device__ Float8_ to_float(uint4 u) {
|
||||
return tmp;
|
||||
}
|
||||
|
||||
// Zero-out a variable.
|
||||
inline __device__ void zero(uint16_t& dst) {
|
||||
dst = uint16_t(0);
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
@ -265,4 +265,9 @@ inline __device__ Float8_ to_float(Float8_ u) {
|
||||
return u;
|
||||
}
|
||||
|
||||
// Zero-out a variable.
|
||||
inline __device__ void zero(float& dst) {
|
||||
dst = 0.f;
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include "dispatch_utils.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <map>
|
||||
@ -125,9 +127,7 @@ void copy_blocks(
|
||||
dim3 grid(num_layers, num_pairs);
|
||||
dim3 block(std::min(1024, numel_per_block));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
at::ScalarType::Half,
|
||||
at::ScalarType::BFloat16,
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
|
||||
vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
key_cache_ptrs_tensor.data_ptr<int64_t>(),
|
||||
@ -202,9 +202,7 @@ void reshape_and_cache(
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(num_heads * head_size, 512));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
at::ScalarType::Half,
|
||||
at::ScalarType::BFloat16,
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
key.scalar_type(),
|
||||
"reshape_and_cache_kernel",
|
||||
[&] {
|
||||
@ -364,9 +362,7 @@ void gather_cached_kv(
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(num_heads * head_size, 512));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
at::ScalarType::Half,
|
||||
at::ScalarType::BFloat16,
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
key.scalar_type(),
|
||||
"gather_cached_kv_kernel_optimized",
|
||||
[&] {
|
||||
|
||||
14
csrc/dispatch_utils.h
Normal file
14
csrc/dispatch_utils.h
Normal file
@ -0,0 +1,14 @@
|
||||
/*
|
||||
* Adapted from
|
||||
* https://github.com/pytorch/pytorch/blob/v2.0.1/aten/src/ATen/Dispatch.h
|
||||
*/
|
||||
#include <torch/extension.h>
|
||||
|
||||
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
|
||||
|
||||
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH( \
|
||||
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
||||
@ -1,6 +1,7 @@
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include "dispatch_utils.h"
|
||||
#include "reduction_utils.cuh"
|
||||
|
||||
namespace vllm {
|
||||
@ -46,9 +47,7 @@ void rms_norm(
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(hidden_size, 1024));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
at::ScalarType::Half,
|
||||
at::ScalarType::BFloat16,
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(),
|
||||
"rms_norm_kernel",
|
||||
[&] {
|
||||
|
||||
@ -1,15 +1,16 @@
|
||||
#include <torch/extension.h>
|
||||
|
||||
void rotary_embedding_neox(
|
||||
void rotary_embedding(
|
||||
torch::Tensor& positions,
|
||||
torch::Tensor& query,
|
||||
torch::Tensor& key,
|
||||
int head_size,
|
||||
torch::Tensor& cos_sin_cache);
|
||||
torch::Tensor& cos_sin_cache,
|
||||
bool is_neox);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def(
|
||||
"rotary_embedding_neox",
|
||||
&rotary_embedding_neox,
|
||||
"Apply GPT-NeoX style rotary embedding to query and key");
|
||||
"rotary_embedding",
|
||||
&rotary_embedding,
|
||||
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
|
||||
}
|
||||
|
||||
@ -1,17 +1,51 @@
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include "dispatch_utils.h"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
template<typename scalar_t>
|
||||
__global__ void rotary_embedding_neox_kernel(
|
||||
template<typename scalar_t, bool IS_NEOX>
|
||||
inline __device__ void apply_rotary_embedding(
|
||||
scalar_t* __restrict__ arr,
|
||||
const scalar_t* __restrict__ cos_ptr,
|
||||
const scalar_t* __restrict__ sin_ptr,
|
||||
int rot_offset,
|
||||
int embed_dim)
|
||||
{
|
||||
int x_index, y_index;
|
||||
scalar_t cos, sin;
|
||||
if (IS_NEOX) {
|
||||
// GPT-NeoX style rotary embedding.
|
||||
x_index = rot_offset;
|
||||
y_index = embed_dim + rot_offset;
|
||||
cos = __ldg(cos_ptr + x_index);
|
||||
sin = __ldg(sin_ptr + x_index);
|
||||
} else {
|
||||
// GPT-J style rotary embedding.
|
||||
x_index = 2 * rot_offset;
|
||||
y_index = 2 * rot_offset + 1;
|
||||
cos = __ldg(cos_ptr + x_index / 2);
|
||||
sin = __ldg(sin_ptr + x_index / 2);
|
||||
}
|
||||
|
||||
const scalar_t x = arr[x_index];
|
||||
const scalar_t y = arr[y_index];
|
||||
arr[x_index] = x * cos - y * sin;
|
||||
arr[y_index] = y * cos + x * sin;
|
||||
}
|
||||
|
||||
template<typename scalar_t, bool IS_NEOX>
|
||||
__global__ void rotary_embedding_kernel(
|
||||
const int64_t* __restrict__ positions, // [num_tokens]
|
||||
scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size]
|
||||
scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
||||
scalar_t* __restrict__ key, // [num_tokens, num_kv_heads, head_size]
|
||||
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
|
||||
const int rot_dim,
|
||||
const int stride,
|
||||
const int query_stride,
|
||||
const int key_stride,
|
||||
const int num_heads,
|
||||
const int num_kv_heads,
|
||||
const int head_size) {
|
||||
// Each thread block is responsible for one token.
|
||||
const int token_idx = blockIdx.x;
|
||||
@ -19,65 +53,75 @@ __global__ void rotary_embedding_neox_kernel(
|
||||
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
|
||||
|
||||
const int embed_dim = rot_dim / 2;
|
||||
const int n = num_heads * embed_dim;
|
||||
for (int i = threadIdx.x; i < n; i += blockDim.x) {
|
||||
const scalar_t* cos_ptr = cache_ptr;
|
||||
const scalar_t* sin_ptr = cache_ptr + embed_dim;
|
||||
|
||||
const int nq = num_heads * embed_dim;
|
||||
for (int i = threadIdx.x; i < nq; i += blockDim.x) {
|
||||
const int head_idx = i / embed_dim;
|
||||
const int token_head = token_idx * stride + head_idx * head_size;
|
||||
|
||||
const int token_head = token_idx * query_stride + head_idx * head_size;
|
||||
const int rot_offset = i % embed_dim;
|
||||
const int x_index = rot_offset;
|
||||
const int y_index = embed_dim + rot_offset;
|
||||
apply_rotary_embedding<scalar_t, IS_NEOX>(query + token_head, cos_ptr,
|
||||
sin_ptr, rot_offset, embed_dim);
|
||||
}
|
||||
|
||||
const int out_x = token_idx * stride + head_idx * head_size + x_index;
|
||||
const int out_y = token_idx * stride + head_idx * head_size + y_index;
|
||||
|
||||
const scalar_t cos = __ldg(cache_ptr + x_index);
|
||||
const scalar_t sin = __ldg(cache_ptr + y_index);
|
||||
|
||||
const scalar_t q_x = query[token_head + x_index];
|
||||
const scalar_t q_y = query[token_head + y_index];
|
||||
query[out_x] = q_x * cos - q_y * sin;
|
||||
query[out_y] = q_y * cos + q_x * sin;
|
||||
|
||||
const scalar_t k_x = key[token_head + x_index];
|
||||
const scalar_t k_y = key[token_head + y_index];
|
||||
key[out_x] = k_x * cos - k_y * sin;
|
||||
key[out_y] = k_y * cos + k_x * sin;
|
||||
const int nk = num_kv_heads * embed_dim;
|
||||
for (int i = threadIdx.x; i < nk; i += blockDim.x) {
|
||||
const int head_idx = i / embed_dim;
|
||||
const int token_head = token_idx * key_stride + head_idx * head_size;
|
||||
const int rot_offset = i % embed_dim;
|
||||
apply_rotary_embedding<scalar_t, IS_NEOX>(key + token_head, cos_ptr,
|
||||
sin_ptr, rot_offset, embed_dim);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
void rotary_embedding_neox(
|
||||
void rotary_embedding(
|
||||
torch::Tensor& positions, // [num_tokens]
|
||||
torch::Tensor& query, // [num_tokens, num_heads * head_size]
|
||||
torch::Tensor& key, // [num_tokens, num_heads * head_size]
|
||||
torch::Tensor& key, // [num_tokens, num_kv_heads * head_size]
|
||||
int head_size,
|
||||
torch::Tensor& cos_sin_cache) // [max_position, rot_dim]
|
||||
{
|
||||
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
|
||||
bool is_neox) {
|
||||
int num_tokens = query.size(0);
|
||||
int rot_dim = cos_sin_cache.size(1);
|
||||
int num_heads = query.size(1) / head_size;
|
||||
int stride = query.stride(0);
|
||||
TORCH_CHECK(stride == key.stride(0));
|
||||
int num_kv_heads = key.size(1) / head_size;
|
||||
int query_stride = query.stride(0);
|
||||
int key_stride = key.stride(0);
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(num_heads * rot_dim / 2, 512));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
at::ScalarType::Half,
|
||||
at::ScalarType::BFloat16,
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
query.scalar_type(),
|
||||
"rotary_embedding_neox",
|
||||
"rotary_embedding",
|
||||
[&] {
|
||||
vllm::rotary_embedding_neox_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
if (is_neox) {
|
||||
vllm::rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
|
||||
positions.data_ptr<int64_t>(),
|
||||
query.data_ptr<scalar_t>(),
|
||||
key.data_ptr<scalar_t>(),
|
||||
cos_sin_cache.data_ptr<scalar_t>(),
|
||||
rot_dim,
|
||||
stride,
|
||||
query_stride,
|
||||
key_stride,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_size);
|
||||
} else {
|
||||
vllm::rotary_embedding_kernel<scalar_t, false><<<grid, block, 0, stream>>>(
|
||||
positions.data_ptr<int64_t>(),
|
||||
query.data_ptr<scalar_t>(),
|
||||
key.data_ptr<scalar_t>(),
|
||||
cos_sin_cache.data_ptr<scalar_t>(),
|
||||
rot_dim,
|
||||
query_stride,
|
||||
key_stride,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_size);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@ -4,14 +4,14 @@
|
||||
|
||||
```bash
|
||||
# Install dependencies.
|
||||
pip -r requirements-docs.txt
|
||||
pip install -r requirements-docs.txt
|
||||
|
||||
# Build the docs.
|
||||
make clean
|
||||
make html
|
||||
```
|
||||
|
||||
## Open the docs with your brower
|
||||
## Open the docs with your browser
|
||||
|
||||
```bash
|
||||
python -m http.server -d build/html/
|
||||
|
||||
@ -29,7 +29,7 @@ vLLM is fast with:
|
||||
|
||||
* State-of-the-art serving throughput
|
||||
* Efficient management of attention key and value memory with **PagedAttention**
|
||||
* Dynamic batching of incoming requests
|
||||
* Continuous batching of incoming requests
|
||||
* Optimized CUDA kernels
|
||||
|
||||
vLLM is flexible and easy to use with:
|
||||
@ -40,7 +40,11 @@ vLLM is flexible and easy to use with:
|
||||
* Streaming outputs
|
||||
* OpenAI-compatible API server
|
||||
|
||||
For more information, please refer to our `blog post <https://vllm.ai>`_.
|
||||
For more information, check out the following:
|
||||
|
||||
* `vLLM announcing blog post <https://vllm.ai>`_ (intro to PagedAttention)
|
||||
* `How continuous batching enables 23x throughput in LLM inference while reducing p50 latency <https://www.anyscale.com/blog/continuous-batching-llm-inference>`_ by Cade Daniel et al.
|
||||
|
||||
|
||||
|
||||
Documentation
|
||||
@ -53,6 +57,13 @@ Documentation
|
||||
getting_started/installation
|
||||
getting_started/quickstart
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:caption: Serving
|
||||
|
||||
serving/distributed_serving
|
||||
serving/run_on_sky
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:caption: Models
|
||||
|
||||
@ -59,7 +59,7 @@ Next, you need to rewrite the :code:`forward` methods of your model by following
|
||||
+ kv_caches: List[KVCache],
|
||||
+ input_metadata: InputMetadata,
|
||||
+ cache_events: Optional[List[torch.cuda.Event]],
|
||||
+) -> Dict[int, SequenceOutputs]:
|
||||
+) -> SamplerOutput:
|
||||
|
||||
3. Update the code by considering that :code:`input_ids` and :code:`positions` are now flattened tensors.
|
||||
4. Replace the attention operation with either :code:`GPTPagedAttention` or :code:`GPTNeoXPagedAttention`, depending on the model's architecture.
|
||||
|
||||
@ -14,18 +14,45 @@ Alongside each architecture, we include some popular models that use it.
|
||||
* - Architecture
|
||||
- Models
|
||||
- Example HuggingFace Models
|
||||
* - :code:`AquilaForCausalLM`
|
||||
- Aquila
|
||||
- :code:`BAAI/Aquila-7B`, :code:`BAAI/AquilaChat-7B`, etc.
|
||||
* - :code:`BaiChuanForCausalLM`
|
||||
- Baichuan
|
||||
- :code:`baichuan-inc/Baichuan-7B`, :code:`baichuan-inc/Baichuan-13B-Chat`, etc.
|
||||
* - :code:`BloomForCausalLM`
|
||||
- BLOOM, BLOOMZ, BLOOMChat
|
||||
- :code:`bigscience/bloom`, :code:`bigscience/bloomz`, etc.
|
||||
* - :code:`FalconForCausalLM`
|
||||
- Falcon
|
||||
- :code:`tiiuae/falcon-7b``, :code:`tiiuae/falcon-40b`, :code:`tiiuae/falcon-rw-7b`, etc.
|
||||
* - :code:`GPT2LMHeadModel`
|
||||
- GPT-2
|
||||
- :code:`gpt2`, :code:`gpt2-xl`, etc.
|
||||
* - :code:`GPTBigCodeForCausalLM`
|
||||
- StarCoder, SantaCoder, WizardCoder
|
||||
- :code:`bigcode/starcoder`, :code:`bigcode/gpt_bigcode-santacoder`, :code:`WizardLM/WizardCoder-15B-V1.0`, etc.
|
||||
* - :code:`GPTJForCausalLM`
|
||||
- GPT-J
|
||||
- :code:`EleutherAI/gpt-j-6b`, :code:`nomic-ai/gpt4all-j`, etc.
|
||||
* - :code:`GPTNeoXForCausalLM`
|
||||
- GPT-NeoX, Pythia, OpenAssistant, Dolly V2, StableLM
|
||||
- :code:`EleutherAI/gpt-neox-20b`, :code:`EleutherAI/pythia-12b`, :code:`OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5`, :code:`databricks/dolly-v2-12b`, :code:`stabilityai/stablelm-tuned-alpha-7b`, etc.
|
||||
* - :code:`InternLMForCausalLM`
|
||||
- InternLM
|
||||
- :code:`internlm/internlm-7b`, :code:`internlm/internlm-chat-7b`, etc.
|
||||
* - :code:`LlamaForCausalLM`
|
||||
- LLaMA, Vicuna, Alpaca, Koala, Guanaco
|
||||
- :code:`openlm-research/open_llama_13b`, :code:`lmsys/vicuna-13b-v1.3`, :code:`young-geng/koala`, :code:`JosephusCheung/Guanaco`, etc.
|
||||
- LLaMA, LLaMA-2, Vicuna, Alpaca, Koala, Guanaco
|
||||
- :code:`meta-llama/Llama-2-13b-hf`, :code:`meta-llama/Llama-2-70b-hf`, :code:`openlm-research/open_llama_13b`, :code:`lmsys/vicuna-13b-v1.3`, :code:`young-geng/koala`, etc.
|
||||
* - :code:`MPTForCausalLM`
|
||||
- MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter
|
||||
- :code:`mosaicml/mpt-7b`, :code:`mosaicml/mpt-7b-storywriter`, :code:`mosaicml/mpt-30b`, etc.
|
||||
* - :code:`OPTForCausalLM`
|
||||
- OPT, OPT-IML
|
||||
- :code:`facebook/opt-66b`, :code:`facebook/opt-iml-max-30b`, etc.
|
||||
* - :code:`QWenLMHeadModel`
|
||||
- Qwen
|
||||
- :code:`Qwen/Qwen-7B`, :code:`Qwen/Qwen-7B-Chat`, etc.
|
||||
|
||||
If your model uses one of the above model architectures, you can seamlessly run your model with vLLM.
|
||||
Otherwise, please refer to :ref:`Adding a New Model <adding_a_new_model>` for instructions on how to implement support for your model.
|
||||
|
||||
38
docs/source/serving/distributed_serving.rst
Normal file
38
docs/source/serving/distributed_serving.rst
Normal file
@ -0,0 +1,38 @@
|
||||
.. _distributed_serving:
|
||||
|
||||
Distributed Inference and Serving
|
||||
=================================
|
||||
|
||||
vLLM supports distributed tensor-parallel inference and serving. Currently, we support `Megatron-LM's tensor parallel algorithm <https://arxiv.org/pdf/1909.08053.pdf>`_. We manage the distributed runtime with `Ray <https://github.com/ray-project/ray>`_. To run distributed inference, install Ray with:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ pip install ray
|
||||
|
||||
To run multi-GPU inference with the :code:`LLM` class, set the :code:`tensor_parallel_size` argument to the number of GPUs you want to use. For example, to run inference on 4 GPUs:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from vllm import LLM
|
||||
llm = LLM("facebook/opt-13b", tensor_parallel_size=4)
|
||||
output = llm.generate("San Franciso is a")
|
||||
|
||||
To run multi-GPU serving, pass in the :code:`--tensor-parallel-size` argument when starting the server. For example, to run API server on 4 GPUs:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ python -m vllm.entrypoints.api_server \
|
||||
$ --model facebook/opt-13b \
|
||||
$ --tensor-parallel-size 4
|
||||
|
||||
To scale vLLM beyond a single machine, start a `Ray runtime <https://docs.ray.io/en/latest/ray-core/starting-ray.html>`_ via CLI before running vLLM:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ # On head node
|
||||
$ ray start --head
|
||||
|
||||
$ # On worker nodes
|
||||
$ ray start --address=<ray-head-address>
|
||||
|
||||
After that, you can run inference and serving on multiple machines by launching the vLLM process on the head node by setting :code:`tensor_parallel_size` to the number of GPUs to be the total number of GPUs across all machines.
|
||||
69
docs/source/serving/run_on_sky.rst
Normal file
69
docs/source/serving/run_on_sky.rst
Normal file
@ -0,0 +1,69 @@
|
||||
.. _on_cloud:
|
||||
|
||||
Running on clouds with SkyPilot
|
||||
===============================
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<p align="center">
|
||||
<img src="https://imgur.com/yxtzPEu.png" alt="vLLM"/>
|
||||
</p>
|
||||
|
||||
vLLM can be run on the cloud to scale to multiple GPUs with `SkyPilot <https://github.com/skypilot-org/skypilot>`__, an open-source framework for running LLMs on any cloud.
|
||||
|
||||
To install SkyPilot and setup your cloud credentials, run:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ pip install skypilot
|
||||
$ sky check
|
||||
|
||||
See the vLLM SkyPilot YAML for serving, `serving.yaml <https://github.com/skypilot-org/skypilot/blob/master/llm/vllm/serve.yaml>`__.
|
||||
|
||||
.. code-block:: yaml
|
||||
|
||||
resources:
|
||||
accelerators: A100
|
||||
|
||||
envs:
|
||||
MODEL_NAME: decapoda-research/llama-13b-hf
|
||||
TOKENIZER: hf-internal-testing/llama-tokenizer
|
||||
|
||||
setup: |
|
||||
conda create -n vllm python=3.9 -y
|
||||
conda activate vllm
|
||||
git clone https://github.com/vllm-project/vllm.git
|
||||
cd vllm
|
||||
pip install .
|
||||
pip install gradio
|
||||
|
||||
run: |
|
||||
conda activate vllm
|
||||
echo 'Starting vllm api server...'
|
||||
python -u -m vllm.entrypoints.api_server \
|
||||
--model $MODEL_NAME \
|
||||
--tensor-parallel-size $SKYPILOT_NUM_GPUS_PER_NODE \
|
||||
--tokenizer $TOKENIZER 2>&1 | tee api_server.log &
|
||||
echo 'Waiting for vllm api server to start...'
|
||||
while ! `cat api_server.log | grep -q 'Uvicorn running on'`; do sleep 1; done
|
||||
echo 'Starting gradio server...'
|
||||
python vllm/examples/gradio_webserver.py
|
||||
|
||||
Start the serving the LLaMA-13B model on an A100 GPU:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ sky launch serving.yaml
|
||||
|
||||
Check the output of the command. There will be a sharable gradio link (like the last line of the following). Open it in your browser to use the LLaMA model to do the text completion.
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
(task, pid=7431) Running on public URL: https://<gradio-hash>.gradio.live
|
||||
|
||||
**Optional**: Serve the 65B model instead of the default 13B and use more GPU:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
sky launch -c vllm-serve-new -s serve.yaml --gpus A100:8 --env MODEL_NAME=decapoda-research/llama-65b-hf
|
||||
|
||||
@ -14,7 +14,9 @@ def clear_line(n: int = 1) -> None:
|
||||
print(LINE_UP, end=LINE_CLEAR, flush=True)
|
||||
|
||||
|
||||
def post_http_request(prompt: str, api_url: str, n: int = 1,
|
||||
def post_http_request(prompt: str,
|
||||
api_url: str,
|
||||
n: int = 1,
|
||||
stream: bool = False) -> requests.Response:
|
||||
headers = {"User-Agent": "Test Client"}
|
||||
pload = {
|
||||
@ -30,7 +32,8 @@ def post_http_request(prompt: str, api_url: str, n: int = 1,
|
||||
|
||||
|
||||
def get_streaming_response(response: requests.Response) -> Iterable[List[str]]:
|
||||
for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False,
|
||||
for chunk in response.iter_lines(chunk_size=8192,
|
||||
decode_unicode=False,
|
||||
delimiter=b"\0"):
|
||||
if chunk:
|
||||
data = json.loads(chunk.decode("utf-8"))
|
||||
|
||||
@ -12,9 +12,14 @@ def http_bot(prompt):
|
||||
"stream": True,
|
||||
"max_tokens": 128,
|
||||
}
|
||||
response = requests.post(args.model_url, headers=headers, json=pload, stream=True)
|
||||
response = requests.post(args.model_url,
|
||||
headers=headers,
|
||||
json=pload,
|
||||
stream=True)
|
||||
|
||||
for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"):
|
||||
for chunk in response.iter_lines(chunk_size=8192,
|
||||
decode_unicode=False,
|
||||
delimiter=b"\0"):
|
||||
if chunk:
|
||||
data = json.loads(chunk.decode("utf-8"))
|
||||
output = data["text"][0]
|
||||
@ -23,11 +28,11 @@ def http_bot(prompt):
|
||||
|
||||
def build_demo():
|
||||
with gr.Blocks() as demo:
|
||||
gr.Markdown(
|
||||
"# vLLM text completion demo\n"
|
||||
)
|
||||
inputbox = gr.Textbox(label="Input", placeholder="Enter text and press ENTER")
|
||||
outputbox = gr.Textbox(label="Output", placeholder="Generated result from the model")
|
||||
gr.Markdown("# vLLM text completion demo\n")
|
||||
inputbox = gr.Textbox(label="Input",
|
||||
placeholder="Enter text and press ENTER")
|
||||
outputbox = gr.Textbox(label="Output",
|
||||
placeholder="Generated result from the model")
|
||||
inputbox.submit(http_bot, [inputbox], [outputbox])
|
||||
return demo
|
||||
|
||||
@ -36,7 +41,9 @@ if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="localhost")
|
||||
parser.add_argument("--port", type=int, default=8001)
|
||||
parser.add_argument("--model-url", type=str, default="http://localhost:8000/generate")
|
||||
parser.add_argument("--model-url",
|
||||
type=str,
|
||||
default="http://localhost:8000/generate")
|
||||
args = parser.parse_args()
|
||||
|
||||
demo = build_demo()
|
||||
|
||||
@ -10,19 +10,25 @@ def main(args: argparse.Namespace):
|
||||
|
||||
# Test the following prompts.
|
||||
test_prompts = [
|
||||
("A robot may not injure a human being", SamplingParams()),
|
||||
("A robot may not injure a human being",
|
||||
SamplingParams(temperature=0.0)),
|
||||
("To be or not to be,",
|
||||
SamplingParams(temperature=0.8, top_k=5, presence_penalty=0.2)),
|
||||
("What is the meaning of life?",
|
||||
SamplingParams(n=2, best_of=5, temperature=0.8, top_p=0.95, frequency_penalty=0.1)),
|
||||
SamplingParams(n=2,
|
||||
best_of=5,
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
frequency_penalty=0.1)),
|
||||
("It is only with the heart that one can see rightly",
|
||||
SamplingParams(n=3, best_of=3, use_beam_search=True, temperature=0.0)),
|
||||
SamplingParams(n=3, best_of=3, use_beam_search=True,
|
||||
temperature=0.0)),
|
||||
]
|
||||
|
||||
# Run the engine by calling `engine.step()` manually.
|
||||
request_id = 0
|
||||
while True:
|
||||
# To test iteration-level scheduling, we add one request at each step.
|
||||
# To test continuous batching, we add one request at each step.
|
||||
if test_prompts:
|
||||
prompt, sampling_params = test_prompts.pop(0)
|
||||
engine.add_request(str(request_id), prompt, sampling_params)
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
|
||||
33
examples/openai_chatcompletion_client.py
Normal file
33
examples/openai_chatcompletion_client.py
Normal file
@ -0,0 +1,33 @@
|
||||
import openai
|
||||
|
||||
# Modify OpenAI's API key and API base to use vLLM's API server.
|
||||
openai.api_key = "EMPTY"
|
||||
openai.api_base = "http://localhost:8000/v1"
|
||||
|
||||
# List models API
|
||||
models = openai.Model.list()
|
||||
print("Models:", models)
|
||||
|
||||
model = models["data"][0]["id"]
|
||||
|
||||
# Chat completion API
|
||||
chat_completion = openai.ChatCompletion.create(
|
||||
model=model,
|
||||
messages=[{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "Who won the world series in 2020?"
|
||||
}, {
|
||||
"role":
|
||||
"assistant",
|
||||
"content":
|
||||
"The Los Angeles Dodgers won the World Series in 2020."
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "Where was it played?"
|
||||
}])
|
||||
|
||||
print("Chat completion results:")
|
||||
print(chat_completion)
|
||||
@ -3,21 +3,26 @@ import openai
|
||||
# Modify OpenAI's API key and API base to use vLLM's API server.
|
||||
openai.api_key = "EMPTY"
|
||||
openai.api_base = "http://localhost:8000/v1"
|
||||
model = "facebook/opt-125m"
|
||||
|
||||
# Test list models API
|
||||
# List models API
|
||||
models = openai.Model.list()
|
||||
print("Models:", models)
|
||||
|
||||
# Test completion API
|
||||
stream = True
|
||||
completion = openai.Completion.create(
|
||||
model=model, prompt="A robot may not injure a human being", echo=False, n=2,
|
||||
best_of=3, stream=stream, logprobs=3)
|
||||
model = models["data"][0]["id"]
|
||||
|
||||
# print the completion
|
||||
# Completion API
|
||||
stream = False
|
||||
completion = openai.Completion.create(
|
||||
model=model,
|
||||
prompt="A robot may not injure a human being",
|
||||
echo=False,
|
||||
n=2,
|
||||
stream=stream,
|
||||
logprobs=3)
|
||||
|
||||
print("Completion results:")
|
||||
if stream:
|
||||
for c in completion:
|
||||
print(c)
|
||||
else:
|
||||
print("Completion result:", completion)
|
||||
print(completion)
|
||||
108
format.sh
Executable file
108
format.sh
Executable file
@ -0,0 +1,108 @@
|
||||
#!/usr/bin/env bash
|
||||
# YAPF formatter, adapted from ray and skypilot.
|
||||
#
|
||||
# Usage:
|
||||
# # Do work and commit your work.
|
||||
|
||||
# # Format files that differ from origin/main.
|
||||
# bash format.sh
|
||||
|
||||
# # Commit changed files with message 'Run yapf and pylint'
|
||||
#
|
||||
#
|
||||
# YAPF + Clang formatter (if installed). This script formats all changed files from the last mergebase.
|
||||
# You are encouraged to run this locally before pushing changes for review.
|
||||
|
||||
# Cause the script to exit if a single command fails
|
||||
set -eo pipefail
|
||||
|
||||
# this stops git rev-parse from failing if we run this from the .git directory
|
||||
builtin cd "$(dirname "${BASH_SOURCE:-$0}")"
|
||||
ROOT="$(git rev-parse --show-toplevel)"
|
||||
builtin cd "$ROOT" || exit 1
|
||||
|
||||
YAPF_VERSION=$(yapf --version | awk '{print $2}')
|
||||
PYLINT_VERSION=$(pylint --version | head -n 1 | awk '{print $2}')
|
||||
MYPY_VERSION=$(mypy --version | awk '{print $2}')
|
||||
|
||||
# # params: tool name, tool version, required version
|
||||
tool_version_check() {
|
||||
if [[ $2 != $3 ]]; then
|
||||
echo "Wrong $1 version installed: $3 is required, not $2."
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
tool_version_check "yapf" $YAPF_VERSION "$(grep yapf requirements-dev.txt | cut -d'=' -f3)"
|
||||
tool_version_check "pylint" $PYLINT_VERSION "$(grep "pylint==" requirements-dev.txt | cut -d'=' -f3)"
|
||||
tool_version_check "mypy" "$MYPY_VERSION" "$(grep mypy requirements-dev.txt | cut -d'=' -f3)"
|
||||
|
||||
YAPF_FLAGS=(
|
||||
'--recursive'
|
||||
'--parallel'
|
||||
)
|
||||
|
||||
YAPF_EXCLUDES=(
|
||||
'--exclude' 'build/**'
|
||||
'--exclude' 'vllm/model_executor/parallel_utils/**'
|
||||
)
|
||||
|
||||
# Format specified files
|
||||
format() {
|
||||
yapf --in-place "${YAPF_FLAGS[@]}" "$@"
|
||||
}
|
||||
|
||||
# Format files that differ from main branch. Ignores dirs that are not slated
|
||||
# for autoformat yet.
|
||||
format_changed() {
|
||||
# The `if` guard ensures that the list of filenames is not empty, which
|
||||
# could cause yapf to receive 0 positional arguments, making it hang
|
||||
# waiting for STDIN.
|
||||
#
|
||||
# `diff-filter=ACM` and $MERGEBASE is to ensure we only format files that
|
||||
# exist on both branches.
|
||||
MERGEBASE="$(git merge-base origin/main HEAD)"
|
||||
|
||||
if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then
|
||||
git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs -P 5 \
|
||||
yapf --in-place "${YAPF_EXCLUDES[@]}" "${YAPF_FLAGS[@]}"
|
||||
fi
|
||||
|
||||
}
|
||||
|
||||
# Format all files
|
||||
format_all() {
|
||||
yapf --in-place "${YAPF_FLAGS[@]}" "${YAPF_EXCLUDES[@]}" vllm
|
||||
}
|
||||
|
||||
## This flag formats individual files. --files *must* be the first command line
|
||||
## arg to use this option.
|
||||
if [[ "$1" == '--files' ]]; then
|
||||
format "${@:2}"
|
||||
# If `--all` is passed, then any further arguments are ignored and the
|
||||
# entire python directory is formatted.
|
||||
elif [[ "$1" == '--all' ]]; then
|
||||
format_all
|
||||
else
|
||||
# Format only the files that changed in last commit.
|
||||
format_changed
|
||||
fi
|
||||
echo 'vLLM yapf: Done'
|
||||
|
||||
# Run mypy
|
||||
# TODO(zhuohan): Enable mypy
|
||||
# echo 'vLLM mypy:'
|
||||
# mypy
|
||||
|
||||
# Run Pylint
|
||||
echo 'vLLM Pylint:'
|
||||
pylint vllm
|
||||
|
||||
if ! git diff --quiet &>/dev/null; then
|
||||
echo 'Reformatted files. Please review and stage the changes.'
|
||||
echo 'Changes not staged for commit:'
|
||||
echo
|
||||
git --no-pager diff --name-only
|
||||
|
||||
exit 1
|
||||
fi
|
||||
@ -1,2 +1,13 @@
|
||||
mypy
|
||||
# formatting
|
||||
yapf==0.32.0
|
||||
pylint==2.8.2
|
||||
|
||||
# type checking
|
||||
mypy==0.991
|
||||
types-PyYAML
|
||||
types-requests
|
||||
types-setuptools
|
||||
|
||||
# testing
|
||||
pytest
|
||||
pytest-forked
|
||||
|
||||
@ -1,11 +1,11 @@
|
||||
ninja # For faster builds.
|
||||
psutil
|
||||
ray
|
||||
ray >= 2.5.1
|
||||
sentencepiece # Required for LLaMA tokenizer.
|
||||
numpy
|
||||
torch >= 2.0.0
|
||||
transformers >= 4.28.0 # Required for LLaMA.
|
||||
xformers >= 0.0.19
|
||||
transformers >= 4.33.1 # Required for Code Llama.
|
||||
xformers >= 0.0.21
|
||||
fastapi
|
||||
uvicorn
|
||||
pydantic # Required for OpenAI server.
|
||||
pydantic < 2 # Required for OpenAI server.
|
||||
|
||||
32
setup.py
32
setup.py
@ -20,10 +20,9 @@ ABI = 1 if torch._C._GLIBCXX_USE_CXX11_ABI else 0
|
||||
CXX_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
|
||||
NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
if CUDA_HOME is None:
|
||||
raise RuntimeError(
|
||||
f"Cannot find CUDA at CUDA_HOME: {CUDA_HOME}. "
|
||||
"CUDA must be available in order to build the package.")
|
||||
f"Cannot find CUDA_HOME. CUDA must be available to build the package.")
|
||||
|
||||
|
||||
def get_nvcc_cuda_version(cuda_dir: str) -> Version:
|
||||
@ -48,12 +47,6 @@ for i in range(device_count):
|
||||
raise RuntimeError(
|
||||
"GPUs with compute capability less than 7.0 are not supported.")
|
||||
compute_capabilities.add(major * 10 + minor)
|
||||
# If no GPU is available, add all supported compute capabilities.
|
||||
if not compute_capabilities:
|
||||
compute_capabilities = {70, 75, 80, 86, 90}
|
||||
# Add target compute capabilities to NVCC flags.
|
||||
for capability in compute_capabilities:
|
||||
NVCC_FLAGS += ["-gencode", f"arch=compute_{capability},code=sm_{capability}"]
|
||||
|
||||
# Validate the NVCC CUDA version.
|
||||
nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME)
|
||||
@ -62,10 +55,31 @@ if nvcc_cuda_version < Version("11.0"):
|
||||
if 86 in compute_capabilities and nvcc_cuda_version < Version("11.1"):
|
||||
raise RuntimeError(
|
||||
"CUDA 11.1 or higher is required for GPUs with compute capability 8.6.")
|
||||
if 89 in compute_capabilities and nvcc_cuda_version < Version("11.8"):
|
||||
# CUDA 11.8 is required to generate the code targeting compute capability 8.9.
|
||||
# However, GPUs with compute capability 8.9 can also run the code generated by
|
||||
# the previous versions of CUDA 11 and targeting compute capability 8.0.
|
||||
# Therefore, if CUDA 11.8 is not available, we target compute capability 8.0
|
||||
# instead of 8.9.
|
||||
compute_capabilities.remove(89)
|
||||
compute_capabilities.add(80)
|
||||
if 90 in compute_capabilities and nvcc_cuda_version < Version("11.8"):
|
||||
raise RuntimeError(
|
||||
"CUDA 11.8 or higher is required for GPUs with compute capability 9.0.")
|
||||
|
||||
# If no GPU is available, add all supported compute capabilities.
|
||||
if not compute_capabilities:
|
||||
compute_capabilities = {70, 75, 80}
|
||||
if nvcc_cuda_version >= Version("11.1"):
|
||||
compute_capabilities.add(86)
|
||||
if nvcc_cuda_version >= Version("11.8"):
|
||||
compute_capabilities.add(89)
|
||||
compute_capabilities.add(90)
|
||||
|
||||
# Add target compute capabilities to NVCC flags.
|
||||
for capability in compute_capabilities:
|
||||
NVCC_FLAGS += ["-gencode", f"arch=compute_{capability},code=sm_{capability}"]
|
||||
|
||||
# Use NVCC threads to parallelize the build.
|
||||
if nvcc_cuda_version >= Version("11.2"):
|
||||
num_threads = min(os.cpu_count(), 8)
|
||||
|
||||
51
tests/async_engine/api_server_async_engine.py
Normal file
51
tests/async_engine/api_server_async_engine.py
Normal file
@ -0,0 +1,51 @@
|
||||
"""vllm.entrypoints.api_server with some extra logging for testing."""
|
||||
import argparse
|
||||
from typing import Any, Dict
|
||||
|
||||
import uvicorn
|
||||
from fastapi.responses import JSONResponse, Response
|
||||
|
||||
import vllm.entrypoints.api_server
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
|
||||
app = vllm.entrypoints.api_server.app
|
||||
|
||||
|
||||
class AsyncLLMEngineWithStats(AsyncLLMEngine):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._num_aborts = 0
|
||||
|
||||
async def abort(self, request_id: str) -> None:
|
||||
await super().abort(request_id)
|
||||
self._num_aborts += 1
|
||||
|
||||
def testing_stats(self) -> Dict[str, Any]:
|
||||
return {"num_aborted_requests": self._num_aborts}
|
||||
|
||||
|
||||
@app.get("/stats")
|
||||
def stats() -> Response:
|
||||
"""Get the statistics of the engine."""
|
||||
return JSONResponse(engine.testing_stats())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="localhost")
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
engine = AsyncLLMEngineWithStats.from_engine_args(engine_args,
|
||||
start_engine_loop=False)
|
||||
vllm.entrypoints.api_server.engine = engine
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
log_level="debug",
|
||||
timeout_keep_alive=vllm.entrypoints.api_server.TIMEOUT_KEEP_ALIVE)
|
||||
86
tests/async_engine/test_api_server.py
Normal file
86
tests/async_engine/test_api_server.py
Normal file
@ -0,0 +1,86 @@
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from multiprocessing import Pool
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
|
||||
def _query_server(prompt: str) -> dict:
|
||||
response = requests.post("http://localhost:8000/generate",
|
||||
json={
|
||||
"prompt": prompt,
|
||||
"max_tokens": 100,
|
||||
"temperature": 0,
|
||||
"ignore_eos": True
|
||||
})
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api_server():
|
||||
script_path = Path(__file__).parent.joinpath(
|
||||
"api_server_async_engine.py").absolute()
|
||||
uvicorn_process = subprocess.Popen([
|
||||
sys.executable, "-u",
|
||||
str(script_path), "--model", "facebook/opt-125m"
|
||||
])
|
||||
yield
|
||||
uvicorn_process.terminate()
|
||||
|
||||
|
||||
def test_api_server(api_server):
|
||||
"""
|
||||
Run the API server and test it.
|
||||
|
||||
We run both the server and requests in separate processes.
|
||||
|
||||
We test that the server can handle incoming requests, including
|
||||
multiple requests at the same time, and that it can handle requests
|
||||
being cancelled without crashing.
|
||||
"""
|
||||
with Pool(32) as pool:
|
||||
# Wait until the server is ready
|
||||
prompts = ["Hello world"] * 1
|
||||
result = None
|
||||
while not result:
|
||||
try:
|
||||
for result in pool.map(_query_server, prompts):
|
||||
break
|
||||
except:
|
||||
time.sleep(1)
|
||||
|
||||
# Actual tests start here
|
||||
# Try with 1 prompt
|
||||
for result in pool.map(_query_server, prompts):
|
||||
assert result
|
||||
|
||||
num_aborted_requests = requests.get(
|
||||
"http://localhost:8000/stats").json()["num_aborted_requests"]
|
||||
assert num_aborted_requests == 0
|
||||
|
||||
# Try with 100 prompts
|
||||
prompts = ["Hello world"] * 100
|
||||
for result in pool.map(_query_server, prompts):
|
||||
assert result
|
||||
|
||||
# Cancel requests
|
||||
pool.map_async(_query_server, prompts)
|
||||
time.sleep(0.01)
|
||||
pool.terminate()
|
||||
pool.join()
|
||||
|
||||
# check cancellation stats
|
||||
num_aborted_requests = requests.get(
|
||||
"http://localhost:8000/stats").json()["num_aborted_requests"]
|
||||
assert num_aborted_requests > 0
|
||||
|
||||
# check that server still runs after cancellations
|
||||
with Pool(32) as pool:
|
||||
# Try with 100 prompts
|
||||
prompts = ["Hello world"] * 100
|
||||
for result in pool.map(_query_server, prompts):
|
||||
assert result
|
||||
54
tests/async_engine/test_request_tracker.py
Normal file
54
tests/async_engine/test_request_tracker.py
Normal file
@ -0,0 +1,54 @@
|
||||
import pytest
|
||||
|
||||
from vllm.engine.async_llm_engine import RequestTracker
|
||||
from vllm.outputs import RequestOutput
|
||||
|
||||
|
||||
def test_request_tracker():
|
||||
tracker = RequestTracker()
|
||||
stream_1 = tracker.add_request("1")
|
||||
new, finished = tracker.get_new_and_finished_requests()
|
||||
assert len(new) == 1
|
||||
assert new[0]["request_id"] == "1"
|
||||
assert not finished
|
||||
assert not stream_1.finished
|
||||
|
||||
stream_2 = tracker.add_request("2")
|
||||
stream_3 = tracker.add_request("3")
|
||||
new, finished = tracker.get_new_and_finished_requests()
|
||||
assert len(new) == 2
|
||||
assert new[0]["request_id"] == "2"
|
||||
assert new[1]["request_id"] == "3"
|
||||
assert not finished
|
||||
assert not stream_2.finished
|
||||
assert not stream_3.finished
|
||||
|
||||
# request_ids must be unique
|
||||
with pytest.raises(KeyError):
|
||||
tracker.add_request("1")
|
||||
|
||||
tracker.abort_request("1")
|
||||
new, finished = tracker.get_new_and_finished_requests()
|
||||
assert len(finished) == 1
|
||||
assert "1" in finished
|
||||
assert not new
|
||||
assert stream_1.finished
|
||||
|
||||
stream_4 = tracker.add_request("4")
|
||||
tracker.abort_request("4")
|
||||
new, finished = tracker.get_new_and_finished_requests()
|
||||
assert len(finished) == 1
|
||||
assert "4" in finished
|
||||
assert not new
|
||||
assert stream_4.finished
|
||||
|
||||
stream_5 = tracker.add_request("5")
|
||||
tracker.process_request_output(
|
||||
RequestOutput("2", "output", [], [], finished=True))
|
||||
new, finished = tracker.get_new_and_finished_requests()
|
||||
assert len(finished) == 1
|
||||
assert "2" in finished
|
||||
assert len(new) == 1
|
||||
assert new[0]["request_id"] == "5"
|
||||
assert stream_2.finished
|
||||
assert not stream_5.finished
|
||||
178
tests/conftest.py
Normal file
178
tests/conftest.py
Normal file
@ -0,0 +1,178 @@
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
_TEST_PROMPTS = [
|
||||
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs.",
|
||||
"Briefly describe the major milestones in the development of artificial intelligence from 1950 to 2020.",
|
||||
"Compare and contrast artificial intelligence with human intelligence in terms of processing information.",
|
||||
"Describe the basic components of a neural network and how it can be trained.",
|
||||
"Write a short story about a robot that dreams for the first time.",
|
||||
"Analyze the impact of the COVID-19 pandemic on global economic structures and future business models.",
|
||||
"Explain the cultural significance of the Mona Lisa painting, and how its perception might vary in Western versus Eastern societies.",
|
||||
"Translate the following English sentence into Japanese, French, and Swahili: 'The early bird catches the worm.'",
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def example_prompts() -> List[str]:
|
||||
return _TEST_PROMPTS
|
||||
|
||||
|
||||
_STR_DTYPE_TO_TORCH_DTYPE = {
|
||||
"half": torch.half,
|
||||
"bfloat16": torch.bfloat16,
|
||||
"float": torch.float,
|
||||
}
|
||||
|
||||
|
||||
class HfRunner:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
tokenizer_name: Optional[str] = None,
|
||||
dtype: str = "half",
|
||||
) -> None:
|
||||
assert dtype in _STR_DTYPE_TO_TORCH_DTYPE
|
||||
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch_dtype,
|
||||
trust_remote_code=True,
|
||||
).cuda()
|
||||
if tokenizer_name is None:
|
||||
tokenizer_name = model_name
|
||||
self.tokenizer = get_tokenizer(tokenizer_name, trust_remote_code=True)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
**kwargs,
|
||||
) -> List[Tuple[List[int], str]]:
|
||||
outputs: List[Tuple[List[int], str]] = []
|
||||
for prompt in prompts:
|
||||
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
|
||||
output_ids = self.model.generate(
|
||||
input_ids.cuda(),
|
||||
use_cache=True,
|
||||
**kwargs,
|
||||
)
|
||||
output_str = self.tokenizer.batch_decode(
|
||||
output_ids,
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=False,
|
||||
)
|
||||
output_ids = output_ids.cpu().tolist()
|
||||
outputs.append((output_ids, output_str))
|
||||
return outputs
|
||||
|
||||
def generate_greedy(
|
||||
self,
|
||||
prompts: List[str],
|
||||
max_tokens: int,
|
||||
) -> List[Tuple[List[int], str]]:
|
||||
outputs = self.generate(prompts,
|
||||
do_sample=False,
|
||||
max_new_tokens=max_tokens)
|
||||
for i in range(len(outputs)):
|
||||
output_ids, output_str = outputs[i]
|
||||
outputs[i] = (output_ids[0], output_str[0])
|
||||
return outputs
|
||||
|
||||
def generate_beam_search(
|
||||
self,
|
||||
prompts: List[str],
|
||||
beam_width: int,
|
||||
max_tokens: int,
|
||||
) -> List[Tuple[List[int], str]]:
|
||||
outputs = self.generate(prompts,
|
||||
do_sample=False,
|
||||
max_new_tokens=max_tokens,
|
||||
num_beams=beam_width,
|
||||
num_return_sequences=beam_width)
|
||||
for i in range(len(outputs)):
|
||||
output_ids, output_str = outputs[i]
|
||||
for j in range(len(output_ids)):
|
||||
output_ids[j] = [
|
||||
x for x in output_ids[j]
|
||||
if x != self.tokenizer.pad_token_id
|
||||
]
|
||||
outputs[i] = (output_ids, output_str)
|
||||
return outputs
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def hf_runner():
|
||||
return HfRunner
|
||||
|
||||
|
||||
class VllmRunner:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
tokenizer_name: Optional[str] = None,
|
||||
dtype: str = "half",
|
||||
) -> None:
|
||||
self.model = LLM(
|
||||
model=model_name,
|
||||
tokenizer=tokenizer_name,
|
||||
trust_remote_code=True,
|
||||
dtype=dtype,
|
||||
swap_space=0,
|
||||
)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
sampling_params: SamplingParams,
|
||||
) -> List[Tuple[List[int], str]]:
|
||||
req_outputs = self.model.generate(prompts,
|
||||
sampling_params=sampling_params)
|
||||
outputs = []
|
||||
for req_output in req_outputs:
|
||||
prompt_str = req_output.prompt
|
||||
prompt_ids = req_output.prompt_token_ids
|
||||
req_sample_output_ids = []
|
||||
req_sample_output_strs = []
|
||||
for sample in req_output.outputs:
|
||||
output_str = sample.text
|
||||
output_ids = sample.token_ids
|
||||
req_sample_output_ids.append(prompt_ids + output_ids)
|
||||
req_sample_output_strs.append(prompt_str + output_str)
|
||||
outputs.append((req_sample_output_ids, req_sample_output_strs))
|
||||
return outputs
|
||||
|
||||
def generate_greedy(
|
||||
self,
|
||||
prompts: List[str],
|
||||
max_tokens: int,
|
||||
) -> List[Tuple[List[int], str]]:
|
||||
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
|
||||
outputs = self.generate(prompts, greedy_params)
|
||||
return [(output_ids[0], output_str[0])
|
||||
for output_ids, output_str in outputs]
|
||||
|
||||
def generate_beam_search(
|
||||
self,
|
||||
prompts: List[str],
|
||||
beam_width: int,
|
||||
max_tokens: int,
|
||||
) -> List[Tuple[List[int], str]]:
|
||||
beam_search_params = SamplingParams(n=beam_width,
|
||||
use_beam_search=True,
|
||||
temperature=0.0,
|
||||
max_tokens=max_tokens)
|
||||
outputs = self.generate(prompts, beam_search_params)
|
||||
return outputs
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vllm_runner():
|
||||
return VllmRunner
|
||||
43
tests/kernels/conftest.py
Normal file
43
tests/kernels/conftest.py
Normal file
@ -0,0 +1,43 @@
|
||||
from typing import List, Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
|
||||
def create_kv_caches(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_layers: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
scale = head_size**-0.5
|
||||
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
||||
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
|
||||
key_caches = []
|
||||
for _ in range(num_layers):
|
||||
key_cache = torch.empty(size=key_cache_shape,
|
||||
dtype=dtype,
|
||||
device='cuda')
|
||||
key_cache.uniform_(-scale, scale)
|
||||
key_caches.append(key_cache)
|
||||
|
||||
value_cache_shape = (num_blocks, num_heads, head_size, block_size)
|
||||
value_caches = []
|
||||
for _ in range(num_layers):
|
||||
value_cache = torch.empty(size=value_cache_shape,
|
||||
dtype=dtype,
|
||||
device='cuda')
|
||||
value_cache.uniform_(-scale, scale)
|
||||
value_caches.append(value_cache)
|
||||
return key_caches, value_caches
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def kv_cache_factory():
|
||||
return create_kv_caches
|
||||
@ -1,20 +1,34 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers.activations import get_activation
|
||||
|
||||
from vllm import activation_ops
|
||||
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||
NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
|
||||
D = [512, 4096, 5120, 13824] # Arbitrary values for testing
|
||||
SEEDS = [0]
|
||||
|
||||
|
||||
def ref_silu_and_mul(x: torch.Tensor) -> torch.Tensor:
|
||||
x1, x2 = x.chunk(chunks=2, dim=1)
|
||||
return F.silu(x1) * x2
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("d", D)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@torch.inference_mode()
|
||||
def run_silu_and_mul(
|
||||
def test_silu_and_mul(
|
||||
num_tokens: int,
|
||||
d: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
) -> None:
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
x = torch.randn(num_tokens, 2 * d, dtype=dtype, device='cuda')
|
||||
out = torch.empty(num_tokens, d, dtype=dtype, device='cuda')
|
||||
activation_ops.silu_and_mul(out, x)
|
||||
@ -22,9 +36,40 @@ def run_silu_and_mul(
|
||||
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
|
||||
|
||||
|
||||
def test_silu_and_mul() -> None:
|
||||
for dtype in [torch.half, torch.bfloat16, torch.float]:
|
||||
for num_tokens in [7, 83, 2048]:
|
||||
for d in [512, 4096, 5120, 13824]:
|
||||
print(f'Testing dtype={dtype}, num_tokens={num_tokens}, d={d}')
|
||||
run_silu_and_mul(num_tokens, d, dtype)
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("d", D)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@torch.inference_mode()
|
||||
def test_gelu_new(
|
||||
num_tokens: int,
|
||||
d: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
) -> None:
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
x = torch.randn(num_tokens, d, dtype=dtype, device='cuda')
|
||||
out = torch.empty(num_tokens, d, dtype=dtype, device='cuda')
|
||||
activation_ops.gelu_new(out, x)
|
||||
ref_out = get_activation("gelu_new")(x)
|
||||
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("d", D)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
def test_gelu_fast(
|
||||
num_tokens: int,
|
||||
d: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
) -> None:
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
x = torch.randn(num_tokens, d, dtype=dtype, device='cuda')
|
||||
out = torch.empty(num_tokens, d, dtype=dtype, device='cuda')
|
||||
activation_ops.gelu_fast(out, x)
|
||||
ref_out = get_activation("gelu_fast")(x)
|
||||
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
|
||||
|
||||
@ -1,14 +1,24 @@
|
||||
import random
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from xformers import ops as xops
|
||||
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
|
||||
|
||||
from vllm import attention_ops
|
||||
|
||||
MAX_SEQ_LEN = 4096
|
||||
TEST_SEED = 0
|
||||
MAX_SEQ_LEN = 8192
|
||||
NUM_BLOCKS = 128 # Arbitrary values for testing
|
||||
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||
NUM_GEN_SEQS = [7] # Arbitrary values for testing
|
||||
NUM_PREFILL_SEQS = [1, 3, 7] # Arbitrary values for testing
|
||||
NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing
|
||||
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
|
||||
BLOCK_SIZES = [8, 16, 32]
|
||||
USE_ALIBI = [False, True]
|
||||
SEEDS = [0]
|
||||
|
||||
|
||||
def ref_masked_attention(
|
||||
@ -18,29 +28,34 @@ def ref_masked_attention(
|
||||
scale: float,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
query = query * scale
|
||||
attn = torch.einsum('qhd,khd->hqk', query, key)
|
||||
attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
|
||||
if attn_mask is not None:
|
||||
attn = attn + attn_mask
|
||||
attn = torch.softmax(attn, dim=-1)
|
||||
out = torch.einsum('hqk,khd->qhd', attn, value)
|
||||
attn_weights = attn_weights + attn_mask.float()
|
||||
attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
|
||||
out = torch.einsum("hqk,khd->qhd", attn_weights, value)
|
||||
return out
|
||||
|
||||
|
||||
def ref_single_query_cached_kv_attention(
|
||||
output: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
num_queries_per_kv: int,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
context_lens: torch.Tensor,
|
||||
scale: float,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
) -> None:
|
||||
num_heads = value_cache.shape[1]
|
||||
num_query_heads = query.shape[1]
|
||||
num_kv_heads = value_cache.shape[1]
|
||||
head_size = value_cache.shape[2]
|
||||
block_size = value_cache.shape[3]
|
||||
num_seqs = query.shape[0]
|
||||
|
||||
num_input_tokens = query.shape[0]
|
||||
for i in range(num_input_tokens):
|
||||
block_tables = block_tables.cpu().tolist()
|
||||
context_lens = context_lens.cpu().tolist()
|
||||
for i in range(num_seqs):
|
||||
q = query[i].unsqueeze(0)
|
||||
block_table = block_tables[i]
|
||||
context_len = int(context_lens[i])
|
||||
@ -52,30 +67,138 @@ def ref_single_query_cached_kv_attention(
|
||||
block_offset = j % block_size
|
||||
|
||||
k = key_cache[block_number, :, :, block_offset, :]
|
||||
k = k.reshape(num_heads, head_size)
|
||||
k = k.reshape(num_kv_heads, head_size)
|
||||
keys.append(k)
|
||||
|
||||
v = value_cache[block_number, :, :, block_offset]
|
||||
values.append(v)
|
||||
keys = torch.stack(keys, dim=0)
|
||||
values = torch.stack(values, dim=0)
|
||||
if num_queries_per_kv > 1:
|
||||
# Handle MQA and GQA
|
||||
keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1)
|
||||
values = torch.repeat_interleave(values, num_queries_per_kv, dim=1)
|
||||
|
||||
scale = 1.0 / (head_size ** 0.5)
|
||||
out = ref_masked_attention(q, keys, values, scale)
|
||||
out = out.view(num_heads, head_size)
|
||||
alibi_bias = None
|
||||
if alibi_slopes is not None:
|
||||
# Create the ALiBi bias used in the paged attention kernel.
|
||||
position_ids = torch.arange(context_len, device="cuda").int()
|
||||
alibi_bias = (position_ids - context_len + 1).float()
|
||||
alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(
|
||||
1, 1, -1)
|
||||
|
||||
out = ref_masked_attention(q, keys, values, scale, alibi_bias)
|
||||
out = out.view(num_query_heads, head_size)
|
||||
output[i].copy_(out, non_blocking=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("use_alibi", USE_ALIBI)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@torch.inference_mode()
|
||||
def test_single_query_cached_kv_attention(
|
||||
kv_cache_factory,
|
||||
num_seqs: int,
|
||||
num_heads: Tuple[int, int],
|
||||
head_size: int,
|
||||
use_alibi: bool,
|
||||
block_size: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
) -> None:
|
||||
random.seed(seed)
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
scale = float(1.0 / (head_size**0.5))
|
||||
num_query_heads, num_kv_heads = num_heads
|
||||
query = torch.empty(num_seqs,
|
||||
num_query_heads,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device="cuda")
|
||||
query.uniform_(-scale, scale)
|
||||
|
||||
assert num_query_heads % num_kv_heads == 0
|
||||
num_queries_per_kv = num_query_heads // num_kv_heads
|
||||
head_mapping = torch.repeat_interleave(
|
||||
torch.arange(num_kv_heads, dtype=torch.int32, device="cuda"),
|
||||
num_queries_per_kv)
|
||||
alibi_slopes = None
|
||||
if use_alibi:
|
||||
alibi_slopes = torch.randn(num_query_heads,
|
||||
dtype=torch.float,
|
||||
device="cuda")
|
||||
|
||||
context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
|
||||
max_context_len = max(context_lens)
|
||||
context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda")
|
||||
|
||||
# Create the block tables.
|
||||
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
|
||||
block_tables = []
|
||||
for _ in range(num_seqs):
|
||||
block_table = [
|
||||
random.randint(0, NUM_BLOCKS - 1)
|
||||
for _ in range(max_num_blocks_per_seq)
|
||||
]
|
||||
block_tables.append(block_table)
|
||||
block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda")
|
||||
|
||||
# Create the KV caches.
|
||||
key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
|
||||
num_kv_heads, head_size, dtype,
|
||||
seed)
|
||||
key_cache, value_cache = key_caches[0], value_caches[0]
|
||||
|
||||
# Call the paged attention kernel.
|
||||
output = torch.empty_like(query)
|
||||
attention_ops.single_query_cached_kv_attention(
|
||||
output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
head_mapping,
|
||||
scale,
|
||||
block_tables,
|
||||
context_lens,
|
||||
block_size,
|
||||
max_context_len,
|
||||
alibi_slopes,
|
||||
)
|
||||
|
||||
# Run the reference implementation.
|
||||
ref_output = torch.empty_like(query)
|
||||
ref_single_query_cached_kv_attention(
|
||||
ref_output,
|
||||
query,
|
||||
num_queries_per_kv,
|
||||
key_cache,
|
||||
value_cache,
|
||||
block_tables,
|
||||
context_lens,
|
||||
scale,
|
||||
alibi_slopes,
|
||||
)
|
||||
|
||||
# NOTE(woosuk): Due to the kernel-level differences in the two
|
||||
# implementations, there is a small numerical difference in the two
|
||||
# outputs. Thus, we use a relaxed tolerance for the test.
|
||||
assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)
|
||||
|
||||
|
||||
def ref_multi_query_kv_attention(
|
||||
cu_seq_lens: List[int],
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
scale: float,
|
||||
dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
head_size = query.shape[-1]
|
||||
scale = 1.0 / (head_size ** 0.5)
|
||||
|
||||
num_seqs = len(cu_seq_lens) - 1
|
||||
ref_outputs = []
|
||||
for i in range(num_seqs):
|
||||
@ -84,10 +207,10 @@ def ref_multi_query_kv_attention(
|
||||
seq_len = end_idx - start_idx
|
||||
|
||||
# Create attention mask.
|
||||
attn_mask = torch.triu(
|
||||
torch.ones(seq_len, seq_len, dtype=dtype), diagonal=1)
|
||||
attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype),
|
||||
diagonal=1)
|
||||
attn_mask = attn_mask * torch.finfo(dtype).min
|
||||
attn_mask = attn_mask.to(dtype=dtype, device='cuda')
|
||||
attn_mask = attn_mask.to(dtype=dtype, device="cuda")
|
||||
|
||||
ref_output = ref_masked_attention(
|
||||
query[start_idx:end_idx],
|
||||
@ -101,147 +224,43 @@ def ref_multi_query_kv_attention(
|
||||
return ref_output
|
||||
|
||||
|
||||
def ref_multi_query_cached_kv_attention(
|
||||
cu_query_lens: List[int],
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
context_lens: torch.Tensor,
|
||||
dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
num_heads = value_cache.shape[1]
|
||||
head_size = value_cache.shape[2]
|
||||
block_size = value_cache.shape[3]
|
||||
scale = 1.0 / (head_size ** 0.5)
|
||||
|
||||
num_queries = len(cu_query_lens) - 1
|
||||
ref_outputs = []
|
||||
for i in range(num_queries):
|
||||
start_idx = cu_query_lens[i]
|
||||
end_idx = cu_query_lens[i + 1]
|
||||
query_len = end_idx - start_idx
|
||||
context_len = int(context_lens[i])
|
||||
block_table = block_tables[i]
|
||||
|
||||
# Create attention mask
|
||||
attn_mask = torch.triu(
|
||||
torch.ones(query_len, context_len), diagonal=context_len - query_len + 1) * -1e5
|
||||
attn_mask = attn_mask.to(dtype=dtype, device='cuda')
|
||||
|
||||
keys = []
|
||||
values = []
|
||||
for j in range(context_len):
|
||||
block_number = int(block_table[j // block_size])
|
||||
block_offset = j % block_size
|
||||
|
||||
k = key_cache[block_number, :, :, block_offset, :]
|
||||
k = k.reshape(num_heads, head_size)
|
||||
keys.append(k)
|
||||
|
||||
v = value_cache[block_number, :, :, block_offset]
|
||||
values.append(v)
|
||||
keys = torch.stack(keys, dim=0)
|
||||
values = torch.stack(values, dim=0)
|
||||
|
||||
ref_output = ref_masked_attention(
|
||||
query[start_idx:end_idx],
|
||||
keys,
|
||||
values,
|
||||
scale,
|
||||
attn_mask=attn_mask,
|
||||
)
|
||||
ref_outputs.append(ref_output)
|
||||
ref_output = torch.cat(ref_outputs, dim=0)
|
||||
return ref_output
|
||||
|
||||
|
||||
# TODO(woosuk): Add tests for USE_ALIBI=True.
|
||||
@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@torch.inference_mode()
|
||||
def run_single_query_cached_kv_attention(
|
||||
num_tokens: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
block_size: int,
|
||||
num_blocks: int,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
qkv = torch.empty(
|
||||
num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
|
||||
qkv.uniform_(-1e-3, 1e-3)
|
||||
query, _, _ = qkv.unbind(dim=1)
|
||||
|
||||
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
||||
key_block_shape = (num_heads, head_size // x, block_size, x)
|
||||
key_cache = torch.empty(
|
||||
size=(num_blocks, *key_block_shape), dtype=dtype, device='cuda')
|
||||
key_cache.uniform_(-1e-3, 1e-3)
|
||||
value_block_shape = (num_heads, head_size, block_size)
|
||||
value_cache = torch.empty(
|
||||
size=(num_blocks, *value_block_shape), dtype=dtype, device='cuda')
|
||||
value_cache.uniform_(-1e-3, 1e-3)
|
||||
|
||||
context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_tokens)]
|
||||
max_context_len = max(context_lens)
|
||||
context_lens = torch.tensor(context_lens, dtype=torch.int, device='cuda')
|
||||
|
||||
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
|
||||
block_tables = []
|
||||
for _ in range(num_tokens):
|
||||
block_table = [
|
||||
random.randint(0, num_blocks - 1)
|
||||
for _ in range(max_num_blocks_per_seq)
|
||||
]
|
||||
block_tables.append(block_table)
|
||||
block_tables = torch.tensor(block_tables, dtype=torch.int, device='cuda')
|
||||
|
||||
scale = float(1.0 / (head_size ** 0.5))
|
||||
output = torch.empty(
|
||||
num_tokens, num_heads, head_size, dtype=dtype, device='cuda')
|
||||
attention_ops.single_query_cached_kv_attention(
|
||||
output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
scale,
|
||||
block_tables,
|
||||
context_lens,
|
||||
block_size,
|
||||
max_context_len,
|
||||
)
|
||||
|
||||
ref_output = torch.empty_like(query)
|
||||
ref_single_query_cached_kv_attention(
|
||||
ref_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
block_tables,
|
||||
context_lens,
|
||||
)
|
||||
# NOTE(woosuk): Due to the difference in the data types the two
|
||||
# implementations use for attention softmax logits and accumulation,
|
||||
# there is a small difference in the final outputs.
|
||||
# We should use a relaxed tolerance for the test.
|
||||
assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def run_multi_query_kv_attention(
|
||||
def test_multi_query_kv_attention(
|
||||
num_seqs: int,
|
||||
num_heads: int,
|
||||
num_heads: Tuple[int, int],
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
) -> None:
|
||||
random.seed(seed)
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
seq_lens = random.sample(range(1, MAX_SEQ_LEN), num_seqs)
|
||||
num_tokens = sum(seq_lens)
|
||||
|
||||
scale = float(1.0 / (head_size**0.5))
|
||||
qkv = torch.empty(
|
||||
num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
|
||||
qkv.uniform_(-1e-3, 1e-3)
|
||||
query, key, value = qkv.unbind(dim=1)
|
||||
num_query_heads, num_kv_heads = num_heads
|
||||
qkv = torch.empty(num_tokens,
|
||||
num_query_heads + 2 * num_kv_heads,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device="cuda")
|
||||
qkv.uniform_(-scale, scale)
|
||||
query, key, value = qkv.split(
|
||||
[num_query_heads, num_kv_heads, num_kv_heads], dim=1)
|
||||
|
||||
attn_op = xops.fmha.cutlass.FwOp()
|
||||
num_queries_per_kv = num_query_heads // num_kv_heads
|
||||
if num_queries_per_kv > 1:
|
||||
# Handle MQA and GQA
|
||||
key = torch.repeat_interleave(key, num_queries_per_kv, dim=1)
|
||||
value = torch.repeat_interleave(value, num_queries_per_kv, dim=1)
|
||||
attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)
|
||||
output = xops.memory_efficient_attention_forward(
|
||||
query.unsqueeze(0),
|
||||
@ -250,7 +269,6 @@ def run_multi_query_kv_attention(
|
||||
attn_bias=attn_bias,
|
||||
p=0.0,
|
||||
scale=scale,
|
||||
op=attn_op,
|
||||
)
|
||||
output = output.squeeze(0)
|
||||
|
||||
@ -262,40 +280,7 @@ def run_multi_query_kv_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
scale,
|
||||
dtype,
|
||||
)
|
||||
assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)
|
||||
|
||||
|
||||
def test_single_query_cached_kv_attention() -> None:
|
||||
torch.random.manual_seed(TEST_SEED)
|
||||
torch.cuda.manual_seed(TEST_SEED)
|
||||
for dtype in [torch.half, torch.bfloat16, torch.float]:
|
||||
for block_size in [8, 16, 32]:
|
||||
for head_size in [64, 80, 96, 128]:
|
||||
print(f'Testing single_query_cached_kv_attention with '
|
||||
f'dtype={dtype}, block_size={block_size}, '
|
||||
f'head_size={head_size}')
|
||||
run_single_query_cached_kv_attention(
|
||||
num_tokens=37,
|
||||
num_heads=3,
|
||||
head_size=head_size,
|
||||
block_size=block_size,
|
||||
num_blocks=1024,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
|
||||
def test_multi_query_kv_attention() -> None:
|
||||
torch.random.manual_seed(TEST_SEED)
|
||||
torch.cuda.manual_seed(TEST_SEED)
|
||||
for dtype in [torch.half, torch.bfloat16, torch.float]:
|
||||
for head_size in [64, 80, 96, 128]:
|
||||
print(f'Testing multi_query_kv_attention with dtype={dtype}, '
|
||||
f'head_size={head_size}')
|
||||
run_multi_query_kv_attention(
|
||||
num_seqs=5,
|
||||
num_heads=3,
|
||||
head_size=head_size,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
@ -1,12 +1,32 @@
|
||||
import random
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import cache_ops
|
||||
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||
NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
|
||||
NUM_LAYERS = [5] # Arbitrary values for testing
|
||||
NUM_HEADS = [8] # Arbitrary values for testing
|
||||
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
|
||||
BLOCK_SIZES = [8, 16, 32]
|
||||
NUM_BLOCKS = [1024] # Arbitrary values for testing
|
||||
NUM_MAPPINGS = [32, 256] # Arbitrary values for testing
|
||||
SEEDS = [0]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
|
||||
@pytest.mark.parametrize("num_layers", NUM_LAYERS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@torch.inference_mode()
|
||||
def run_copy_blocks(
|
||||
def test_copy_blocks(
|
||||
kv_cache_factory,
|
||||
num_mappings: int,
|
||||
num_layers: int,
|
||||
num_heads: int,
|
||||
@ -14,151 +34,113 @@ def run_copy_blocks(
|
||||
block_size: int,
|
||||
num_blocks: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
) -> None:
|
||||
# Generate random block mappings.
|
||||
random.seed(seed)
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
# Generate random block mappings where each source block is mapped to two
|
||||
# destination blocks.
|
||||
assert 2 * num_mappings <= num_blocks
|
||||
src_blocks = random.sample(range(num_blocks), num_mappings)
|
||||
remainig_blocks = list(set(range(num_blocks)) - set(src_blocks))
|
||||
dst_blocks = random.sample(remainig_blocks, num_mappings)
|
||||
block_mapping = {src: [dst] for src, dst in zip(src_blocks, dst_blocks)}
|
||||
dst_blocks = random.sample(remainig_blocks, 2 * num_mappings)
|
||||
block_mapping = {}
|
||||
for i in range(num_mappings):
|
||||
src = src_blocks[i]
|
||||
dst1 = dst_blocks[2 * i]
|
||||
dst2 = dst_blocks[2 * i + 1]
|
||||
block_mapping[src] = [dst1, dst2]
|
||||
|
||||
# Create the KV cache.
|
||||
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
||||
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
|
||||
key_caches = []
|
||||
for _ in range(num_layers):
|
||||
key_cache = torch.randn(
|
||||
size=key_cache_shape, dtype=dtype, device='cuda')
|
||||
key_caches.append(key_cache)
|
||||
cloned_key_caches = []
|
||||
for key_cache in key_caches:
|
||||
cloned_key_caches.append(key_cache.clone())
|
||||
# Create the KV caches.
|
||||
key_caches, value_caches = kv_cache_factory(num_blocks, block_size,
|
||||
num_layers, num_heads,
|
||||
head_size, dtype, seed)
|
||||
|
||||
value_cache_shape = (num_blocks, num_heads, head_size, block_size)
|
||||
value_caches = []
|
||||
for _ in range(num_layers):
|
||||
value_cache = torch.randn(
|
||||
size=value_cache_shape, dtype=dtype, device='cuda')
|
||||
value_caches.append(value_cache)
|
||||
cloned_value_caches = []
|
||||
for value_cache in value_caches:
|
||||
cloned_value_caches.append(value_cache.clone())
|
||||
# Clone the KV caches.
|
||||
cloned_key_caches = [key_cache.clone() for key_cache in key_caches]
|
||||
cloned_value_caches = [value_cache.clone() for value_cache in value_caches]
|
||||
|
||||
# Call the copy blocks kernel.
|
||||
cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
|
||||
|
||||
# Reference implementation.
|
||||
# Run the reference implementation.
|
||||
for src, dsts in block_mapping.items():
|
||||
for dst in dsts:
|
||||
for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
|
||||
for cloned_key_cache in cloned_key_caches:
|
||||
cloned_key_cache[dst] = cloned_key_cache[src]
|
||||
for value_cache, cloned_value_cache in zip(value_caches, cloned_value_caches):
|
||||
for cloned_value_cache in cloned_value_caches:
|
||||
cloned_value_cache[dst] = cloned_value_cache[src]
|
||||
|
||||
# Compare the results.
|
||||
for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
|
||||
assert torch.allclose(key_cache, cloned_key_cache)
|
||||
for value_cache, cloned_value_cache in zip(value_caches, cloned_value_caches):
|
||||
for value_cache, cloned_value_cache in zip(value_caches,
|
||||
cloned_value_caches):
|
||||
assert torch.allclose(value_cache, cloned_value_cache)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@torch.inference_mode()
|
||||
def run_reshape_and_cache(
|
||||
def test_reshape_and_cache(
|
||||
kv_cache_factory,
|
||||
num_tokens: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
block_size: int,
|
||||
num_blocks: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
) -> None:
|
||||
random.seed(seed)
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
# Create a random slot mapping.
|
||||
num_slots = block_size * num_blocks
|
||||
slot_mapping = random.sample(range(num_slots), num_tokens)
|
||||
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device='cuda')
|
||||
|
||||
qkv = torch.randn(
|
||||
num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
|
||||
qkv = torch.randn(num_tokens,
|
||||
3,
|
||||
num_heads,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device='cuda')
|
||||
_, key, value = qkv.unbind(dim=1)
|
||||
|
||||
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
||||
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
|
||||
key_cache = torch.randn(size=key_cache_shape, dtype=dtype, device='cuda')
|
||||
cloned_key_cache = key_cache.clone()
|
||||
# Create the KV caches.
|
||||
key_caches, value_caches = kv_cache_factory(num_blocks, block_size, 1,
|
||||
num_heads, head_size, dtype,
|
||||
seed)
|
||||
key_cache, value_cache = key_caches[0], value_caches[0]
|
||||
|
||||
value_cache_shape = (num_blocks, num_heads, head_size, block_size)
|
||||
value_cache = torch.randn(
|
||||
size=value_cache_shape, dtype=dtype, device='cuda')
|
||||
# Clone the KV caches.
|
||||
cloned_key_cache = key_cache.clone()
|
||||
cloned_value_cache = value_cache.clone()
|
||||
|
||||
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping)
|
||||
# Call the reshape_and_cache kernel.
|
||||
cache_ops.reshape_and_cache(key, value, key_cache, value_cache,
|
||||
slot_mapping)
|
||||
|
||||
# Run the reference implementation.
|
||||
reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
|
||||
block_indicies = torch.div(slot_mapping, block_size, rounding_mode='floor')
|
||||
block_indicies = block_indicies.cpu().tolist()
|
||||
block_offsets = slot_mapping % block_size
|
||||
block_offsets = block_offsets.cpu().tolist()
|
||||
for i in range(num_tokens):
|
||||
reshaped_key = key.reshape(num_tokens, num_heads, head_size // x, x)
|
||||
block_idx = torch.div(slot_mapping[i], block_size, rounding_mode='floor')
|
||||
block_offset = slot_mapping[i] % block_size
|
||||
block_idx = block_indicies[i]
|
||||
block_offset = block_offsets[i]
|
||||
cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i]
|
||||
cloned_value_cache[block_idx, :, :, block_offset] = value[i]
|
||||
|
||||
assert torch.allclose(key_cache, cloned_key_cache)
|
||||
assert torch.allclose(value_cache, cloned_value_cache)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def run_gather_cached_kv(
|
||||
num_tokens: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
block_size: int,
|
||||
num_blocks: int,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
num_slots = block_size * num_blocks
|
||||
slot_mapping = random.sample(range(num_slots), num_tokens)
|
||||
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device='cuda')
|
||||
|
||||
qkv = torch.randn(
|
||||
num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
|
||||
_, key, value = qkv.unbind(dim=1)
|
||||
|
||||
qkv_clone = qkv.clone()
|
||||
_, cloned_key, cloned_value = qkv_clone.unbind(dim=1)
|
||||
|
||||
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
||||
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
|
||||
key_cache = torch.randn(size=key_cache_shape, dtype=dtype, device='cuda')
|
||||
|
||||
value_cache_shape = (num_blocks, num_heads, head_size, block_size)
|
||||
value_cache = torch.randn(
|
||||
size=value_cache_shape, dtype=dtype, device='cuda')
|
||||
|
||||
cache_ops.gather_cached_kv(key, value, key_cache, value_cache, slot_mapping)
|
||||
|
||||
# Reference implementation.
|
||||
for i in range(num_tokens):
|
||||
reshaped_key = cloned_key.reshape(num_tokens, num_heads, head_size // x, x)
|
||||
block_idx = torch.div(slot_mapping[i], block_size, rounding_mode='floor')
|
||||
block_offset = slot_mapping[i] % block_size
|
||||
reshaped_key[i] = key_cache[block_idx, :, :, block_offset, :]
|
||||
cloned_value[i] = value_cache[block_idx, :, :, block_offset]
|
||||
|
||||
assert torch.allclose(key, cloned_key)
|
||||
assert torch.allclose(value, cloned_value)
|
||||
|
||||
|
||||
def test_copy_blocks() -> None:
|
||||
for dtype in [torch.half, torch.bfloat16, torch.float]:
|
||||
run_copy_blocks(
|
||||
num_mappings=23, num_layers=7, num_heads=17, head_size=16,
|
||||
block_size=8, num_blocks=1024, dtype=dtype)
|
||||
|
||||
|
||||
def test_reshape_and_cache() -> None:
|
||||
for dtype in [torch.half, torch.bfloat16, torch.float]:
|
||||
run_reshape_and_cache(
|
||||
num_tokens=3, num_heads=2, head_size=16, block_size=8, num_blocks=2,
|
||||
dtype=dtype)
|
||||
|
||||
|
||||
def test_gather_cached_kv() -> None:
|
||||
for dtype in [torch.half, torch.bfloat16, torch.float]:
|
||||
run_gather_cached_kv(
|
||||
num_tokens=3, num_heads=2, head_size=16, block_size=8, num_blocks=2,
|
||||
dtype=dtype)
|
||||
|
||||
@ -1,33 +1,50 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm import layernorm_ops
|
||||
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||
HIDDEN_SIZES = [67, 768, 2048, 5120, 8192] # Arbitrary values for testing
|
||||
NUM_TOKENS = [7, 83, 4096] # Arbitrary values for testing
|
||||
SEEDS = [0]
|
||||
|
||||
|
||||
class RefRMSNorm(nn.Module):
|
||||
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
super().__init__()
|
||||
weight = torch.empty(hidden_size)
|
||||
weight.uniform_(-1e-3, 1e-3)
|
||||
weight.normal_(mean=1.0, std=0.1)
|
||||
self.weight = nn.Parameter(weight)
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
if self.weight.dtype in [torch.half, torch.float16, torch.bfloat16]:
|
||||
hidden_states = hidden_states.to(self.weight.dtype)
|
||||
return self.weight * hidden_states
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance +
|
||||
self.variance_epsilon)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@torch.inference_mode()
|
||||
def run_rms_norm(
|
||||
def test_rms_norm(
|
||||
num_tokens: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
) -> None:
|
||||
x = torch.randn(num_tokens, hidden_size, dtype=dtype, device='cuda')
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
scale = float(hidden_size**-0.5)
|
||||
x = torch.empty(num_tokens, hidden_size, dtype=dtype, device="cuda")
|
||||
x.uniform_(-scale, scale)
|
||||
ref = RefRMSNorm(hidden_size).to(dtype).cuda()
|
||||
|
||||
out = torch.empty_like(x)
|
||||
@ -38,17 +55,4 @@ def run_rms_norm(
|
||||
ref.variance_epsilon,
|
||||
)
|
||||
ref_out = ref(x)
|
||||
assert torch.allclose(out, ref_out, atol=1e-3, rtol=1e-5)
|
||||
|
||||
|
||||
def test_rms_norm() -> None:
|
||||
for dtype in [torch.half, torch.bfloat16, torch.float]:
|
||||
for num_tokens in [7, 128, 2048]:
|
||||
for hidden_size in [13, 64, 1024, 5120]:
|
||||
print(f'Testing RMS kernel with dtype={dtype}, num_tokens='
|
||||
f'{num_tokens}, hidden_size={hidden_size}')
|
||||
run_rms_norm(
|
||||
num_tokens=num_tokens,
|
||||
hidden_size=hidden_size,
|
||||
dtype=dtype,
|
||||
)
|
||||
assert torch.allclose(out, ref_out, atol=1e-2, rtol=1e-5)
|
||||
|
||||
@ -1,47 +1,70 @@
|
||||
from typing import Tuple
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm import pos_encoding_ops
|
||||
|
||||
IS_NEOX_STYLE = [True, False]
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
|
||||
ROTARY_DIMS = [None, 32] # None means rotary dim == head size
|
||||
NUM_HEADS = [7, 12, 40, 52] # Arbitrary values for testing
|
||||
NUM_TOKENS = [11, 83, 2048] # Arbitrary values for testing
|
||||
SEEDS = [0]
|
||||
|
||||
def rotate_half(x: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
def rotate_neox(x: torch.Tensor) -> torch.Tensor:
|
||||
x1 = x[..., :x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2:]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(
|
||||
def rotate_gptj(x: torch.Tensor) -> torch.Tensor:
|
||||
x1 = x[..., ::2]
|
||||
x2 = x[..., 1::2]
|
||||
x = torch.stack((-x2, x1), dim=-1)
|
||||
return x.flatten(-2)
|
||||
|
||||
|
||||
def apply_rope(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
is_neox_style: bool,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
rotate_fn = rotate_neox if is_neox_style else rotate_gptj
|
||||
q_embed = (q * cos) + (rotate_fn(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_fn(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
class RefRotaryEmbeddingNeox(nn.Module):
|
||||
"""Reference implementation of the GPT-NeoX style rotary embedding."""
|
||||
class RefRotaryEmbedding(nn.Module):
|
||||
"""Reference implementation of rotary embedding."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
max_position_embeddings: int = 2048,
|
||||
is_neox_style: bool,
|
||||
max_position_embeddings: int = 8192,
|
||||
base: int = 10000,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.rotary_dim = dim
|
||||
self.is_neox_style = is_neox_style
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
# Create cos and sin embeddings.
|
||||
inv_freq = 1.0 / (base**(torch.arange(0, dim, 2) / dim))
|
||||
t = torch.arange(max_position_embeddings).float()
|
||||
freqs = torch.einsum("i,j->ij", t, inv_freq.float())
|
||||
if is_neox_style:
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
else:
|
||||
emb = torch.repeat_interleave(freqs, 2, -1)
|
||||
cos = emb.cos().to(dtype=inv_freq.dtype)
|
||||
sin = emb.sin().to(dtype=inv_freq.dtype)
|
||||
self.register_buffer("cos_cached", cos, persistent=False)
|
||||
@ -53,18 +76,18 @@ class RefRotaryEmbeddingNeox(nn.Module):
|
||||
query: torch.Tensor, # [num_tokens, num_heads, head_size]
|
||||
key: torch.Tensor, # [num_tokens, num_heads, head_size]
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
query_rot = query[..., :self.rotary_dim]
|
||||
query_pass = query[..., self.rotary_dim:]
|
||||
key_rot = key[..., :self.rotary_dim]
|
||||
key_pass = key[..., self.rotary_dim:]
|
||||
|
||||
|
||||
query_rot = query_rot.transpose(0, 1)
|
||||
key_rot = key_rot.transpose(0, 1)
|
||||
cos = F.embedding(positions, self.cos_cached)
|
||||
sin = F.embedding(positions, self.sin_cached)
|
||||
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
|
||||
|
||||
query_rot, key_rot = apply_rope(query_rot, key_rot, cos, sin,
|
||||
self.is_neox_style)
|
||||
query_rot = query_rot.transpose(0, 1).contiguous()
|
||||
key_rot = key_rot.transpose(0, 1).contiguous()
|
||||
|
||||
@ -75,24 +98,44 @@ class RefRotaryEmbeddingNeox(nn.Module):
|
||||
return query, key
|
||||
|
||||
|
||||
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@torch.inference_mode()
|
||||
def run_rotary_embedding_neox(
|
||||
def test_rotary_embedding(
|
||||
is_neox_style: bool,
|
||||
num_tokens: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
max_position: int,
|
||||
rotary_dim: int,
|
||||
rotary_dim: Optional[int],
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
max_position: int = 8192,
|
||||
base: int = 10000,
|
||||
) -> None:
|
||||
positions = torch.randint(0, max_position, (num_tokens,), device='cuda')
|
||||
query = torch.randn(num_tokens, num_heads * head_size, dtype=dtype, device='cuda')
|
||||
key = torch.randn(num_tokens, num_heads * head_size, dtype=dtype, device='cuda')
|
||||
if rotary_dim is None:
|
||||
rotary_dim = head_size
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
positions = torch.randint(0, max_position, (num_tokens, ), device="cuda")
|
||||
query = torch.randn(num_tokens,
|
||||
num_heads * head_size,
|
||||
dtype=dtype,
|
||||
device="cuda")
|
||||
key = torch.randn(num_tokens,
|
||||
num_heads * head_size,
|
||||
dtype=dtype,
|
||||
device="cuda")
|
||||
|
||||
# Create the rotary embedding.
|
||||
inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim))
|
||||
t = torch.arange(max_position).float()
|
||||
freqs = torch.einsum('i,j -> ij', t, inv_freq.float())
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq.float())
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
cos_sin_cache = torch.cat((cos, sin), dim=-1)
|
||||
@ -101,20 +144,22 @@ def run_rotary_embedding_neox(
|
||||
# Run the kernel. The kernel is in-place, so we need to clone the inputs.
|
||||
out_query = query.clone()
|
||||
out_key = key.clone()
|
||||
pos_encoding_ops.rotary_embedding_neox(
|
||||
pos_encoding_ops.rotary_embedding(
|
||||
positions,
|
||||
out_query,
|
||||
out_key,
|
||||
head_size,
|
||||
cos_sin_cache,
|
||||
is_neox_style,
|
||||
)
|
||||
|
||||
# Run the reference implementation.
|
||||
ref_rotary_embedding = RefRotaryEmbeddingNeox(
|
||||
ref_rotary_embedding = RefRotaryEmbedding(
|
||||
dim=rotary_dim,
|
||||
is_neox_style=is_neox_style,
|
||||
max_position_embeddings=max_position,
|
||||
base=base,
|
||||
).to(dtype=dtype, device='cuda')
|
||||
).to(dtype=dtype, device="cuda")
|
||||
ref_query, ref_key = ref_rotary_embedding(
|
||||
positions,
|
||||
query.view(num_tokens, num_heads, head_size),
|
||||
@ -124,19 +169,5 @@ def run_rotary_embedding_neox(
|
||||
ref_key = ref_key.view(num_tokens, num_heads * head_size)
|
||||
|
||||
# Compare the results.
|
||||
assert torch.allclose(out_query, ref_query, atol=1e-3, rtol=1e-5)
|
||||
assert torch.allclose(out_key, ref_key, atol=1e-3, rtol=1e-5)
|
||||
|
||||
|
||||
def test_rotary_embedding_neox() -> None:
|
||||
for dtype in [torch.half, torch.bfloat16, torch.float]:
|
||||
for head_size in [32, 64, 80, 96, 128, 160, 192, 256]:
|
||||
print(f'Running tests for head_size={head_size} and dtype={dtype}')
|
||||
run_rotary_embedding_neox(
|
||||
num_tokens=2145,
|
||||
num_heads=5,
|
||||
head_size=head_size,
|
||||
max_position=8192,
|
||||
rotary_dim=head_size,
|
||||
dtype=dtype,
|
||||
)
|
||||
assert torch.allclose(out_query, ref_query, atol=1e-5, rtol=1e-5)
|
||||
assert torch.allclose(out_key, ref_key, atol=1e-5, rtol=1e-5)
|
||||
|
||||
45
tests/models/test_models.py
Normal file
45
tests/models/test_models.py
Normal file
@ -0,0 +1,45 @@
|
||||
"""Compare the outputs of HF and vLLM when using greedy sampling.
|
||||
|
||||
Run `pytest tests/models/test_models.py --forked`.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
MODELS = [
|
||||
"facebook/opt-125m",
|
||||
"gpt2",
|
||||
"bigcode/tiny_starcoder_py",
|
||||
"EleutherAI/gpt-j-6b",
|
||||
"EleutherAI/pythia-70m",
|
||||
"bigscience/bloom-560m",
|
||||
"mosaicml/mpt-7b",
|
||||
"tiiuae/falcon-7b",
|
||||
"meta-llama/Llama-2-7b-hf",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
def test_models(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
) -> None:
|
||||
hf_model = hf_runner(model, dtype=dtype)
|
||||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||
del hf_model
|
||||
|
||||
vllm_model = vllm_runner(model, dtype=dtype)
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
del vllm_model
|
||||
|
||||
for i in range(len(example_prompts)):
|
||||
hf_output_ids, hf_output_str = hf_outputs[i]
|
||||
vllm_output_ids, vllm_output_str = vllm_outputs[i]
|
||||
assert hf_output_str == vllm_output_str, (
|
||||
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
|
||||
assert hf_output_ids == vllm_output_ids, (
|
||||
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
|
||||
46
tests/samplers/test_beam_search.py
Normal file
46
tests/samplers/test_beam_search.py
Normal file
@ -0,0 +1,46 @@
|
||||
"""Compare the outputs of HF and vLLM when using beam search.
|
||||
|
||||
Run `pytest tests/samplers/test_beam_search.py --forked`.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
# FIXME(zhuohan): The test can not pass if we:
|
||||
# 1. Increase max_tokens to 256.
|
||||
# 2. Increase beam_width to 8.
|
||||
# 3. Use the model "huggyllama/llama-7b".
|
||||
MAX_TOKENS = [128]
|
||||
BEAM_WIDTHS = [4]
|
||||
MODELS = ["facebook/opt-125m"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", MAX_TOKENS)
|
||||
@pytest.mark.parametrize("beam_width", BEAM_WIDTHS)
|
||||
def test_beam_search_single_input(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
beam_width: int,
|
||||
) -> None:
|
||||
hf_model = hf_runner(model, dtype=dtype)
|
||||
hf_outputs = hf_model.generate_beam_search(example_prompts, beam_width,
|
||||
max_tokens)
|
||||
del hf_model
|
||||
|
||||
vllm_model = vllm_runner(model, dtype=dtype)
|
||||
vllm_outputs = vllm_model.generate_beam_search(example_prompts, beam_width,
|
||||
max_tokens)
|
||||
del vllm_model
|
||||
|
||||
for i in range(len(example_prompts)):
|
||||
hf_output_ids, _ = hf_outputs[i]
|
||||
vllm_output_ids, _ = vllm_outputs[i]
|
||||
assert len(hf_output_ids) == len(vllm_output_ids)
|
||||
for j in range(len(hf_output_ids)):
|
||||
assert hf_output_ids[j] == vllm_output_ids[j], (
|
||||
f"Test{i} output{j}:\nHF: {hf_output_ids}\n"
|
||||
f"vLLM: {vllm_output_ids}")
|
||||
@ -1,3 +1,5 @@
|
||||
"""vLLM: a high-throughput and memory-efficient inference engine for LLMs"""
|
||||
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
@ -6,7 +8,7 @@ from vllm.entrypoints.llm import LLM
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
__version__ = "0.1.1"
|
||||
__version__ = "0.1.5"
|
||||
|
||||
__all__ = [
|
||||
"LLM",
|
||||
|
||||
@ -35,7 +35,8 @@ class LogicalTokenBlock:
|
||||
|
||||
def append_tokens(self, token_ids: List[int]) -> None:
|
||||
assert len(token_ids) <= self.get_num_empty_slots()
|
||||
self.token_ids[self.num_tokens:self.num_tokens + len(token_ids)] = token_ids
|
||||
curr_idx = self.num_tokens
|
||||
self.token_ids[curr_idx:curr_idx + len(token_ids)] = token_ids
|
||||
self.num_tokens += len(token_ids)
|
||||
|
||||
def get_token_ids(self) -> List[int]:
|
||||
|
||||
120
vllm/config.py
120
vllm/config.py
@ -1,14 +1,15 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from transformers import AutoConfig, PretrainedConfig
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.config import get_config
|
||||
from vllm.utils import get_cpu_memory
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_GiB = 1 << 30
|
||||
_GB = 1 << 30
|
||||
|
||||
|
||||
class ModelConfig:
|
||||
@ -16,11 +17,23 @@ class ModelConfig:
|
||||
|
||||
Args:
|
||||
model: Name or path of the huggingface model to use.
|
||||
tokenizer: Name or path of the huggingface tokenizer to use.
|
||||
tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
|
||||
available, and "slow" will always use the slow tokenizer.
|
||||
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
|
||||
downloading the model and tokenizer.
|
||||
download_dir: Directory to download and load the weights, default to the
|
||||
default cache directory of huggingface.
|
||||
use_np_weights: Save a numpy copy of model weights for faster loading.
|
||||
This can increase the disk usage by up to 2x.
|
||||
use_dummy_weights: Use dummy values for model weights (for profiling).
|
||||
load_format: The format of the model weights to load:
|
||||
"auto" will try to load the weights in the safetensors format and
|
||||
fall back to the pytorch bin format if safetensors format is
|
||||
not available.
|
||||
"pt" will load the weights in the pytorch bin format.
|
||||
"safetensors" will load the weights in the safetensors format.
|
||||
"npcache" will load the weights in pytorch format and store
|
||||
a numpy cache to speed up the loading.
|
||||
"dummy" will initialize the weights with random values, which is
|
||||
mainly for profiling.
|
||||
dtype: Data type for model weights and activations. The "auto" option
|
||||
will use FP16 precision for FP32 and FP16 models, and BF16 precision
|
||||
for BF16 models.
|
||||
@ -30,20 +43,44 @@ class ModelConfig:
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
tokenizer: str,
|
||||
tokenizer_mode: str,
|
||||
trust_remote_code: bool,
|
||||
download_dir: Optional[str],
|
||||
use_np_weights: bool,
|
||||
use_dummy_weights: bool,
|
||||
load_format: str,
|
||||
dtype: str,
|
||||
seed: int,
|
||||
) -> None:
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.tokenizer_mode = tokenizer_mode
|
||||
self.trust_remote_code = trust_remote_code
|
||||
self.download_dir = download_dir
|
||||
self.use_np_weights = use_np_weights
|
||||
self.use_dummy_weights = use_dummy_weights
|
||||
self.load_format = load_format
|
||||
self.seed = seed
|
||||
|
||||
self.hf_config: PretrainedConfig = AutoConfig.from_pretrained(model)
|
||||
self.hf_config = get_config(model, trust_remote_code)
|
||||
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
|
||||
self._verify_load_format()
|
||||
self._verify_tokenizer_mode()
|
||||
|
||||
def _verify_load_format(self) -> None:
|
||||
load_format = self.load_format.lower()
|
||||
if load_format not in [
|
||||
"auto", "pt", "safetensors", "npcache", "dummy"
|
||||
]:
|
||||
raise ValueError(
|
||||
f"Unknown load format: {self.load_format}. Must be one of "
|
||||
"'auto', 'pt', 'safetensors', 'npcache', or 'dummy'.")
|
||||
self.load_format = load_format
|
||||
|
||||
def _verify_tokenizer_mode(self) -> None:
|
||||
tokenizer_mode = self.tokenizer_mode.lower()
|
||||
if tokenizer_mode not in ["auto", "slow"]:
|
||||
raise ValueError(
|
||||
f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be "
|
||||
"either 'auto' or 'slow'.")
|
||||
self.tokenizer_mode = tokenizer_mode
|
||||
|
||||
def verify_with_parallel_config(
|
||||
self,
|
||||
@ -73,9 +110,48 @@ class ModelConfig:
|
||||
return self.hf_config.hidden_size // self.hf_config.num_attention_heads
|
||||
|
||||
def get_num_heads(self, parallel_config: "ParallelConfig") -> int:
|
||||
# For GPTBigCode & Falcon:
|
||||
# Note: for falcon, when new_decoder_architecture is True, the
|
||||
# multi_query flag is ignored and we use n_head_kv for the number of
|
||||
# KV heads.
|
||||
new_decoder_arch_falcon = (
|
||||
self.hf_config.model_type == "falcon"
|
||||
and getattr(self.hf_config, "new_decoder_architecture", False))
|
||||
if not new_decoder_arch_falcon and getattr(self.hf_config,
|
||||
"multi_query", False):
|
||||
# Multi-query attention, only one KV head.
|
||||
return 1
|
||||
# For Falcon:
|
||||
if getattr(self.hf_config, "n_head_kv", None) is not None:
|
||||
return (self.hf_config.n_head_kv //
|
||||
parallel_config.tensor_parallel_size)
|
||||
# For LLaMA-2:
|
||||
if getattr(self.hf_config, "num_key_value_heads", None) is not None:
|
||||
return (self.hf_config.num_key_value_heads //
|
||||
parallel_config.tensor_parallel_size)
|
||||
total_num_attention_heads = self.hf_config.num_attention_heads
|
||||
return total_num_attention_heads // parallel_config.tensor_parallel_size
|
||||
|
||||
def get_max_model_len(self) -> int:
|
||||
max_model_len = float("inf")
|
||||
possible_keys = [
|
||||
# OPT
|
||||
"max_position_embeddings",
|
||||
# GPT-2
|
||||
"n_positions",
|
||||
# MPT
|
||||
"max_seq_len",
|
||||
# Others
|
||||
"max_sequence_length",
|
||||
"max_seq_length",
|
||||
"seq_len",
|
||||
]
|
||||
for key in possible_keys:
|
||||
max_len_key = getattr(self.hf_config, key, None)
|
||||
if max_len_key is not None:
|
||||
max_model_len = min(max_model_len, max_len_key)
|
||||
return max_model_len
|
||||
|
||||
def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
|
||||
total_num_hidden_layers = self.hf_config.num_hidden_layers
|
||||
return total_num_hidden_layers // parallel_config.pipeline_parallel_size
|
||||
@ -90,6 +166,7 @@ class CacheConfig:
|
||||
vLLM execution.
|
||||
swap_space: Size of the CPU swap space per GPU (in GiB).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block_size: int,
|
||||
@ -98,7 +175,7 @@ class CacheConfig:
|
||||
) -> None:
|
||||
self.block_size = block_size
|
||||
self.gpu_memory_utilization = gpu_memory_utilization
|
||||
self.swap_space_bytes = swap_space * _GiB
|
||||
self.swap_space_bytes = swap_space * _GB
|
||||
self._verify_args()
|
||||
|
||||
# Will be set after profiling.
|
||||
@ -121,14 +198,13 @@ class CacheConfig:
|
||||
num_gpus_per_node = parallel_config.tensor_parallel_size
|
||||
cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node
|
||||
|
||||
msg = (
|
||||
f"{cpu_memory_usage / _GiB:.2f} GiB out of "
|
||||
f"the {total_cpu_memory / _GiB:.2f} GiB total CPU memory is "
|
||||
msg = (f"{cpu_memory_usage / _GB:.2f} GiB out of "
|
||||
f"the {total_cpu_memory / _GB:.2f} GiB total CPU memory is "
|
||||
"allocated for the swap space.")
|
||||
if cpu_memory_usage > 0.7 * total_cpu_memory:
|
||||
raise ValueError("Too large swap space. " + msg)
|
||||
elif cpu_memory_usage > 0.4 * total_cpu_memory:
|
||||
logger.warn("Possibly too large swap space. " + msg)
|
||||
logger.warning("Possibly too large swap space. " + msg)
|
||||
|
||||
|
||||
class ParallelConfig:
|
||||
@ -141,6 +217,7 @@ class ParallelConfig:
|
||||
True if either pipeline_parallel_size or tensor_parallel_size is
|
||||
greater than 1.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pipeline_parallel_size: int,
|
||||
@ -170,14 +247,15 @@ class SchedulerConfig:
|
||||
a single iteration.
|
||||
max_num_seqs: Maximum number of sequences to be processed in a single
|
||||
iteration.
|
||||
max_model_len: Maximum length of a sequence (including prompt
|
||||
and generated text).
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
max_num_batched_tokens: int,
|
||||
max_num_seqs: int,
|
||||
) -> None:
|
||||
|
||||
def __init__(self, max_num_batched_tokens: int, max_num_seqs: int,
|
||||
max_model_len: int) -> None:
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
self.max_num_seqs = max_num_seqs
|
||||
self.max_model_len = max_model_len
|
||||
|
||||
|
||||
_STR_DTYPE_TO_TORCH_DTYPE = {
|
||||
@ -221,7 +299,7 @@ def _get_and_verify_dtype(
|
||||
pass
|
||||
else:
|
||||
# Casting between float16 and bfloat16 is allowed with a warning.
|
||||
logger.warn(f"Casting {config_dtype} to {torch_dtype}.")
|
||||
logger.warning(f"Casting {config_dtype} to {torch_dtype}.")
|
||||
|
||||
# Check if the GPU supports the dtype.
|
||||
if torch_dtype == torch.bfloat16:
|
||||
|
||||
@ -27,8 +27,9 @@ class BlockAllocator:
|
||||
# Initialize the free blocks.
|
||||
self.free_blocks: List[PhysicalTokenBlock] = []
|
||||
for i in range(num_blocks):
|
||||
block = PhysicalTokenBlock(
|
||||
device=device, block_number=i, block_size=block_size)
|
||||
block = PhysicalTokenBlock(device=device,
|
||||
block_number=i,
|
||||
block_size=block_size)
|
||||
self.free_blocks.append(block)
|
||||
|
||||
def allocate(self) -> PhysicalTokenBlock:
|
||||
@ -84,10 +85,12 @@ class BlockSpaceManager:
|
||||
num_required_blocks = len(seq.logical_token_blocks)
|
||||
num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
|
||||
# Use watermark to avoid frequent cache eviction.
|
||||
return num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks
|
||||
return (num_free_gpu_blocks - num_required_blocks >=
|
||||
self.watermark_blocks)
|
||||
|
||||
def allocate(self, seq_group: SequenceGroup) -> None:
|
||||
# NOTE: Here we assume that all sequences in the group have the same prompt.
|
||||
# NOTE: Here we assume that all sequences in the group have the same
|
||||
# prompt.
|
||||
seq = seq_group.get_seqs()[0]
|
||||
|
||||
# Allocate new physical token blocks that will store the prompt tokens.
|
||||
@ -143,7 +146,8 @@ class BlockSpaceManager:
|
||||
for block in src_block_table:
|
||||
block.ref_count += 1
|
||||
|
||||
def _get_physical_blocks(self, seq_group: SequenceGroup) -> List[PhysicalTokenBlock]:
|
||||
def _get_physical_blocks(
|
||||
self, seq_group: SequenceGroup) -> List[PhysicalTokenBlock]:
|
||||
# NOTE: Here, we assume that the physical blocks are only shared by
|
||||
# the sequences in the same group.
|
||||
blocks: Set[PhysicalTokenBlock] = set()
|
||||
@ -168,9 +172,7 @@ class BlockSpaceManager:
|
||||
def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]:
|
||||
# CPU block -> GPU block.
|
||||
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
|
||||
for seq in seq_group.get_seqs():
|
||||
if seq.is_finished():
|
||||
continue
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
|
||||
new_block_table: BlockTable = []
|
||||
block_table = self.block_tables[seq.seq_id]
|
||||
|
||||
@ -199,9 +201,7 @@ class BlockSpaceManager:
|
||||
def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]:
|
||||
# GPU block -> CPU block.
|
||||
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
|
||||
for seq in seq_group.get_seqs():
|
||||
if seq.is_finished():
|
||||
continue
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||
new_block_table: BlockTable = []
|
||||
block_table = self.block_tables[seq.seq_id]
|
||||
|
||||
|
||||
@ -1,19 +1,16 @@
|
||||
import enum
|
||||
import time
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
from vllm.config import CacheConfig, SchedulerConfig
|
||||
from vllm.core.block_manager import BlockSpaceManager
|
||||
from vllm.core.policy import PolicyFactory
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
|
||||
SequenceGroupMetadata, SequenceOutputs,
|
||||
SequenceStatus)
|
||||
SequenceGroupMetadata, SequenceStatus)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_LOGGING_INTERVAL_SEC = 5
|
||||
|
||||
|
||||
class PreemptionMode(enum.Enum):
|
||||
"""Preemption modes.
|
||||
@ -32,20 +29,28 @@ class SchedulerOutputs:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scheduled_seq_groups: List[SequenceGroup],
|
||||
prompt_run: bool,
|
||||
num_batched_tokens: int,
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
ignored_seq_groups: List[SequenceGroup],
|
||||
) -> None:
|
||||
self.scheduled_seq_groups = scheduled_seq_groups
|
||||
self.prompt_run = prompt_run
|
||||
self.num_batched_tokens = num_batched_tokens
|
||||
self.blocks_to_swap_in = blocks_to_swap_in
|
||||
self.blocks_to_swap_out = blocks_to_swap_out
|
||||
self.blocks_to_copy = blocks_to_copy
|
||||
# Swap in and swap out should never happen at the same time.
|
||||
assert not (blocks_to_swap_in and blocks_to_swap_out)
|
||||
self.ignored_seq_groups = ignored_seq_groups
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
return (not self.blocks_to_swap_in
|
||||
and not self.blocks_to_swap_out
|
||||
and not self.blocks_to_copy)
|
||||
# NOTE: We do not consider the ignored sequence groups.
|
||||
return (not self.scheduled_seq_groups and not self.blocks_to_swap_in
|
||||
and not self.blocks_to_swap_out and not self.blocks_to_copy)
|
||||
|
||||
|
||||
class Scheduler:
|
||||
@ -54,14 +59,15 @@ class Scheduler:
|
||||
self,
|
||||
scheduler_config: SchedulerConfig,
|
||||
cache_config: CacheConfig,
|
||||
log_stats: bool,
|
||||
) -> None:
|
||||
self.scheduler_config = scheduler_config
|
||||
self.cache_config = cache_config
|
||||
self.log_stats = log_stats
|
||||
|
||||
self.prompt_limit = min(self.scheduler_config.max_model_len,
|
||||
self.scheduler_config.max_num_batched_tokens)
|
||||
|
||||
# Instantiate the scheduling policy.
|
||||
self.policy = PolicyFactory.get_policy(policy_name='fcfs')
|
||||
self.policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||
# Create the block space manager.
|
||||
self.block_manager = BlockSpaceManager(
|
||||
block_size=self.cache_config.block_size,
|
||||
@ -69,6 +75,7 @@ class Scheduler:
|
||||
num_cpu_blocks=self.cache_config.num_cpu_blocks,
|
||||
)
|
||||
|
||||
# TODO(zhuohan): Use deque instead of list for better performance.
|
||||
# Sequence groups in the WAITING state.
|
||||
self.waiting: List[SequenceGroup] = []
|
||||
# Sequence groups in the RUNNING state.
|
||||
@ -76,24 +83,29 @@ class Scheduler:
|
||||
# Sequence groups in the SWAPPED state.
|
||||
self.swapped: List[SequenceGroup] = []
|
||||
|
||||
self.last_logging_time: float = 0.0
|
||||
# List[timestamp, num_tokens]
|
||||
self.num_input_tokens: List[Tuple[float, int]] = []
|
||||
|
||||
def add_seq_group(self, seq_group: SequenceGroup) -> None:
|
||||
# Add sequence groups to the waiting queue.
|
||||
self.waiting.append(seq_group)
|
||||
|
||||
def abort_seq_group(self, request_id: str) -> None:
|
||||
def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None:
|
||||
if isinstance(request_id, str):
|
||||
request_id = (request_id, )
|
||||
request_ids = set(request_id)
|
||||
for state_queue in [self.waiting, self.running, self.swapped]:
|
||||
for seq_group in state_queue:
|
||||
if seq_group.request_id == request_id:
|
||||
# We need to reverse the list as we are removing elements
|
||||
# from it as we iterate over it. If we don't do it,
|
||||
# indices will get messed up and we will skip over elements.
|
||||
for seq_group in reversed(state_queue):
|
||||
if seq_group.request_id in request_ids:
|
||||
# Remove the sequence group from the state queue.
|
||||
state_queue.remove(seq_group)
|
||||
for seq in seq_group.seqs:
|
||||
for seq in seq_group.get_seqs():
|
||||
if seq.is_finished():
|
||||
continue
|
||||
self.free_seq(seq, SequenceStatus.FINISHED_ABORTED)
|
||||
seq.status = SequenceStatus.FINISHED_ABORTED
|
||||
self.free_seq(seq)
|
||||
request_ids.remove(seq_group.request_id)
|
||||
if not request_ids:
|
||||
return
|
||||
|
||||
def has_unfinished_seqs(self) -> bool:
|
||||
@ -102,7 +114,7 @@ class Scheduler:
|
||||
def get_num_unfinished_seq_groups(self) -> int:
|
||||
return len(self.waiting) + len(self.running) + len(self.swapped)
|
||||
|
||||
def _schedule(self) -> Tuple[SchedulerOutputs, List[str]]:
|
||||
def _schedule(self) -> SchedulerOutputs:
|
||||
# Blocks that need to be swaped or copied before model execution.
|
||||
blocks_to_swap_in: Dict[int, int] = {}
|
||||
blocks_to_swap_out: Dict[int, int] = {}
|
||||
@ -111,10 +123,72 @@ class Scheduler:
|
||||
# Fix the current time.
|
||||
now = time.time()
|
||||
|
||||
# NOTE(woosuk): We prioritize the sequence groups in the RUNNING state
|
||||
# in order to minimize the preemption overheads.
|
||||
# Preemption happens only when there is no available slot to keep all
|
||||
# the sequence groups in the RUNNING state.
|
||||
# Join waiting sequences if possible.
|
||||
if not self.swapped:
|
||||
ignored_seq_groups: List[SequenceGroup] = []
|
||||
scheduled: List[SequenceGroup] = []
|
||||
# The total number of sequences on the fly, including the
|
||||
# requests in the generation phase.
|
||||
num_curr_seqs = sum(seq_group.get_max_num_running_seqs()
|
||||
for seq_group in self.running)
|
||||
num_batched_tokens = 0
|
||||
# Optimization: We do not sort the waiting queue since the preempted
|
||||
# sequence groups are added to the front and the new sequence groups
|
||||
# are added to the back.
|
||||
while self.waiting:
|
||||
seq_group = self.waiting[0]
|
||||
|
||||
assert seq_group.num_seqs() == 1, (
|
||||
"Waiting sequence group should have only one prompt "
|
||||
"sequence.")
|
||||
num_prompt_tokens = seq_group.get_seqs()[0].get_len()
|
||||
if num_prompt_tokens > self.prompt_limit:
|
||||
logger.warning(
|
||||
f"Input prompt ({num_prompt_tokens} tokens) is too long"
|
||||
f" and exceeds limit of {self.prompt_limit}")
|
||||
for seq in seq_group.get_seqs():
|
||||
seq.status = SequenceStatus.FINISHED_IGNORED
|
||||
ignored_seq_groups.append(seq_group)
|
||||
self.waiting.pop(0)
|
||||
continue
|
||||
|
||||
# If the sequence group cannot be allocated, stop.
|
||||
if not self.block_manager.can_allocate(seq_group):
|
||||
break
|
||||
|
||||
# If the number of batched tokens exceeds the limit, stop.
|
||||
if (num_batched_tokens + num_prompt_tokens >
|
||||
self.scheduler_config.max_num_batched_tokens):
|
||||
break
|
||||
|
||||
# The total number of sequences in the RUNNING state should not
|
||||
# exceed the maximum number of sequences.
|
||||
num_new_seqs = seq_group.get_max_num_running_seqs()
|
||||
if (num_curr_seqs + num_new_seqs >
|
||||
self.scheduler_config.max_num_seqs):
|
||||
break
|
||||
|
||||
seq_group = self.waiting.pop(0)
|
||||
self._allocate(seq_group)
|
||||
self.running.append(seq_group)
|
||||
num_batched_tokens += num_prompt_tokens
|
||||
num_curr_seqs += num_new_seqs
|
||||
scheduled.append(seq_group)
|
||||
|
||||
if scheduled:
|
||||
scheduler_outputs = SchedulerOutputs(
|
||||
scheduled_seq_groups=scheduled,
|
||||
prompt_run=True,
|
||||
num_batched_tokens=num_batched_tokens,
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
ignored_seq_groups=ignored_seq_groups,
|
||||
)
|
||||
return scheduler_outputs
|
||||
|
||||
# NOTE(woosuk): Preemption happens only when there is no available slot
|
||||
# to keep all the sequence groups in the RUNNING state.
|
||||
# In this case, the policy is responsible for deciding which sequence
|
||||
# groups to preempt.
|
||||
self.running = self.policy.sort_by_priority(now, self.running)
|
||||
@ -144,129 +218,56 @@ class Scheduler:
|
||||
|
||||
# Swap in the sequence groups in the SWAPPED state if possible.
|
||||
self.swapped = self.policy.sort_by_priority(now, self.swapped)
|
||||
while self.swapped and not blocks_to_swap_out:
|
||||
if not preempted:
|
||||
num_curr_seqs = sum(seq_group.get_max_num_running_seqs()
|
||||
for seq_group in self.running)
|
||||
|
||||
while self.swapped:
|
||||
seq_group = self.swapped[0]
|
||||
# If the sequence group has been preempted in this step, stop.
|
||||
if seq_group in preempted:
|
||||
break
|
||||
# If the sequence group cannot be swapped in, stop.
|
||||
if not self.block_manager.can_swap_in(seq_group):
|
||||
break
|
||||
|
||||
# The total number of sequences in the RUNNING state should not
|
||||
# exceed the maximum number of sequences.
|
||||
num_new_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED)
|
||||
num_curr_seqs = len(self.running)
|
||||
if num_curr_seqs + num_new_seqs > self.scheduler_config.max_num_seqs:
|
||||
num_new_seqs = seq_group.get_max_num_running_seqs()
|
||||
if (num_curr_seqs + num_new_seqs >
|
||||
self.scheduler_config.max_num_seqs):
|
||||
break
|
||||
|
||||
seq_group = self.swapped.pop(0)
|
||||
self._swap_in(seq_group, blocks_to_swap_in)
|
||||
self._append_slot(seq_group, blocks_to_copy)
|
||||
num_curr_seqs += num_new_seqs
|
||||
self.running.append(seq_group)
|
||||
|
||||
# Each sequence in the generation phase only takes one token slot.
|
||||
# Therefore, the number of batched tokens is equal to the number of
|
||||
# sequences in the RUNNING state.
|
||||
num_batched_tokens = sum(
|
||||
seq_group.num_seqs(status=SequenceStatus.RUNNING)
|
||||
for seq_group in self.running
|
||||
)
|
||||
|
||||
# Join waiting sequences if possible.
|
||||
prompt_group_ids: List[str] = []
|
||||
# NOTE(woosuk): The sequence groups in the SWAPPED state are strictly
|
||||
# prioritized over the sequence groups in the WAITING state.
|
||||
# This is because we want to bound the amount of CPU memory taken by
|
||||
# the swapped sequence groups.
|
||||
if not self.swapped:
|
||||
# Optimization: We do not sort the waiting queue since the preempted
|
||||
# sequence groups are added to the front and the new sequence groups
|
||||
# are added to the back.
|
||||
while self.waiting:
|
||||
seq_group = self.waiting[0]
|
||||
# If the sequence group has been preempted in this step, stop.
|
||||
if seq_group in preempted:
|
||||
break
|
||||
# If the sequence group cannot be allocated, stop.
|
||||
if not self.block_manager.can_allocate(seq_group):
|
||||
break
|
||||
|
||||
# If the number of batched tokens exceeds the limit, stop.
|
||||
num_prompt_tokens = seq_group.get_seqs()[0].get_len()
|
||||
if (num_batched_tokens + num_prompt_tokens
|
||||
> self.scheduler_config.max_num_batched_tokens):
|
||||
break
|
||||
|
||||
# The total number of sequences in the RUNNING state should not
|
||||
# exceed the maximum number of sequences.
|
||||
num_new_seqs = seq_group.num_seqs(status=SequenceStatus.WAITING)
|
||||
num_curr_seqs = len(self.running)
|
||||
if num_curr_seqs + num_new_seqs > self.scheduler_config.max_num_seqs:
|
||||
break
|
||||
|
||||
seq_group = self.waiting.pop(0)
|
||||
self._allocate(seq_group)
|
||||
self.running.append(seq_group)
|
||||
num_batched_tokens += num_prompt_tokens
|
||||
prompt_group_ids.append(seq_group.request_id)
|
||||
for seq_group in self.running)
|
||||
|
||||
scheduler_outputs = SchedulerOutputs(
|
||||
scheduled_seq_groups=self.running,
|
||||
prompt_run=False,
|
||||
num_batched_tokens=num_batched_tokens,
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
ignored_seq_groups=[],
|
||||
)
|
||||
if not self.log_stats:
|
||||
return scheduler_outputs, prompt_group_ids
|
||||
|
||||
# TODO(woosuk): Move the below code to the engine.
|
||||
now = time.time()
|
||||
if num_batched_tokens > 0:
|
||||
self.num_input_tokens.append((now, num_batched_tokens))
|
||||
elapsed_time = now - self.last_logging_time
|
||||
if elapsed_time > _LOGGING_INTERVAL_SEC:
|
||||
self.last_logging_time = now
|
||||
self.num_input_tokens = [
|
||||
(t, n) for t, n in self.num_input_tokens
|
||||
if now - t < _LOGGING_INTERVAL_SEC
|
||||
]
|
||||
if len(self.num_input_tokens) > 1:
|
||||
total_num_tokens = sum(n for _, n in self.num_input_tokens[:-1])
|
||||
window = now - self.num_input_tokens[0][0]
|
||||
avg_throughput = total_num_tokens / window
|
||||
else:
|
||||
avg_throughput = 0.0
|
||||
|
||||
total_num_gpu_blocks = self.cache_config.num_gpu_blocks
|
||||
num_free_gpu_blocks = self.block_manager.get_num_free_gpu_blocks()
|
||||
num_used_gpu_blocks = total_num_gpu_blocks - num_free_gpu_blocks
|
||||
gpu_cache_usage = num_used_gpu_blocks / total_num_gpu_blocks
|
||||
|
||||
total_num_cpu_blocks = self.cache_config.num_cpu_blocks
|
||||
if total_num_cpu_blocks > 0:
|
||||
num_free_cpu_blocks = self.block_manager.get_num_free_cpu_blocks()
|
||||
num_used_cpu_blocks = total_num_cpu_blocks - num_free_cpu_blocks
|
||||
cpu_cache_usage = num_used_cpu_blocks / total_num_cpu_blocks
|
||||
else:
|
||||
cpu_cache_usage = 0.0
|
||||
|
||||
logger.info(
|
||||
f"Throughput: {avg_throughput:.1f} tokens/s, "
|
||||
f"Running: {len(self.running)} reqs, "
|
||||
f"Swapped: {len(self.swapped)} reqs, "
|
||||
f"Pending: {len(self.waiting)} reqs, "
|
||||
f"GPU KV cache usage: {gpu_cache_usage * 100:.1f}%, "
|
||||
f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%")
|
||||
return scheduler_outputs, prompt_group_ids
|
||||
return scheduler_outputs
|
||||
|
||||
def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
|
||||
# Schedule sequence groups.
|
||||
# This function call changes the internal states of the scheduler
|
||||
# such as self.running, self.swapped, and self.waiting.
|
||||
scheduler_outputs, prompt_group_ids = self._schedule()
|
||||
scheduler_outputs = self._schedule()
|
||||
|
||||
# Create input data structures.
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata] = []
|
||||
for seq_group in self.running:
|
||||
is_prompt = seq_group.request_id in prompt_group_ids
|
||||
|
||||
for seq_group in scheduler_outputs.scheduled_seq_groups:
|
||||
seq_data: Dict[int, List[SequenceData]] = {}
|
||||
block_tables: Dict[int, List[int]] = {}
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||
@ -276,7 +277,7 @@ class Scheduler:
|
||||
|
||||
seq_group_metadata = SequenceGroupMetadata(
|
||||
request_id=seq_group.request_id,
|
||||
is_prompt=is_prompt,
|
||||
is_prompt=scheduler_outputs.prompt_run,
|
||||
seq_data=seq_data,
|
||||
sampling_params=seq_group.sampling_params,
|
||||
block_tables=block_tables,
|
||||
@ -284,35 +285,10 @@ class Scheduler:
|
||||
seq_group_metadata_list.append(seq_group_metadata)
|
||||
return seq_group_metadata_list, scheduler_outputs
|
||||
|
||||
def update(
|
||||
self,
|
||||
seq_outputs: Dict[int, SequenceOutputs],
|
||||
) -> List[SequenceGroup]:
|
||||
# Update the running sequences and free blocks.
|
||||
for seq_group in self.running:
|
||||
# Process beam search results before processing the new tokens.
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||
output = seq_outputs[seq.seq_id]
|
||||
if seq.seq_id != output.parent_seq_id:
|
||||
# The sequence is a fork of the parent sequence (beam search).
|
||||
# Free the current sequence.
|
||||
self.block_manager.free(seq)
|
||||
# Fork the parent sequence.
|
||||
parent_seq = seq_group.find(output.parent_seq_id)
|
||||
parent_seq.fork(seq)
|
||||
self.block_manager.fork(parent_seq, seq)
|
||||
def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None:
|
||||
self.block_manager.fork(parent_seq, child_seq)
|
||||
|
||||
# Process the new tokens.
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||
# Append a new token to the sequence.
|
||||
output = seq_outputs[seq.seq_id]
|
||||
seq.append_token_id(output.output_token, output.logprobs)
|
||||
# Return a shallow copy of the running queue to prevent the queue
|
||||
# from being modified by the caller.
|
||||
return self.running.copy()
|
||||
|
||||
def free_seq(self, seq: Sequence, finish_status: SequenceStatus) -> None:
|
||||
seq.status = finish_status
|
||||
def free_seq(self, seq: Sequence) -> None:
|
||||
self.block_manager.free(seq)
|
||||
|
||||
def free_finished_seq_groups(self) -> None:
|
||||
@ -349,8 +325,8 @@ class Scheduler:
|
||||
# If preemption mode is not specified, we determine the mode as follows:
|
||||
# We use recomputation by default since it incurs lower overhead than
|
||||
# swapping. However, when the sequence group has multiple sequences
|
||||
# (e.g., beam search), recomputation is not supported. In such a case,
|
||||
# we use swapping instead.
|
||||
# (e.g., beam search), recomputation is not currently supported. In
|
||||
# such a case, we use swapping instead.
|
||||
# FIXME(woosuk): This makes our scheduling policy a bit bizarre.
|
||||
# As swapped sequences are prioritized over waiting sequences,
|
||||
# sequence groups with multiple sequences are implicitly prioritized
|
||||
@ -358,8 +334,7 @@ class Scheduler:
|
||||
# TODO(woosuk): Support recomputation for sequence groups with multiple
|
||||
# sequences. This may require a more sophisticated CUDA kernel.
|
||||
if preemption_mode is None:
|
||||
seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
|
||||
if len(seqs) == 1:
|
||||
if seq_group.get_max_num_running_seqs() == 1:
|
||||
preemption_mode = PreemptionMode.RECOMPUTE
|
||||
else:
|
||||
preemption_mode = PreemptionMode.SWAP
|
||||
@ -368,7 +343,7 @@ class Scheduler:
|
||||
elif preemption_mode == PreemptionMode.SWAP:
|
||||
self._preempt_by_swap(seq_group, blocks_to_swap_out)
|
||||
else:
|
||||
assert False, 'Invalid preemption mode.'
|
||||
assert False, "Invalid preemption mode."
|
||||
|
||||
def _preempt_by_recompute(
|
||||
self,
|
||||
@ -388,9 +363,6 @@ class Scheduler:
|
||||
seq_group: SequenceGroup,
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
) -> None:
|
||||
seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
|
||||
for seq in seqs:
|
||||
seq.status = SequenceStatus.SWAPPED
|
||||
self._swap_out(seq_group, blocks_to_swap_out)
|
||||
self.swapped.append(seq_group)
|
||||
|
||||
|
||||
@ -11,10 +11,12 @@ from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
|
||||
class EngineArgs:
|
||||
"""Arguments for vLLM engine."""
|
||||
model: str
|
||||
tokenizer: Optional[str] = None
|
||||
tokenizer_mode: str = 'auto'
|
||||
trust_remote_code: bool = False
|
||||
download_dir: Optional[str] = None
|
||||
use_np_weights: bool = False
|
||||
use_dummy_weights: bool = False
|
||||
dtype: str = "auto"
|
||||
load_format: str = 'auto'
|
||||
dtype: str = 'auto'
|
||||
seed: int = 0
|
||||
worker_use_ray: bool = False
|
||||
pipeline_parallel_size: int = 1
|
||||
@ -27,72 +29,117 @@ class EngineArgs:
|
||||
disable_log_stats: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if self.tokenizer is None:
|
||||
self.tokenizer = self.model
|
||||
self.max_num_seqs = min(self.max_num_seqs, self.max_num_batched_tokens)
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(
|
||||
parser: argparse.ArgumentParser,
|
||||
) -> argparse.ArgumentParser:
|
||||
parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
||||
"""Shared CLI arguments for vLLM engine."""
|
||||
# Model arguments
|
||||
parser.add_argument('--model', type=str, default='facebook/opt-125m',
|
||||
parser.add_argument(
|
||||
'--model',
|
||||
type=str,
|
||||
default='facebook/opt-125m',
|
||||
help='name or path of the huggingface model to use')
|
||||
parser.add_argument('--download-dir', type=str,
|
||||
parser.add_argument(
|
||||
'--tokenizer',
|
||||
type=str,
|
||||
default=EngineArgs.tokenizer,
|
||||
help='name or path of the huggingface tokenizer to use')
|
||||
parser.add_argument('--tokenizer-mode',
|
||||
type=str,
|
||||
default=EngineArgs.tokenizer_mode,
|
||||
choices=['auto', 'slow'],
|
||||
help='tokenizer mode. "auto" will use the fast '
|
||||
'tokenizer if available, and "slow" will '
|
||||
'always use the slow tokenizer.')
|
||||
parser.add_argument('--trust-remote-code',
|
||||
action='store_true',
|
||||
help='trust remote code from huggingface')
|
||||
parser.add_argument('--download-dir',
|
||||
type=str,
|
||||
default=EngineArgs.download_dir,
|
||||
help='directory to download and load the weights, '
|
||||
'default to the default cache dir of '
|
||||
'huggingface')
|
||||
parser.add_argument('--use-np-weights', action='store_true',
|
||||
help='save a numpy copy of model weights for '
|
||||
'faster loading. This can increase the disk '
|
||||
'usage by up to 2x.')
|
||||
parser.add_argument('--use-dummy-weights', action='store_true',
|
||||
help='use dummy values for model weights')
|
||||
parser.add_argument(
|
||||
'--load-format',
|
||||
type=str,
|
||||
default=EngineArgs.load_format,
|
||||
choices=['auto', 'pt', 'safetensors', 'npcache', 'dummy'],
|
||||
help='The format of the model weights to load. '
|
||||
'"auto" will try to load the weights in the safetensors format '
|
||||
'and fall back to the pytorch bin format if safetensors format '
|
||||
'is not available. '
|
||||
'"pt" will load the weights in the pytorch bin format. '
|
||||
'"safetensors" will load the weights in the safetensors format. '
|
||||
'"npcache" will load the weights in pytorch format and store '
|
||||
'a numpy cache to speed up the loading. '
|
||||
'"dummy" will initialize the weights with random values, '
|
||||
'which is mainly for profiling.')
|
||||
# TODO(woosuk): Support FP32.
|
||||
parser.add_argument('--dtype', type=str, default=EngineArgs.dtype,
|
||||
parser.add_argument(
|
||||
'--dtype',
|
||||
type=str,
|
||||
default=EngineArgs.dtype,
|
||||
choices=['auto', 'half', 'bfloat16', 'float'],
|
||||
help='data type for model weights and activations. '
|
||||
'The "auto" option will use FP16 precision '
|
||||
'for FP32 and FP16 models, and BF16 precision '
|
||||
'for BF16 models.')
|
||||
# Parallel arguments
|
||||
parser.add_argument('--worker-use-ray', action='store_true',
|
||||
parser.add_argument('--worker-use-ray',
|
||||
action='store_true',
|
||||
help='use Ray for distributed serving, will be '
|
||||
'automatically set when using more than 1 GPU')
|
||||
parser.add_argument('--pipeline-parallel-size', '-pp', type=int,
|
||||
parser.add_argument('--pipeline-parallel-size',
|
||||
'-pp',
|
||||
type=int,
|
||||
default=EngineArgs.pipeline_parallel_size,
|
||||
help='number of pipeline stages')
|
||||
parser.add_argument('--tensor-parallel-size', '-tp', type=int,
|
||||
parser.add_argument('--tensor-parallel-size',
|
||||
'-tp',
|
||||
type=int,
|
||||
default=EngineArgs.tensor_parallel_size,
|
||||
help='number of tensor parallel replicas')
|
||||
# KV cache arguments
|
||||
parser.add_argument('--block-size', type=int,
|
||||
parser.add_argument('--block-size',
|
||||
type=int,
|
||||
default=EngineArgs.block_size,
|
||||
choices=[8, 16, 32],
|
||||
help='token block size')
|
||||
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
|
||||
parser.add_argument('--seed', type=int, default=EngineArgs.seed,
|
||||
parser.add_argument('--seed',
|
||||
type=int,
|
||||
default=EngineArgs.seed,
|
||||
help='random seed')
|
||||
parser.add_argument('--swap-space', type=int,
|
||||
parser.add_argument('--swap-space',
|
||||
type=int,
|
||||
default=EngineArgs.swap_space,
|
||||
help='CPU swap space size (GiB) per GPU')
|
||||
parser.add_argument('--gpu-memory-utilization', type=float,
|
||||
parser.add_argument('--gpu-memory-utilization',
|
||||
type=float,
|
||||
default=EngineArgs.gpu_memory_utilization,
|
||||
help='the percentage of GPU memory to be used for'
|
||||
'the model executor')
|
||||
parser.add_argument('--max-num-batched-tokens', type=int,
|
||||
parser.add_argument('--max-num-batched-tokens',
|
||||
type=int,
|
||||
default=EngineArgs.max_num_batched_tokens,
|
||||
help='maximum number of batched tokens per '
|
||||
'iteration')
|
||||
parser.add_argument('--max-num-seqs', type=int,
|
||||
parser.add_argument('--max-num-seqs',
|
||||
type=int,
|
||||
default=EngineArgs.max_num_seqs,
|
||||
help='maximum number of sequences per iteration')
|
||||
parser.add_argument('--disable-log-stats', action='store_true',
|
||||
parser.add_argument('--disable-log-stats',
|
||||
action='store_true',
|
||||
help='disable logging statistics')
|
||||
return parser
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace) -> "EngineArgs":
|
||||
def from_cli_args(cls, args: argparse.Namespace) -> 'EngineArgs':
|
||||
# Get the list of attributes of this dataclass.
|
||||
attrs = [attr.name for attr in dataclasses.fields(cls)]
|
||||
# Set the attributes from the parsed arguments.
|
||||
@ -103,16 +150,19 @@ class EngineArgs:
|
||||
self,
|
||||
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]:
|
||||
# Initialize the configs.
|
||||
model_config = ModelConfig(
|
||||
self.model, self.download_dir, self.use_np_weights,
|
||||
self.use_dummy_weights, self.dtype, self.seed)
|
||||
cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization,
|
||||
model_config = ModelConfig(self.model, self.tokenizer,
|
||||
self.tokenizer_mode, self.trust_remote_code,
|
||||
self.download_dir, self.load_format,
|
||||
self.dtype, self.seed)
|
||||
cache_config = CacheConfig(self.block_size,
|
||||
self.gpu_memory_utilization,
|
||||
self.swap_space)
|
||||
parallel_config = ParallelConfig(self.pipeline_parallel_size,
|
||||
self.tensor_parallel_size,
|
||||
self.worker_use_ray)
|
||||
scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
|
||||
self.max_num_seqs)
|
||||
self.max_num_seqs,
|
||||
model_config.get_max_model_len())
|
||||
return model_config, cache_config, parallel_config, scheduler_config
|
||||
|
||||
|
||||
@ -124,12 +174,13 @@ class AsyncEngineArgs(EngineArgs):
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(
|
||||
parser: argparse.ArgumentParser,
|
||||
) -> argparse.ArgumentParser:
|
||||
parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
parser.add_argument('--engine-use-ray', action='store_true',
|
||||
parser.add_argument('--engine-use-ray',
|
||||
action='store_true',
|
||||
help='use Ray to start the LLM engine in a '
|
||||
'separate process as the server process.')
|
||||
parser.add_argument('--disable-log-requests', action='store_true',
|
||||
parser.add_argument('--disable-log-requests',
|
||||
action='store_true',
|
||||
help='disable logging requests')
|
||||
return parser
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Dict, List, Optional
|
||||
from functools import partial
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
from vllm.engine.ray_utils import initialize_cluster, ray
|
||||
@ -11,7 +13,202 @@ from vllm.sampling_params import SamplingParams
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
|
||||
|
||||
class AsyncEngineDeadError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
def _raise_exception_on_finish(task: asyncio.Task,
|
||||
request_tracker: "RequestTracker") -> None:
|
||||
msg = ("Task finished unexpectedly. This should never happen! "
|
||||
"Please open an issue on Github.")
|
||||
try:
|
||||
try:
|
||||
task.result()
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
except Exception as exc:
|
||||
raise AsyncEngineDeadError(
|
||||
msg + " See stack trace above for the actual cause.") from exc
|
||||
raise AsyncEngineDeadError(msg)
|
||||
except Exception as exc:
|
||||
request_tracker.propagate_exception(exc)
|
||||
raise exc
|
||||
|
||||
|
||||
class AsyncStream:
|
||||
"""A stream of RequestOutputs for a request that can be
|
||||
iterated over asynchronously."""
|
||||
|
||||
def __init__(self, request_id: str) -> None:
|
||||
self.request_id = request_id
|
||||
self._queue = asyncio.Queue()
|
||||
self._finished = False
|
||||
|
||||
def put(self, item: RequestOutput) -> None:
|
||||
if self._finished:
|
||||
return
|
||||
self._queue.put_nowait(item)
|
||||
|
||||
def finish(self) -> None:
|
||||
self._queue.put_nowait(StopIteration)
|
||||
self._finished = True
|
||||
|
||||
@property
|
||||
def finished(self) -> bool:
|
||||
return self._finished
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self) -> RequestOutput:
|
||||
result = await self._queue.get()
|
||||
if result is StopIteration:
|
||||
raise StopAsyncIteration
|
||||
elif isinstance(result, Exception):
|
||||
raise result
|
||||
return result
|
||||
|
||||
|
||||
class RequestTracker:
|
||||
"""Synchronous abstraction for tracking requests."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._request_streams: Dict[str, AsyncStream] = {}
|
||||
self._finished_requests: asyncio.Queue[str] = asyncio.Queue()
|
||||
self._new_requests: asyncio.Queue[Tuple[AsyncStream,
|
||||
dict]] = asyncio.Queue()
|
||||
|
||||
def __contains__(self, item):
|
||||
return item in self._request_streams
|
||||
|
||||
def propagate_exception(self, exc: Exception) -> None:
|
||||
"""Propagate an exception to all request streams."""
|
||||
for stream in self._request_streams.values():
|
||||
stream.put(exc)
|
||||
|
||||
def process_request_output(self,
|
||||
request_output: RequestOutput,
|
||||
*,
|
||||
verbose: bool = False) -> None:
|
||||
"""Process a request output from the engine."""
|
||||
request_id = request_output.request_id
|
||||
|
||||
self._request_streams[request_id].put(request_output)
|
||||
if request_output.finished:
|
||||
if verbose:
|
||||
logger.info(f"Finished request {request_id}.")
|
||||
self.abort_request(request_id)
|
||||
|
||||
def add_request(self, request_id: str,
|
||||
**engine_add_request_kwargs) -> AsyncStream:
|
||||
"""Add a request to be sent to the engine on the next background
|
||||
loop iteration."""
|
||||
if request_id in self._request_streams:
|
||||
raise KeyError(f"Request {request_id} already exists.")
|
||||
|
||||
stream = AsyncStream(request_id)
|
||||
self._new_requests.put_nowait((stream, {
|
||||
"request_id": request_id,
|
||||
**engine_add_request_kwargs
|
||||
}))
|
||||
return stream
|
||||
|
||||
def abort_request(self, request_id: str, *, verbose: bool = False) -> None:
|
||||
"""Abort a request during next background loop iteration."""
|
||||
if verbose:
|
||||
logger.info(f"Aborted request {request_id}.")
|
||||
|
||||
self._finished_requests.put_nowait(request_id)
|
||||
|
||||
if request_id not in self._request_streams or self._request_streams[
|
||||
request_id].finished:
|
||||
# The request has already finished or been aborted.
|
||||
return
|
||||
|
||||
self._request_streams[request_id].finish()
|
||||
|
||||
def get_new_and_finished_requests(self) -> Tuple[List[dict], Set[str]]:
|
||||
"""Get the new requests and finished requests to be
|
||||
sent to the engine."""
|
||||
new_requests: List[dict] = []
|
||||
finished_requests: Set[str] = set()
|
||||
|
||||
while not self._finished_requests.empty():
|
||||
request_id = self._finished_requests.get_nowait()
|
||||
finished_requests.add(request_id)
|
||||
self._request_streams.pop(request_id, None)
|
||||
|
||||
while not self._new_requests.empty():
|
||||
stream, new_request = self._new_requests.get_nowait()
|
||||
if stream.request_id in finished_requests:
|
||||
# The request has already been aborted.
|
||||
stream.finish()
|
||||
continue
|
||||
self._request_streams[stream.request_id] = stream
|
||||
new_requests.append(new_request)
|
||||
|
||||
return new_requests, finished_requests
|
||||
|
||||
|
||||
class _AsyncLLMEngine(LLMEngine):
|
||||
"""Extension of LLMEngine to add async methods."""
|
||||
|
||||
async def step_async(self) -> List[RequestOutput]:
|
||||
"""Performs one decoding iteration and returns newly generated results.
|
||||
The workers are ran asynchronously if possible.
|
||||
|
||||
This function performs one decoding iteration of the engine. It first
|
||||
schedules the sequences to be executed in the next iteration and the
|
||||
token blocks to be swapped in/out/copy. Then, it executes the model
|
||||
and updates the scheduler with the model outputs. Finally, it decodes
|
||||
the sequences and returns the newly generated results.
|
||||
"""
|
||||
(seq_group_metadata_list, scheduler_outputs,
|
||||
early_return) = self._schedule()
|
||||
if early_return is not None:
|
||||
return early_return
|
||||
|
||||
# Execute the model.
|
||||
output = await self._run_workers_async(
|
||||
"execute_model",
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
|
||||
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
|
||||
blocks_to_copy=scheduler_outputs.blocks_to_copy,
|
||||
)
|
||||
|
||||
return self._process_model_outputs(output, scheduler_outputs)
|
||||
|
||||
async def _run_workers_async(
|
||||
self,
|
||||
method: str,
|
||||
*args,
|
||||
get_all_outputs: bool = False,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""Runs the given method on all workers."""
|
||||
all_outputs = []
|
||||
for worker in self.workers:
|
||||
if self.parallel_config.worker_use_ray:
|
||||
executor = partial(worker.execute_method.remote, method)
|
||||
else:
|
||||
executor = getattr(worker, method)
|
||||
|
||||
output = executor(*args, **kwargs)
|
||||
all_outputs.append(output)
|
||||
|
||||
if self.parallel_config.worker_use_ray:
|
||||
all_outputs = await asyncio.gather(*all_outputs)
|
||||
|
||||
if get_all_outputs:
|
||||
return all_outputs
|
||||
|
||||
# Make sure all workers have the same results.
|
||||
output = all_outputs[0]
|
||||
for other_output in all_outputs[1:]:
|
||||
assert output == other_output
|
||||
return output
|
||||
|
||||
|
||||
class AsyncLLMEngine:
|
||||
@ -35,53 +232,125 @@ class AsyncLLMEngine:
|
||||
log_requests: Whether to log the requests.
|
||||
*args, *kwargs: Arguments for LLMEngine.
|
||||
"""
|
||||
def __init__(self, worker_use_ray: bool, engine_use_ray: bool,
|
||||
log_requests: bool = True, *args, **kwargs) -> None:
|
||||
|
||||
_engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine
|
||||
|
||||
def __init__(self,
|
||||
worker_use_ray: bool,
|
||||
engine_use_ray: bool,
|
||||
*args,
|
||||
log_requests: bool = True,
|
||||
start_engine_loop: bool = False,
|
||||
**kwargs) -> None:
|
||||
self.worker_use_ray = worker_use_ray
|
||||
self.engine_use_ray = engine_use_ray
|
||||
self.log_requests = log_requests
|
||||
if not self.engine_use_ray:
|
||||
engine_class = LLMEngine
|
||||
elif self.worker_use_ray:
|
||||
engine_class = ray.remote(num_cpus=0)(LLMEngine).remote
|
||||
else:
|
||||
engine_class = ray.remote(num_gpus=1)(LLMEngine).remote
|
||||
self.engine = engine_class(*args, **kwargs)
|
||||
# Request id -> request output.
|
||||
self.request_outputs: Dict[str, RequestOutput] = {}
|
||||
# Request id -> event to notify that there is new output.
|
||||
self.request_events: Dict[str, asyncio.Event] = {}
|
||||
self.is_engine_running = False
|
||||
self.kicking_request_id: Optional[str] = None
|
||||
self.engine = self._init_engine(*args, **kwargs)
|
||||
|
||||
async def engine_step(self, kicking_request_id: Optional[str] = None):
|
||||
self.request_tracker: RequestTracker = RequestTracker()
|
||||
self.background_loop = None
|
||||
if start_engine_loop:
|
||||
self.start_background_loop()
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
return (self.background_loop is not None
|
||||
and not self.background_loop.done())
|
||||
|
||||
def start_background_loop(self) -> None:
|
||||
"""Start the background loop."""
|
||||
if self.is_running:
|
||||
raise RuntimeError("Background loop is already running.")
|
||||
self.background_loop = asyncio.get_event_loop().create_task(
|
||||
self.run_engine_loop())
|
||||
self.background_loop.add_done_callback(
|
||||
partial(_raise_exception_on_finish,
|
||||
request_tracker=self.request_tracker))
|
||||
|
||||
def _init_engine(self, *args,
|
||||
**kwargs) -> Union[_AsyncLLMEngine, "ray.ObjectRef"]:
|
||||
if not self.engine_use_ray:
|
||||
engine_class = self._engine_class
|
||||
elif self.worker_use_ray:
|
||||
engine_class = ray.remote(num_cpus=0)(self._engine_class).remote
|
||||
else:
|
||||
engine_class = ray.remote(num_gpus=1)(self._engine_class).remote
|
||||
return engine_class(*args, **kwargs)
|
||||
|
||||
async def engine_step(self):
|
||||
"""Kick the engine to process the waiting requests."""
|
||||
self.is_engine_running = True
|
||||
self.kicking_request_id = kicking_request_id
|
||||
|
||||
new_requests, finished_requests = (
|
||||
self.request_tracker.get_new_and_finished_requests())
|
||||
|
||||
for new_request in new_requests:
|
||||
# Add the request into the vLLM engine's waiting queue.
|
||||
# TODO: Maybe add add_request_batch to reduce Ray overhead
|
||||
if self.engine_use_ray:
|
||||
await self.engine.add_request.remote(**new_request)
|
||||
else:
|
||||
self.engine.add_request(**new_request)
|
||||
|
||||
if finished_requests:
|
||||
await self._engine_abort(finished_requests)
|
||||
|
||||
if self.engine_use_ray:
|
||||
request_outputs = await self.engine.step.remote()
|
||||
else:
|
||||
# Yield to the event loop to allow other coroutines to run
|
||||
# while is_engine_running is True. This let the engine to add new
|
||||
# requests into the queue.
|
||||
await asyncio.sleep(0)
|
||||
request_outputs = self.engine.step()
|
||||
self.is_engine_running = False
|
||||
self.kicking_request_id = None
|
||||
request_outputs = await self.engine.step_async()
|
||||
|
||||
# Notify the waiting coroutines that there are new outputs ready.
|
||||
# Put the outputs into the corresponding streams.
|
||||
for request_output in request_outputs:
|
||||
request_id = request_output.request_id
|
||||
self.request_outputs[request_id] = request_output
|
||||
self.request_events[request_id].set()
|
||||
self.request_tracker.process_request_output(
|
||||
request_output, verbose=self.log_requests)
|
||||
|
||||
async def _engine_abort(self, request_ids: Iterable[str]):
|
||||
if self.engine_use_ray:
|
||||
await self.engine.abort_request.remote(request_ids)
|
||||
else:
|
||||
self.engine.abort_request(request_ids)
|
||||
|
||||
async def run_engine_loop(self):
|
||||
while True:
|
||||
await self.engine_step()
|
||||
await asyncio.sleep(0)
|
||||
|
||||
async def add_request(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: Optional[str],
|
||||
sampling_params: SamplingParams,
|
||||
prompt_token_ids: Optional[List[int]] = None,
|
||||
arrival_time: Optional[float] = None,
|
||||
) -> AsyncStream:
|
||||
if self.log_requests:
|
||||
logger.info(f"Received request {request_id}: "
|
||||
f"prompt: {prompt!r}, "
|
||||
f"sampling params: {sampling_params}, "
|
||||
f"prompt token ids: {prompt_token_ids}.")
|
||||
|
||||
if not self.is_running:
|
||||
raise AsyncEngineDeadError(
|
||||
"Background loop is not running. If it was running, "
|
||||
"inspect the output to find the stacktrace of the "
|
||||
"error that caused the background loop to stop "
|
||||
"(AsyncEngineDeadError).")
|
||||
|
||||
stream = self.request_tracker.add_request(
|
||||
request_id,
|
||||
prompt=prompt,
|
||||
sampling_params=sampling_params,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
arrival_time=arrival_time)
|
||||
|
||||
return stream
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
prompt: Optional[str],
|
||||
sampling_params: SamplingParams,
|
||||
request_id: str,
|
||||
prompt_token_ids: Optional[List[int]] = None
|
||||
) -> RequestOutput:
|
||||
prompt_token_ids: Optional[List[int]] = None) -> RequestOutput:
|
||||
"""Generate outputs for a request.
|
||||
|
||||
Generate outputs for a request. This method is a coroutine. It adds the
|
||||
@ -103,69 +372,19 @@ class AsyncLLMEngine:
|
||||
# Preprocess the request.
|
||||
arrival_time = time.time()
|
||||
|
||||
# Create an event to notify us that there is new output from the
|
||||
# vLLM engine.
|
||||
request_event = asyncio.Event()
|
||||
self.request_events[request_id] = request_event
|
||||
|
||||
if self.log_requests:
|
||||
logger.info(f"Received request {request_id}: "
|
||||
f"prompt: {prompt!r}, "
|
||||
f"sampling params: {sampling_params}, "
|
||||
f"prompt token ids: {prompt_token_ids}.")
|
||||
|
||||
# Add the request into the vLLM engine's waiting queue.
|
||||
if self.engine_use_ray:
|
||||
await self.engine.add_request.remote(
|
||||
request_id, prompt, sampling_params,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
arrival_time=arrival_time)
|
||||
else:
|
||||
self.engine.add_request(
|
||||
request_id, prompt, sampling_params,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
arrival_time=arrival_time)
|
||||
|
||||
# The vLLM engine does not have a background loop that keeps
|
||||
# processing incoming requests. Therefore, we need to keep kicking
|
||||
# the engine to process the requests.
|
||||
while True:
|
||||
if request_id not in self.request_events:
|
||||
# The request has been aborted.
|
||||
return
|
||||
|
||||
# Kick the engine if the engine is not running.
|
||||
if not self.is_engine_running:
|
||||
await self.engine_step(request_id)
|
||||
|
||||
# Wait for new output. The group_event will be set in engine_step
|
||||
# when there is new output available for the sequence group.
|
||||
# Added a timeout to prevent deadlock.
|
||||
try:
|
||||
await asyncio.wait_for(request_event.wait(),
|
||||
timeout=TIMEOUT_TO_PREVENT_DEADLOCK)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
# Reset the event to wait for the next output.
|
||||
request_event.clear()
|
||||
stream = await self.add_request(request_id,
|
||||
prompt,
|
||||
sampling_params,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
arrival_time=arrival_time)
|
||||
|
||||
# Decode and return new outputs.
|
||||
request_output = self.request_outputs[request_id]
|
||||
async for request_output in stream:
|
||||
yield request_output
|
||||
|
||||
# Once finished, release the resources of the sequence group.
|
||||
if request_output.finished:
|
||||
if self.log_requests:
|
||||
logger.info(f"Finished request {request_id}.")
|
||||
|
||||
del self.request_outputs[request_id]
|
||||
del self.request_events[request_id]
|
||||
# Kick the engine if the engine is not running. This is to
|
||||
# prevent that there are still requests in engine's waiting
|
||||
# queue to be executed.
|
||||
if not self.is_engine_running:
|
||||
await self.engine_step()
|
||||
break
|
||||
except Exception as e:
|
||||
# If there is an exception, abort the request.
|
||||
self._abort(request_id)
|
||||
raise e
|
||||
|
||||
async def abort(self, request_id: str) -> None:
|
||||
"""Abort a request.
|
||||
@ -176,43 +395,52 @@ class AsyncLLMEngine:
|
||||
Args:
|
||||
request_id: The unique id of the request.
|
||||
"""
|
||||
if request_id not in self.request_events:
|
||||
# The request has already finished or been aborted.
|
||||
return
|
||||
if not self.is_running:
|
||||
raise AsyncEngineDeadError(
|
||||
"Background loop is not running. If it was running, "
|
||||
"inspect the output to find the stacktrace of the "
|
||||
"error that caused the background loop to stop "
|
||||
"(AsyncEngineDeadError).")
|
||||
|
||||
if self.log_requests:
|
||||
logger.info(f"Aborted request {request_id}.")
|
||||
return self._abort(request_id)
|
||||
|
||||
def _abort(self, request_id: str) -> None:
|
||||
"""Abort a request.
|
||||
|
||||
Abort a submitted request. If the request is finished or not found,
|
||||
this method will be a no-op.
|
||||
|
||||
Args:
|
||||
request_id: The unique id of the request.
|
||||
"""
|
||||
self.request_tracker.abort_request(request_id,
|
||||
verbose=self.log_requests)
|
||||
|
||||
async def get_model_config(self) -> ModelConfig:
|
||||
"""Get the model configuration of the vLLM engine."""
|
||||
if self.engine_use_ray:
|
||||
await self.engine.abort_request.remote(request_id)
|
||||
return await self.engine.get_model_config.remote()
|
||||
else:
|
||||
self.engine.abort_request(request_id)
|
||||
|
||||
if request_id in self.request_events:
|
||||
del self.request_events[request_id]
|
||||
if request_id in self.request_outputs:
|
||||
del self.request_outputs[request_id]
|
||||
|
||||
# To prevent deadlock when a request is aborted while the engine is
|
||||
# running.
|
||||
if self.kicking_request_id == request_id:
|
||||
self.is_engine_running = False
|
||||
self.kicking_request_id = None
|
||||
return self.engine.get_model_config()
|
||||
|
||||
@classmethod
|
||||
def from_engine_args(cls, engine_args: AsyncEngineArgs) -> "AsyncLLMEngine":
|
||||
def from_engine_args(cls,
|
||||
engine_args: AsyncEngineArgs,
|
||||
start_engine_loop: bool = False) -> "AsyncLLMEngine":
|
||||
"""Creates an async LLM engine from the engine arguments."""
|
||||
# Create the engine configs.
|
||||
engine_configs = engine_args.create_engine_configs()
|
||||
parallel_config = engine_configs[2]
|
||||
# Initialize the cluster.
|
||||
distributed_init_method, devices = initialize_cluster(
|
||||
distributed_init_method, placement_group = initialize_cluster(
|
||||
parallel_config, engine_args.engine_use_ray)
|
||||
# Create the async LLM engine.
|
||||
engine = cls(engine_args.worker_use_ray,
|
||||
engine_args.engine_use_ray,
|
||||
not engine_args.disable_log_requests,
|
||||
*engine_configs,
|
||||
distributed_init_method, devices,
|
||||
log_stats=not engine_args.disable_log_stats)
|
||||
distributed_init_method,
|
||||
placement_group,
|
||||
log_requests=not engine_args.disable_log_requests,
|
||||
log_stats=not engine_args.disable_log_stats,
|
||||
start_engine_loop=start_engine_loop)
|
||||
return engine
|
||||
|
||||
@ -1,21 +1,34 @@
|
||||
import copy
|
||||
import time
|
||||
from typing import Any, List, Optional
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
|
||||
SchedulerConfig)
|
||||
from vllm.core.scheduler import Scheduler
|
||||
from vllm.core.scheduler import Scheduler, SchedulerOutputs
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.engine.ray_utils import DeviceID, initialize_cluster, ray
|
||||
from vllm.engine.tokenizer_utils import detokenize_incrementally, get_tokenizer
|
||||
from vllm.engine.ray_utils import RayWorker, initialize_cluster, ray
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
|
||||
from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup,
|
||||
SequenceGroupMetadata, SequenceOutputs,
|
||||
SequenceStatus)
|
||||
from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
|
||||
get_tokenizer)
|
||||
from vllm.utils import Counter
|
||||
from vllm.worker.worker import Worker
|
||||
|
||||
if ray:
|
||||
from ray.air.util.torch_dist import init_torch_dist_process_group
|
||||
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.util.placement_group import PlacementGroup
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_LOGGING_INTERVAL_SEC = 5
|
||||
|
||||
|
||||
class LLMEngine:
|
||||
"""An LLM engine that receives requests and generates texts.
|
||||
@ -53,19 +66,20 @@ class LLMEngine:
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
distributed_init_method: str,
|
||||
stage_devices: List[List[DeviceID]],
|
||||
placement_group: Optional["PlacementGroup"],
|
||||
log_stats: bool,
|
||||
) -> None:
|
||||
logger.info(
|
||||
"Initializing an LLM engine with config: "
|
||||
f"model={model_config.model!r}, "
|
||||
f"tokenizer={model_config.tokenizer!r}, "
|
||||
f"tokenizer_mode={model_config.tokenizer_mode}, "
|
||||
f"trust_remote_code={model_config.trust_remote_code}, "
|
||||
f"dtype={model_config.dtype}, "
|
||||
f"use_dummy_weights={model_config.use_dummy_weights}, "
|
||||
f"download_dir={model_config.download_dir!r}, "
|
||||
f"use_np_weights={model_config.use_np_weights}, "
|
||||
f"load_format={model_config.load_format}, "
|
||||
f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
|
||||
f"seed={model_config.seed})"
|
||||
)
|
||||
f"seed={model_config.seed})")
|
||||
# TODO(woosuk): Print more configs in debug mode.
|
||||
|
||||
self.model_config = model_config
|
||||
@ -75,34 +89,91 @@ class LLMEngine:
|
||||
self.log_stats = log_stats
|
||||
self._verify_args()
|
||||
|
||||
self.tokenizer = get_tokenizer(model_config.model)
|
||||
self.tokenizer = get_tokenizer(
|
||||
model_config.tokenizer,
|
||||
tokenizer_mode=model_config.tokenizer_mode,
|
||||
trust_remote_code=model_config.trust_remote_code)
|
||||
self.seq_counter = Counter()
|
||||
|
||||
# Create the parallel GPU workers.
|
||||
self.workers: List[Worker] = []
|
||||
assert len(stage_devices) == 1, "Only support one stage for now."
|
||||
for rank, node_resource, _ in stage_devices[0]:
|
||||
worker_cls = Worker
|
||||
if self.parallel_config.worker_use_ray:
|
||||
worker_cls = ray.remote(
|
||||
num_cpus=0,
|
||||
num_gpus=1,
|
||||
resources={node_resource: 1e-3},
|
||||
)(worker_cls).remote
|
||||
self._init_workers_ray(placement_group)
|
||||
else:
|
||||
self._init_workers(distributed_init_method)
|
||||
|
||||
worker = worker_cls(
|
||||
model_config,
|
||||
parallel_config,
|
||||
scheduler_config,
|
||||
rank,
|
||||
distributed_init_method,
|
||||
)
|
||||
self.workers.append(worker)
|
||||
# Profile the memory usage and initialize the cache.
|
||||
self._init_cache()
|
||||
|
||||
# Create the scheduler.
|
||||
self.scheduler = Scheduler(scheduler_config, cache_config, log_stats)
|
||||
self.scheduler = Scheduler(scheduler_config, cache_config)
|
||||
|
||||
# Logging.
|
||||
self.last_logging_time = 0.0
|
||||
# List of (timestamp, num_tokens)
|
||||
self.num_prompt_tokens: List[Tuple[float, int]] = []
|
||||
# List of (timestamp, num_tokens)
|
||||
self.num_generation_tokens: List[Tuple[float, int]] = []
|
||||
|
||||
def _init_workers(self, distributed_init_method: str):
|
||||
# Lazy import the Worker to avoid importing torch.cuda/xformers
|
||||
# before CUDA_VISIBLE_DEVICES is set in the Worker
|
||||
from vllm.worker.worker import Worker # pylint: disable=import-outside-toplevel
|
||||
|
||||
assert self.parallel_config.world_size == 1, (
|
||||
"Ray is required if parallel_config.world_size > 1.")
|
||||
|
||||
self.workers: List[Worker] = []
|
||||
worker = Worker(
|
||||
self.model_config,
|
||||
self.parallel_config,
|
||||
self.scheduler_config,
|
||||
0,
|
||||
distributed_init_method,
|
||||
)
|
||||
self.workers.append(worker)
|
||||
self._run_workers(
|
||||
"init_model",
|
||||
get_all_outputs=True,
|
||||
)
|
||||
|
||||
def _init_workers_ray(self, placement_group: "PlacementGroup",
|
||||
**ray_remote_kwargs):
|
||||
# Lazy import the Worker to avoid importing torch.cuda/xformers
|
||||
# before CUDA_VISIBLE_DEVICES is set in the Worker
|
||||
from vllm.worker.worker import Worker # pylint: disable=import-outside-toplevel
|
||||
|
||||
self.workers: List[Worker] = []
|
||||
for bundle in placement_group.bundle_specs:
|
||||
if not bundle.get("GPU", 0):
|
||||
continue
|
||||
worker = ray.remote(
|
||||
num_cpus=0,
|
||||
num_gpus=1,
|
||||
scheduling_strategy=PlacementGroupSchedulingStrategy(
|
||||
placement_group=placement_group,
|
||||
placement_group_capture_child_tasks=True),
|
||||
**ray_remote_kwargs,
|
||||
)(RayWorker).remote()
|
||||
self.workers.append(worker)
|
||||
|
||||
# Initialize torch distributed process group for the workers.
|
||||
init_torch_dist_process_group(self.workers, backend="nccl")
|
||||
model_config = copy.deepcopy(self.model_config)
|
||||
parallel_config = copy.deepcopy(self.parallel_config)
|
||||
scheduler_config = copy.deepcopy(self.scheduler_config)
|
||||
self._run_workers("init_worker",
|
||||
get_all_outputs=True,
|
||||
worker_init_fn=lambda: Worker(
|
||||
model_config,
|
||||
parallel_config,
|
||||
scheduler_config,
|
||||
None,
|
||||
None,
|
||||
))
|
||||
self._run_workers(
|
||||
"init_model",
|
||||
get_all_outputs=True,
|
||||
)
|
||||
|
||||
def _verify_args(self) -> None:
|
||||
self.model_config.verify_with_parallel_config(self.parallel_config)
|
||||
@ -125,10 +196,10 @@ class LLMEngine:
|
||||
num_gpu_blocks = min(b[0] for b in num_blocks)
|
||||
num_cpu_blocks = min(b[1] for b in num_blocks)
|
||||
# FIXME(woosuk): Change to debug log.
|
||||
logger.info(f'# GPU blocks: {num_gpu_blocks}, '
|
||||
f'# CPU blocks: {num_cpu_blocks}')
|
||||
logger.info(f"# GPU blocks: {num_gpu_blocks}, "
|
||||
f"# CPU blocks: {num_cpu_blocks}")
|
||||
|
||||
if num_gpu_blocks <= 0 or num_cpu_blocks <= 0:
|
||||
if num_gpu_blocks <= 0:
|
||||
raise ValueError("No available memory for the cache blocks. "
|
||||
"Try increasing `gpu_memory_utilization` when "
|
||||
"initializing the engine.")
|
||||
@ -146,9 +217,12 @@ class LLMEngine:
|
||||
engine_configs = engine_args.create_engine_configs()
|
||||
parallel_config = engine_configs[2]
|
||||
# Initialize the cluster.
|
||||
distributed_init_method, devices = initialize_cluster(parallel_config)
|
||||
distributed_init_method, placement_group = initialize_cluster(
|
||||
parallel_config)
|
||||
# Create the LLM engine.
|
||||
engine = cls(*engine_configs, distributed_init_method, devices,
|
||||
engine = cls(*engine_configs,
|
||||
distributed_init_method,
|
||||
placement_group,
|
||||
log_stats=not engine_args.disable_log_stats)
|
||||
return engine
|
||||
|
||||
@ -184,27 +258,28 @@ class LLMEngine:
|
||||
|
||||
# Create the sequences.
|
||||
block_size = self.cache_config.block_size
|
||||
seqs: List[Sequence] = []
|
||||
for _ in range(sampling_params.best_of):
|
||||
seq_id = next(self.seq_counter)
|
||||
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size)
|
||||
seqs.append(seq)
|
||||
|
||||
# Create the sequence group.
|
||||
seq_group = SequenceGroup(request_id, seqs, sampling_params,
|
||||
seq_group = SequenceGroup(request_id, [seq], sampling_params,
|
||||
arrival_time)
|
||||
|
||||
# Add the sequence group to the scheduler.
|
||||
self.scheduler.add_seq_group(seq_group)
|
||||
|
||||
def abort_request(self, request_id: str) -> None:
|
||||
"""Aborts a request with the given ID.
|
||||
def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
|
||||
"""Aborts a request(s) with the given ID.
|
||||
|
||||
Args:
|
||||
request_id: The ID of the request to abort.
|
||||
request_id: The ID(s) of the request to abort.
|
||||
"""
|
||||
self.scheduler.abort_seq_group(request_id)
|
||||
|
||||
def get_model_config(self) -> ModelConfig:
|
||||
"""Gets the model configuration."""
|
||||
return self.model_config
|
||||
|
||||
def get_num_unfinished_requests(self) -> int:
|
||||
"""Gets the number of unfinished requests."""
|
||||
return self.scheduler.get_num_unfinished_seq_groups()
|
||||
@ -213,6 +288,251 @@ class LLMEngine:
|
||||
"""Returns True if there are unfinished requests."""
|
||||
return self.scheduler.has_unfinished_seqs()
|
||||
|
||||
def _schedule(
|
||||
self
|
||||
) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs,
|
||||
Optional[List[RequestOutput]]]:
|
||||
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
|
||||
if scheduler_outputs.is_empty():
|
||||
return seq_group_metadata_list, scheduler_outputs, [
|
||||
RequestOutput.from_seq_group(seq_group)
|
||||
for seq_group in scheduler_outputs.ignored_seq_groups
|
||||
]
|
||||
return seq_group_metadata_list, scheduler_outputs, None
|
||||
|
||||
def _check_beam_search_early_stopping(
|
||||
self,
|
||||
early_stopping: Union[bool, str],
|
||||
sampling_params: SamplingParams,
|
||||
best_running_seq: Sequence,
|
||||
current_worst_seq: Sequence,
|
||||
) -> bool:
|
||||
assert sampling_params.use_beam_search
|
||||
length_penalty = sampling_params.length_penalty
|
||||
if early_stopping is True:
|
||||
return True
|
||||
|
||||
current_worst_score = (current_worst_seq.get_beam_search_score(
|
||||
length_penalty=length_penalty,
|
||||
eos_token_id=self.tokenizer.eos_token_id))
|
||||
if early_stopping is False:
|
||||
highest_attainable_score = (best_running_seq.get_beam_search_score(
|
||||
length_penalty=length_penalty,
|
||||
eos_token_id=self.tokenizer.eos_token_id))
|
||||
else:
|
||||
assert early_stopping == "never"
|
||||
if length_penalty > 0.0:
|
||||
# If length_penalty > 0.0, beam search will prefer longer
|
||||
# sequences. The highest attainable score calculation is
|
||||
# based on the longest possible sequence length in this case.
|
||||
max_possible_length = max(
|
||||
best_running_seq.get_prompt_len() +
|
||||
sampling_params.max_tokens,
|
||||
self.scheduler_config.max_model_len)
|
||||
highest_attainable_score = (
|
||||
best_running_seq.get_beam_search_score(
|
||||
length_penalty=length_penalty,
|
||||
eos_token_id=self.tokenizer.eos_token_id,
|
||||
seq_len=max_possible_length))
|
||||
else:
|
||||
# Otherwise, beam search will prefer shorter sequences. The
|
||||
# highest attainable score calculation is based on the current
|
||||
# sequence length.
|
||||
highest_attainable_score = (
|
||||
best_running_seq.get_beam_search_score(
|
||||
length_penalty=length_penalty,
|
||||
eos_token_id=self.tokenizer.eos_token_id))
|
||||
return current_worst_score >= highest_attainable_score
|
||||
|
||||
def _process_sequence_group_samples(
|
||||
self, seq_group: SequenceGroup,
|
||||
samples: List[SequenceOutputs]) -> None:
|
||||
parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
|
||||
existing_finished_seqs = seq_group.get_finished_seqs()
|
||||
parent_child_dict = {
|
||||
parent_seq.seq_id: []
|
||||
for parent_seq in parent_seqs
|
||||
}
|
||||
for sample in samples:
|
||||
parent_child_dict[sample.parent_seq_id].append(sample)
|
||||
# List of (child, parent)
|
||||
child_seqs: List[Tuple[Sequence, Sequence]] = []
|
||||
|
||||
# Process the child samples for each parent sequence
|
||||
for parent in parent_seqs:
|
||||
child_samples: List[SequenceOutputs] = parent_child_dict[
|
||||
parent.seq_id]
|
||||
if len(child_samples) == 0:
|
||||
# This parent sequence has no children samples. Remove
|
||||
# the parent sequence from the sequence group since it will
|
||||
# not be used in the future iterations.
|
||||
parent.status = SequenceStatus.FINISHED_ABORTED
|
||||
seq_group.remove(parent.seq_id)
|
||||
self.scheduler.free_seq(parent)
|
||||
continue
|
||||
# Fork the parent sequence if there are multiple child samples.
|
||||
for child_sample in child_samples[:-1]:
|
||||
new_child_seq_id = next(self.seq_counter)
|
||||
child = parent.fork(new_child_seq_id)
|
||||
child.append_token_id(child_sample.output_token,
|
||||
child_sample.logprobs)
|
||||
child_seqs.append((child, parent))
|
||||
# Continue the parent sequence for the last child sample.
|
||||
# We reuse the parent sequence here to reduce redundant memory
|
||||
# copies, especially when using non-beam search sampling methods.
|
||||
last_child_sample = child_samples[-1]
|
||||
parent.append_token_id(last_child_sample.output_token,
|
||||
last_child_sample.logprobs)
|
||||
child_seqs.append((parent, parent))
|
||||
|
||||
for seq, _ in child_seqs:
|
||||
self._decode_sequence(seq)
|
||||
self._check_stop(seq, seq_group.sampling_params)
|
||||
|
||||
# Non-beam search case
|
||||
if not seq_group.sampling_params.use_beam_search:
|
||||
# For newly created child sequences, add them to the sequence group
|
||||
# and fork them in block manager if they are not finished.
|
||||
for seq, parent in child_seqs:
|
||||
if seq is not parent:
|
||||
seq_group.add(seq)
|
||||
if not seq.is_finished():
|
||||
self.scheduler.fork_seq(parent, seq)
|
||||
|
||||
# Free the finished and selected parent sequences' memory in block
|
||||
# manager. Keep them in the sequence group as candidate output.
|
||||
# NOTE: we need to fork the new sequences before freeing the
|
||||
# old sequences.
|
||||
for seq, parent in child_seqs:
|
||||
if seq is parent and seq.is_finished():
|
||||
self.scheduler.free_seq(seq)
|
||||
return
|
||||
|
||||
# Beam search case
|
||||
# Select the child sequences to keep in the sequence group.
|
||||
selected_child_seqs = []
|
||||
unselected_child_seqs = []
|
||||
beam_width = seq_group.sampling_params.best_of
|
||||
length_penalty = seq_group.sampling_params.length_penalty
|
||||
|
||||
# Select the newly finished sequences with the highest scores
|
||||
# to replace existing finished sequences.
|
||||
# Tuple of (seq, parent, is_new)
|
||||
existing_finished_seqs = [(seq, None, False)
|
||||
for seq in existing_finished_seqs]
|
||||
new_finished_seqs = [(seq, parent, True) for seq, parent in child_seqs
|
||||
if seq.is_finished()]
|
||||
all_finished_seqs = existing_finished_seqs + new_finished_seqs
|
||||
# Sort the finished sequences by their scores.
|
||||
all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score(
|
||||
length_penalty=length_penalty,
|
||||
eos_token_id=self.tokenizer.eos_token_id),
|
||||
reverse=True)
|
||||
for seq, parent, is_new in all_finished_seqs[:beam_width]:
|
||||
if is_new:
|
||||
# A newly generated child sequence finishes and has a high
|
||||
# score, so we will add it into the sequence group.
|
||||
selected_child_seqs.append((seq, parent))
|
||||
for seq, parent, is_new in all_finished_seqs[beam_width:]:
|
||||
if is_new:
|
||||
# A newly generated child sequence finishes but has a low
|
||||
# score, so we will not add it into the sequence group.
|
||||
# Additionally, if this sequence is a continuation of a
|
||||
# parent sequence, we will need remove the parent sequence
|
||||
# from the sequence group.
|
||||
unselected_child_seqs.append((seq, parent))
|
||||
else:
|
||||
# An existing finished sequence has a low score, so we will
|
||||
# remove it from the sequence group.
|
||||
seq_group.remove(seq.seq_id)
|
||||
|
||||
# select the top beam_width sequences from the running
|
||||
# sequences for the next iteration to continue the beam
|
||||
# search.
|
||||
running_child_seqs = [(seq, parent) for seq, parent in child_seqs
|
||||
if not seq.is_finished()]
|
||||
# Sort the running sequences by their scores.
|
||||
running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score(
|
||||
length_penalty=length_penalty,
|
||||
eos_token_id=self.tokenizer.eos_token_id),
|
||||
reverse=True)
|
||||
|
||||
# Check if we can stop the beam search.
|
||||
if len(running_child_seqs) == 0:
|
||||
# No running sequences, stop the beam search.
|
||||
stop_beam_search = True
|
||||
elif len(all_finished_seqs) < beam_width:
|
||||
# Not enough finished sequences, continue the beam search.
|
||||
stop_beam_search = False
|
||||
else:
|
||||
# Check the early stopping criteria
|
||||
best_running_seq = running_child_seqs[0][0]
|
||||
current_worst_seq = all_finished_seqs[beam_width - 1][0]
|
||||
stop_beam_search = self._check_beam_search_early_stopping(
|
||||
seq_group.sampling_params.early_stopping,
|
||||
seq_group.sampling_params, best_running_seq, current_worst_seq)
|
||||
|
||||
if stop_beam_search:
|
||||
# Stop the beam search and remove all the running sequences from
|
||||
# the sequence group.
|
||||
unselected_child_seqs.extend(running_child_seqs)
|
||||
else:
|
||||
# Continue the beam search and select the top beam_width sequences
|
||||
# to continue the beam search.
|
||||
selected_child_seqs.extend(running_child_seqs[:beam_width])
|
||||
# The remaining running sequences will not be used in the next
|
||||
# iteration. Again, if these sequences are continuations of
|
||||
# parent sequences, we will need to remove the parent sequences
|
||||
# from the sequence group.
|
||||
unselected_child_seqs.extend(running_child_seqs[beam_width:])
|
||||
|
||||
# For newly created child sequences, add them to the sequence group
|
||||
# and fork them in block manager if they are not finished.
|
||||
for seq, parent in selected_child_seqs:
|
||||
if seq is not parent:
|
||||
seq_group.add(seq)
|
||||
if not seq.is_finished():
|
||||
self.scheduler.fork_seq(parent, seq)
|
||||
|
||||
# Free the finished and selected parent sequences' memory in block
|
||||
# manager. Keep them in the sequence group as candidate output.
|
||||
for seq, parent in selected_child_seqs:
|
||||
if seq is parent and seq.is_finished():
|
||||
self.scheduler.free_seq(seq)
|
||||
|
||||
# Remove the unselected parent sequences from the sequence group and
|
||||
# free their memory in block manager.
|
||||
for seq, parent in unselected_child_seqs:
|
||||
if seq is parent:
|
||||
# Remove the parent sequence if it is not selected for next
|
||||
# iteration
|
||||
seq_group.remove(seq.seq_id)
|
||||
self.scheduler.free_seq(seq)
|
||||
|
||||
def _process_model_outputs(
|
||||
self, output: SamplerOutput,
|
||||
scheduler_outputs: SchedulerOutputs) -> List[RequestOutput]:
|
||||
# Update the scheduled sequence groups with the model outputs.
|
||||
scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups
|
||||
for seq_group, samples in zip(scheduled_seq_groups, output):
|
||||
self._process_sequence_group_samples(seq_group, samples)
|
||||
|
||||
# Free the finished sequence groups.
|
||||
self.scheduler.free_finished_seq_groups()
|
||||
|
||||
# Create the outputs.
|
||||
request_outputs: List[RequestOutput] = []
|
||||
for seq_group in (scheduled_seq_groups +
|
||||
scheduler_outputs.ignored_seq_groups):
|
||||
request_output = RequestOutput.from_seq_group(seq_group)
|
||||
request_outputs.append(request_output)
|
||||
|
||||
if self.log_stats:
|
||||
# Log the system stats.
|
||||
self._log_system_stats(scheduler_outputs.prompt_run,
|
||||
scheduler_outputs.num_batched_tokens)
|
||||
return request_outputs
|
||||
|
||||
def step(self) -> List[RequestOutput]:
|
||||
"""Performs one decoding iteration and returns newly generated results.
|
||||
|
||||
@ -222,10 +542,10 @@ class LLMEngine:
|
||||
and updates the scheduler with the model outputs. Finally, it decodes
|
||||
the sequences and returns the newly generated results.
|
||||
"""
|
||||
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
|
||||
if (not seq_group_metadata_list) and scheduler_outputs.is_empty():
|
||||
# Nothing to do.
|
||||
return []
|
||||
(seq_group_metadata_list, scheduler_outputs,
|
||||
early_return) = self._schedule()
|
||||
if early_return is not None:
|
||||
return early_return
|
||||
|
||||
# Execute the model.
|
||||
output = self._run_workers(
|
||||
@ -235,80 +555,125 @@ class LLMEngine:
|
||||
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
|
||||
blocks_to_copy=scheduler_outputs.blocks_to_copy,
|
||||
)
|
||||
# Update the scheduler with the model outputs.
|
||||
seq_groups = self.scheduler.update(output)
|
||||
|
||||
# Decode the sequences.
|
||||
self._decode_sequences(seq_groups)
|
||||
# Stop the sequences that meet the stopping criteria.
|
||||
self._stop_sequences(seq_groups)
|
||||
# Free the finished sequence groups.
|
||||
self.scheduler.free_finished_seq_groups()
|
||||
return self._process_model_outputs(output, scheduler_outputs)
|
||||
|
||||
# Create the outputs.
|
||||
request_outputs: List[RequestOutput] = []
|
||||
for seq_group in seq_groups:
|
||||
request_output = RequestOutput.from_seq_group(seq_group)
|
||||
request_outputs.append(request_output)
|
||||
return request_outputs
|
||||
def _log_system_stats(
|
||||
self,
|
||||
prompt_run: bool,
|
||||
num_batched_tokens: int,
|
||||
) -> None:
|
||||
now = time.time()
|
||||
# Log the number of batched input tokens.
|
||||
if prompt_run:
|
||||
self.num_prompt_tokens.append((now, num_batched_tokens))
|
||||
else:
|
||||
self.num_generation_tokens.append((now, num_batched_tokens))
|
||||
|
||||
def _decode_sequences(self, seq_groups: List[SequenceGroup]) -> None:
|
||||
"""Decodes the sequence outputs."""
|
||||
for seq_group in seq_groups:
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||
elapsed_time = now - self.last_logging_time
|
||||
if elapsed_time < _LOGGING_INTERVAL_SEC:
|
||||
return
|
||||
|
||||
# Discard the old stats.
|
||||
self.num_prompt_tokens = [(t, n) for t, n in self.num_prompt_tokens
|
||||
if now - t < _LOGGING_INTERVAL_SEC]
|
||||
self.num_generation_tokens = [(t, n)
|
||||
for t, n in self.num_generation_tokens
|
||||
if now - t < _LOGGING_INTERVAL_SEC]
|
||||
|
||||
if len(self.num_prompt_tokens) > 1:
|
||||
total_num_tokens = sum(n for _, n in self.num_prompt_tokens[:-1])
|
||||
window = now - self.num_prompt_tokens[0][0]
|
||||
avg_prompt_throughput = total_num_tokens / window
|
||||
else:
|
||||
avg_prompt_throughput = 0.0
|
||||
if len(self.num_generation_tokens) > 1:
|
||||
total_num_tokens = sum(n
|
||||
for _, n in self.num_generation_tokens[:-1])
|
||||
window = now - self.num_generation_tokens[0][0]
|
||||
avg_generation_throughput = total_num_tokens / window
|
||||
else:
|
||||
avg_generation_throughput = 0.0
|
||||
|
||||
total_num_gpu_blocks = self.cache_config.num_gpu_blocks
|
||||
num_free_gpu_blocks = (
|
||||
self.scheduler.block_manager.get_num_free_gpu_blocks())
|
||||
num_used_gpu_blocks = total_num_gpu_blocks - num_free_gpu_blocks
|
||||
gpu_cache_usage = num_used_gpu_blocks / total_num_gpu_blocks
|
||||
|
||||
total_num_cpu_blocks = self.cache_config.num_cpu_blocks
|
||||
if total_num_cpu_blocks > 0:
|
||||
num_free_cpu_blocks = (
|
||||
self.scheduler.block_manager.get_num_free_cpu_blocks())
|
||||
num_used_cpu_blocks = total_num_cpu_blocks - num_free_cpu_blocks
|
||||
cpu_cache_usage = num_used_cpu_blocks / total_num_cpu_blocks
|
||||
else:
|
||||
cpu_cache_usage = 0.0
|
||||
|
||||
logger.info("Avg prompt throughput: "
|
||||
f"{avg_prompt_throughput:.1f} tokens/s, "
|
||||
"Avg generation throughput: "
|
||||
f"{avg_generation_throughput:.1f} tokens/s, "
|
||||
f"Running: {len(self.scheduler.running)} reqs, "
|
||||
f"Swapped: {len(self.scheduler.swapped)} reqs, "
|
||||
f"Pending: {len(self.scheduler.waiting)} reqs, "
|
||||
f"GPU KV cache usage: {gpu_cache_usage * 100:.1f}%, "
|
||||
f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%")
|
||||
self.last_logging_time = now
|
||||
|
||||
def _decode_sequence(self, seq: Sequence) -> None:
|
||||
"""Decodes the new token for a sequence."""
|
||||
new_token, new_output_text = detokenize_incrementally(
|
||||
self.tokenizer,
|
||||
seq.output_tokens,
|
||||
seq.get_last_token_id(),
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
if new_token is not None:
|
||||
seq.output_tokens.append(new_token)
|
||||
seq.output_text = new_output_text
|
||||
|
||||
def _stop_sequences(self, seq_groups: List[SequenceGroup]) -> None:
|
||||
def _check_stop(self, seq: Sequence,
|
||||
sampling_params: SamplingParams) -> None:
|
||||
"""Stop the finished sequences."""
|
||||
for seq_group in seq_groups:
|
||||
sampling_params = seq_group.sampling_params
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||
# Check if the sequence has generated a stop string.
|
||||
stopped = False
|
||||
for stop_str in sampling_params.stop:
|
||||
if seq.output_text.endswith(stop_str):
|
||||
# Truncate the output text so that the stop string is
|
||||
# not included in the output.
|
||||
seq.output_text = seq.output_text[:-len(stop_str)]
|
||||
self.scheduler.free_seq(seq,
|
||||
SequenceStatus.FINISHED_STOPPED)
|
||||
stopped = True
|
||||
break
|
||||
if stopped:
|
||||
continue
|
||||
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||
return
|
||||
|
||||
# Check if the sequence has reached max_model_len.
|
||||
if seq.get_len() > self.scheduler_config.max_model_len:
|
||||
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
|
||||
return
|
||||
|
||||
# Check if the sequence has reached max_tokens.
|
||||
if seq.get_output_len() == sampling_params.max_tokens:
|
||||
self.scheduler.free_seq(
|
||||
seq, SequenceStatus.FINISHED_LENGTH_CAPPED)
|
||||
continue
|
||||
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
|
||||
return
|
||||
|
||||
# Check if the sequence has generated the EOS token.
|
||||
if not sampling_params.ignore_eos:
|
||||
if seq.get_last_token_id() == self.tokenizer.eos_token_id:
|
||||
self.scheduler.free_seq(seq,
|
||||
SequenceStatus.FINISHED_STOPPED)
|
||||
continue
|
||||
if ((not sampling_params.ignore_eos)
|
||||
and seq.get_last_token_id() == self.tokenizer.eos_token_id):
|
||||
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||
return
|
||||
|
||||
def _run_workers(
|
||||
self,
|
||||
method: str,
|
||||
get_all_outputs: bool = False,
|
||||
*args,
|
||||
get_all_outputs: bool = False,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""Runs the given method on all workers."""
|
||||
all_outputs = []
|
||||
for worker in self.workers:
|
||||
executor = getattr(worker, method)
|
||||
if self.parallel_config.worker_use_ray:
|
||||
executor = executor.remote
|
||||
executor = partial(worker.execute_method.remote, method)
|
||||
else:
|
||||
executor = getattr(worker, method)
|
||||
|
||||
output = executor(*args, **kwargs)
|
||||
all_outputs.append(output)
|
||||
|
||||
@ -1,21 +1,49 @@
|
||||
import random
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
try:
|
||||
import ray
|
||||
except ImportError:
|
||||
ray = None
|
||||
import socket
|
||||
from typing import Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
from vllm.config import ParallelConfig
|
||||
|
||||
DeviceID = Tuple[int, Optional[str], int] # rank, node resource (node IP), device id
|
||||
try:
|
||||
import ray
|
||||
from ray.air.util.torch_dist import TorchDistributedWorker
|
||||
|
||||
class RayWorker(TorchDistributedWorker):
|
||||
"""Ray wrapper for vllm.worker.Worker, allowing Worker to be
|
||||
lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.worker = None
|
||||
|
||||
def init_worker(self, worker_init_fn):
|
||||
self.worker = worker_init_fn()
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self.worker, name)
|
||||
|
||||
def execute_method(self, method, *args, **kwargs):
|
||||
executor = getattr(self, method)
|
||||
return executor(*args, **kwargs)
|
||||
|
||||
except ImportError:
|
||||
ray = None
|
||||
TorchDistributedWorker = None
|
||||
RayWorker = None # pylint: disable=invalid-name
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.util.placement_group import PlacementGroup
|
||||
|
||||
|
||||
def get_open_port():
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(("", 0))
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
def initialize_cluster(
|
||||
parallel_config: ParallelConfig,
|
||||
engine_use_ray: bool = False,
|
||||
ray_address: Optional[str] = None,
|
||||
) -> Tuple[str, List[List[DeviceID]]]:
|
||||
) -> Tuple[str, Optional["PlacementGroup"]]:
|
||||
"""Initialize the distributed cluster probably with Ray.
|
||||
|
||||
Args:
|
||||
@ -37,71 +65,46 @@ def initialize_cluster(
|
||||
"Ray is not installed. Please install Ray to use distributed "
|
||||
"serving.")
|
||||
# Connect to a ray cluster.
|
||||
ray.init(address=ray_address)
|
||||
ray.init(address=ray_address, ignore_reinit_error=True)
|
||||
|
||||
if not parallel_config.worker_use_ray:
|
||||
# Initialize cluster locally.
|
||||
port = random.randint(10000, 20000)
|
||||
port = get_open_port()
|
||||
# We need to setup the distributed init method to make sure
|
||||
# the distributed megatron code (e.g., get world size) works correctly.
|
||||
distributed_init_method = f"tcp://localhost:{port}"
|
||||
all_stage_devices = [[(0, None, 0)]]
|
||||
return distributed_init_method, all_stage_devices
|
||||
return distributed_init_method, None
|
||||
|
||||
# Assume we have a uniform cluster that each node has the same number of
|
||||
# GPUs for now.
|
||||
valid_node_resources = []
|
||||
num_devices_per_node = None
|
||||
for node in ray.nodes():
|
||||
if (not node['Alive']) or node['Resources']['GPU'] <= 0:
|
||||
continue
|
||||
if num_devices_per_node is None:
|
||||
num_devices_per_node = node['Resources']['GPU']
|
||||
else:
|
||||
assert num_devices_per_node == node['Resources']['GPU'], (
|
||||
"The number of GPUs per node is not uniform.")
|
||||
for key in node['Resources']:
|
||||
if key.startswith('node:'):
|
||||
valid_node_resources.append(key)
|
||||
|
||||
# Verify the parallel config.
|
||||
num_nodes = len(valid_node_resources)
|
||||
if parallel_config.world_size > num_nodes * num_devices_per_node:
|
||||
current_placement_group = ray.util.get_current_placement_group()
|
||||
if current_placement_group:
|
||||
# We are in a placement group
|
||||
bundles = current_placement_group.bundle_specs
|
||||
# Verify that we can use the placement group.
|
||||
gpu_bundles = 0
|
||||
for bundle in bundles:
|
||||
bundle_gpus = bundle.get("GPU", 0)
|
||||
if bundle_gpus > 1:
|
||||
raise ValueError(
|
||||
"Placement group bundle cannot have more than 1 GPU.")
|
||||
if bundle_gpus:
|
||||
gpu_bundles += 1
|
||||
if parallel_config.world_size > gpu_bundles:
|
||||
raise ValueError(
|
||||
"The number of required GPUs exceeds the total number of "
|
||||
"available GPUs.")
|
||||
if parallel_config.tensor_parallel_size >= num_devices_per_node:
|
||||
if parallel_config.tensor_parallel_size % num_devices_per_node != 0:
|
||||
raise ValueError(
|
||||
"The number of tensor parallelism is not divisible by the "
|
||||
"number of GPUs per node.")
|
||||
"available GPUs in the placement group.")
|
||||
else:
|
||||
if num_devices_per_node % parallel_config.tensor_parallel_size != 0:
|
||||
num_gpus_in_cluster = ray.cluster_resources().get("GPU", 0)
|
||||
if parallel_config.world_size > num_gpus_in_cluster:
|
||||
raise ValueError(
|
||||
"The number of GPUs per node is not divisible by the number "
|
||||
"of tensor parallelism.")
|
||||
"The number of required GPUs exceeds the total number of "
|
||||
"available GPUs in the cluster.")
|
||||
# Create a new placement group
|
||||
current_placement_group = ray.util.placement_group([{
|
||||
"GPU": 1
|
||||
}] * parallel_config.world_size)
|
||||
# Wait until PG is ready - this will block until all
|
||||
# requested resources are available, and will timeout
|
||||
# if they cannot be provisioned.
|
||||
ray.get(current_placement_group.ready(), timeout=1800)
|
||||
|
||||
# Assign GPUs to pipeline stages.
|
||||
rank = 0
|
||||
current_node_id = 0
|
||||
current_device_id = 0
|
||||
distributed_init_method = None
|
||||
all_stage_devices = []
|
||||
|
||||
for _ in range(parallel_config.pipeline_parallel_size):
|
||||
stage_devices = []
|
||||
for _ in range(parallel_config.tensor_parallel_size):
|
||||
node_resource = valid_node_resources[current_node_id]
|
||||
stage_devices.append((rank, node_resource, current_device_id))
|
||||
if distributed_init_method is None:
|
||||
ip = node_resource.split("node:")[-1]
|
||||
port = random.randint(10000, 20000)
|
||||
distributed_init_method = f"tcp://{ip}:{port}"
|
||||
rank += 1
|
||||
current_device_id += 1
|
||||
if current_device_id >= num_devices_per_node:
|
||||
current_node_id += 1
|
||||
current_device_id = 0
|
||||
all_stage_devices.append(stage_devices)
|
||||
|
||||
return distributed_init_method, all_stage_devices
|
||||
return None, current_placement_group
|
||||
|
||||
@ -1,92 +0,0 @@
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
from transformers import (AutoConfig, AutoTokenizer, PreTrainedTokenizer,
|
||||
PreTrainedTokenizerFast)
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_MODEL_TYPES_WITH_SLOW_TOKENIZER = []
|
||||
|
||||
|
||||
def get_tokenizer(
|
||||
model_name: str,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
|
||||
"""Gets a tokenizer for the given model name via Huggingface."""
|
||||
config = AutoConfig.from_pretrained(model_name)
|
||||
if "open_llama" in model_name:
|
||||
kwargs["use_fast"] = False
|
||||
logger.info(
|
||||
"OpenLLaMA models do not support the fast tokenizer. "
|
||||
"Using the slow tokenizer instead.")
|
||||
elif config.model_type == "llama" and getattr(kwargs, "use_fast", True):
|
||||
# LLaMA fast tokenizer causes protobuf errors in some environments.
|
||||
# However, we found that the below LLaMA fast tokenizer works well in
|
||||
# most environments.
|
||||
model_name = "hf-internal-testing/llama-tokenizer"
|
||||
logger.info(
|
||||
f"Using the LLaMA fast tokenizer in '{model_name}' to avoid "
|
||||
"potential protobuf errors.")
|
||||
elif config.model_type in _MODEL_TYPES_WITH_SLOW_TOKENIZER:
|
||||
if getattr(kwargs, "use_fast", False) == True:
|
||||
raise ValueError(
|
||||
f"Cannot use the fast tokenizer for {config.model_type} due to "
|
||||
"bugs in the fast tokenizer.")
|
||||
logger.info(
|
||||
f"Using the slow tokenizer for {config.model_type} due to bugs in "
|
||||
"the fast tokenizer. This could potentially lead to performance "
|
||||
"degradation.")
|
||||
kwargs["use_fast"] = False
|
||||
return AutoTokenizer.from_pretrained(model_name, *args, **kwargs)
|
||||
|
||||
|
||||
def detokenize_incrementally(
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||
prev_output_tokens: List[str],
|
||||
new_token_id: int,
|
||||
skip_special_tokens: bool,
|
||||
) -> Tuple[str, str]:
|
||||
"""Detokenizes the new token in conjuction with the previous output tokens.
|
||||
|
||||
NOTE: This function does not update prev_output_tokens.
|
||||
|
||||
Returns:
|
||||
new_token: The new token as a string.
|
||||
output_text: The new output text as a string.
|
||||
"""
|
||||
new_token = tokenizer.convert_ids_to_tokens(
|
||||
new_token_id, skip_special_tokens=skip_special_tokens)
|
||||
output_tokens = prev_output_tokens + [new_token]
|
||||
|
||||
# Convert the tokens to a string.
|
||||
# Optimization: If the tokenizer does not have `added_tokens_encoder`,
|
||||
# then we can directly use `convert_tokens_to_string`.
|
||||
if not getattr(tokenizer, "added_tokens_encoder", {}):
|
||||
output_text = tokenizer.convert_tokens_to_string(output_tokens)
|
||||
return new_token, output_text
|
||||
|
||||
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921
|
||||
# NOTE(woosuk): The following code is slow because it runs a for loop over
|
||||
# the output_tokens. In Python, running a for loop over a list can be slow
|
||||
# even when the loop body is very simple.
|
||||
sub_texts = []
|
||||
current_sub_text = []
|
||||
for token in output_tokens:
|
||||
if skip_special_tokens and token in tokenizer.all_special_ids:
|
||||
continue
|
||||
if token in tokenizer.added_tokens_encoder:
|
||||
if current_sub_text:
|
||||
sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
|
||||
sub_texts.append(sub_text)
|
||||
current_sub_text = []
|
||||
sub_texts.append(token)
|
||||
else:
|
||||
current_sub_text.append(token)
|
||||
if current_sub_text:
|
||||
sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
|
||||
sub_texts.append(sub_text)
|
||||
output_text = " ".join(sub_texts)
|
||||
return new_token, output_text
|
||||
@ -3,7 +3,7 @@ import json
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from fastapi import BackgroundTasks, FastAPI, Request
|
||||
from fastapi.responses import Response, StreamingResponse
|
||||
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||
import uvicorn
|
||||
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
@ -12,8 +12,9 @@ from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
TIMEOUT_KEEP_ALIVE = 5 # seconds.
|
||||
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
|
||||
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds.
|
||||
app = FastAPI()
|
||||
engine = None
|
||||
|
||||
|
||||
@app.post("/generate")
|
||||
@ -30,6 +31,10 @@ async def generate(request: Request) -> Response:
|
||||
stream = request_dict.pop("stream", False)
|
||||
sampling_params = SamplingParams(**request_dict)
|
||||
request_id = random_uuid()
|
||||
|
||||
if not engine.is_running:
|
||||
engine.start_background_loop()
|
||||
|
||||
results_generator = engine.generate(prompt, sampling_params, request_id)
|
||||
|
||||
# Streaming case
|
||||
@ -37,8 +42,7 @@ async def generate(request: Request) -> Response:
|
||||
async for request_output in results_generator:
|
||||
prompt = request_output.prompt
|
||||
text_outputs = [
|
||||
prompt + output.text
|
||||
for output in request_output.outputs
|
||||
prompt + output.text for output in request_output.outputs
|
||||
]
|
||||
ret = {"text": text_outputs}
|
||||
yield (json.dumps(ret) + "\0").encode("utf-8")
|
||||
@ -63,12 +67,9 @@ async def generate(request: Request) -> Response:
|
||||
|
||||
assert final_output is not None
|
||||
prompt = final_output.prompt
|
||||
text_outputs = [
|
||||
prompt + output.text
|
||||
for output in final_output.outputs
|
||||
]
|
||||
text_outputs = [prompt + output.text for output in final_output.outputs]
|
||||
ret = {"text": text_outputs}
|
||||
return Response(content=json.dumps(ret))
|
||||
return JSONResponse(ret)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@ -79,7 +80,11 @@ if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
engine = AsyncLLMEngine.from_engine_args(engine_args,
|
||||
start_engine_loop=False)
|
||||
|
||||
uvicorn.run(app, host=args.host, port=args.port, log_level="debug",
|
||||
uvicorn.run(app,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
log_level="debug",
|
||||
timeout_keep_alive=TIMEOUT_KEEP_ALIVE)
|
||||
|
||||
@ -25,6 +25,11 @@ class LLM:
|
||||
|
||||
Args:
|
||||
model: The name or path of a HuggingFace Transformers model.
|
||||
tokenizer: The name or path of a HuggingFace Transformers tokenizer.
|
||||
tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
|
||||
if available, and "slow" will always use the slow tokenizer.
|
||||
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
|
||||
downloading the model and tokenizer.
|
||||
tensor_parallel_size: The number of GPUs to use for distributed
|
||||
execution with tensor parallelism.
|
||||
dtype: The data type for the model weights and activations. Currently,
|
||||
@ -38,6 +43,9 @@ class LLM:
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
tokenizer: Optional[str] = None,
|
||||
tokenizer_mode: str = "auto",
|
||||
trust_remote_code: bool = False,
|
||||
tensor_parallel_size: int = 1,
|
||||
dtype: str = "auto",
|
||||
seed: int = 0,
|
||||
@ -47,6 +55,9 @@ class LLM:
|
||||
kwargs["disable_log_stats"] = True
|
||||
engine_args = EngineArgs(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
tokenizer_mode=tokenizer_mode,
|
||||
trust_remote_code=trust_remote_code,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
dtype=dtype,
|
||||
seed=seed,
|
||||
@ -56,10 +67,15 @@ class LLM:
|
||||
self.request_counter = Counter()
|
||||
|
||||
def get_tokenizer(
|
||||
self,
|
||||
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
|
||||
self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
|
||||
return self.llm_engine.tokenizer
|
||||
|
||||
def set_tokenizer(
|
||||
self,
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||
) -> None:
|
||||
self.llm_engine.tokenizer = tokenizer
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompts: Optional[Union[str, List[str]]] = None,
|
||||
@ -139,4 +155,8 @@ class LLM:
|
||||
pbar.update(1)
|
||||
if use_tqdm:
|
||||
pbar.close()
|
||||
# Sort the outputs by request ID.
|
||||
# This is necessary because some requests may be finished earlier than
|
||||
# its previous requests.
|
||||
outputs = sorted(outputs, key=lambda x: int(x.request_id))
|
||||
return outputs
|
||||
|
||||
@ -1,47 +1,61 @@
|
||||
# Adapted from https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/serve/openai_api_server.py
|
||||
# Adapted from
|
||||
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/serve/openai_api_server.py
|
||||
|
||||
import argparse
|
||||
from http import HTTPStatus
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from typing import AsyncGenerator, Dict, List, Optional
|
||||
from http import HTTPStatus
|
||||
from typing import AsyncGenerator, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import fastapi
|
||||
import uvicorn
|
||||
from fastapi import BackgroundTasks, Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
import uvicorn
|
||||
from packaging import version
|
||||
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.engine.tokenizer_utils import get_tokenizer
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
CompletionRequest, CompletionResponse, CompletionResponseChoice,
|
||||
CompletionResponseStreamChoice, CompletionStreamResponse, ErrorResponse,
|
||||
CompletionResponseStreamChoice, CompletionStreamResponse,
|
||||
ChatCompletionRequest, ChatCompletionResponse,
|
||||
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
|
||||
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
|
||||
LogProbs, ModelCard, ModelList, ModelPermission, UsageInfo)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
try:
|
||||
import fastchat
|
||||
from fastchat.conversation import Conversation, SeparatorStyle
|
||||
from fastchat.model.model_adapter import get_conversation_template
|
||||
_fastchat_available = True
|
||||
except ImportError:
|
||||
_fastchat_available = False
|
||||
|
||||
TIMEOUT_KEEP_ALIVE = 5 # seconds
|
||||
|
||||
logger = init_logger(__name__)
|
||||
served_model = None
|
||||
app = fastapi.FastAPI()
|
||||
engine = None
|
||||
|
||||
|
||||
def create_error_response(status_code: HTTPStatus,
|
||||
message: str) -> JSONResponse:
|
||||
return JSONResponse(
|
||||
ErrorResponse(message=message, type="invalid_request_error").dict(),
|
||||
status_code=status_code.value
|
||||
)
|
||||
return JSONResponse(ErrorResponse(message=message,
|
||||
type="invalid_request_error").dict(),
|
||||
status_code=status_code.value)
|
||||
|
||||
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def validation_exception_handler(request, exc):
|
||||
async def validation_exception_handler(request, exc): # pylint: disable=unused-argument
|
||||
return create_error_response(HTTPStatus.BAD_REQUEST, str(exc))
|
||||
|
||||
|
||||
@ -55,11 +69,88 @@ async def check_model(request) -> Optional[JSONResponse]:
|
||||
return ret
|
||||
|
||||
|
||||
async def get_gen_prompt(request) -> str:
|
||||
if not _fastchat_available:
|
||||
raise ModuleNotFoundError(
|
||||
"fastchat is not installed. Please install fastchat to use "
|
||||
"the chat completion and conversation APIs: `$ pip install fschat`"
|
||||
)
|
||||
if version.parse(fastchat.__version__) < version.parse("0.2.23"):
|
||||
raise ImportError(
|
||||
f"fastchat version is low. Current version: {fastchat.__version__} "
|
||||
"Please upgrade fastchat to use: `$ pip install -U fschat`")
|
||||
|
||||
conv = get_conversation_template(request.model)
|
||||
conv = Conversation(
|
||||
name=conv.name,
|
||||
system_template=conv.system_template,
|
||||
system_message=conv.system_message,
|
||||
roles=conv.roles,
|
||||
messages=list(conv.messages), # prevent in-place modification
|
||||
offset=conv.offset,
|
||||
sep_style=SeparatorStyle(conv.sep_style),
|
||||
sep=conv.sep,
|
||||
sep2=conv.sep2,
|
||||
stop_str=conv.stop_str,
|
||||
stop_token_ids=conv.stop_token_ids,
|
||||
)
|
||||
|
||||
if isinstance(request.messages, str):
|
||||
prompt = request.messages
|
||||
else:
|
||||
for message in request.messages:
|
||||
msg_role = message["role"]
|
||||
if msg_role == "system":
|
||||
conv.system_message = message["content"]
|
||||
elif msg_role == "user":
|
||||
conv.append_message(conv.roles[0], message["content"])
|
||||
elif msg_role == "assistant":
|
||||
conv.append_message(conv.roles[1], message["content"])
|
||||
else:
|
||||
raise ValueError(f"Unknown role: {msg_role}")
|
||||
|
||||
# Add a blank message for the assistant.
|
||||
conv.append_message(conv.roles[1], None)
|
||||
prompt = conv.get_prompt()
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
async def check_length(
|
||||
request: Union[ChatCompletionRequest, CompletionRequest],
|
||||
prompt: Optional[str] = None,
|
||||
prompt_ids: Optional[List[int]] = None
|
||||
) -> Tuple[List[int], Optional[JSONResponse]]:
|
||||
assert (not (prompt is None and prompt_ids is None)
|
||||
and not (prompt is not None and prompt_ids is not None)
|
||||
), "Either prompt or prompt_ids should be provided."
|
||||
if prompt_ids is not None:
|
||||
input_ids = prompt_ids
|
||||
else:
|
||||
input_ids = tokenizer(prompt).input_ids
|
||||
token_num = len(input_ids)
|
||||
|
||||
if token_num + request.max_tokens > max_model_len:
|
||||
return input_ids, create_error_response(
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
f"This model's maximum context length is {max_model_len} tokens. "
|
||||
f"However, you requested {request.max_tokens + token_num} tokens "
|
||||
f"({token_num} in the messages, "
|
||||
f"{request.max_tokens} in the completion). "
|
||||
f"Please reduce the length of the messages or completion.",
|
||||
)
|
||||
else:
|
||||
return input_ids, None
|
||||
|
||||
|
||||
@app.get("/v1/models")
|
||||
async def show_available_models():
|
||||
"""Show available models. Right now we only have one model."""
|
||||
model_cards = [ModelCard(id=served_model, root=served_model,
|
||||
permission=[ModelPermission()])]
|
||||
model_cards = [
|
||||
ModelCard(id=served_model,
|
||||
root=served_model,
|
||||
permission=[ModelPermission()])
|
||||
]
|
||||
return ModelList(data=model_cards)
|
||||
|
||||
|
||||
@ -76,17 +167,192 @@ def create_logprobs(token_ids: List[int],
|
||||
if len(logprobs.text_offset) == 0:
|
||||
logprobs.text_offset.append(initial_text_offset)
|
||||
else:
|
||||
logprobs.text_offset.append(logprobs.text_offset[-1] + last_token_len)
|
||||
logprobs.text_offset.append(logprobs.text_offset[-1] +
|
||||
last_token_len)
|
||||
last_token_len = len(token)
|
||||
|
||||
logprobs.top_logprobs.append(
|
||||
{tokenizer.convert_ids_to_tokens(i): p
|
||||
for i, p in id_logprob.items()})
|
||||
logprobs.top_logprobs.append({
|
||||
tokenizer.convert_ids_to_tokens(i): p
|
||||
for i, p in id_logprob.items()
|
||||
})
|
||||
return logprobs
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def create_chat_completion(request: ChatCompletionRequest,
|
||||
raw_request: Request):
|
||||
"""Completion API similar to OpenAI's API.
|
||||
|
||||
See https://platform.openai.com/docs/api-reference/chat/create
|
||||
for the API specification. This API mimics the OpenAI ChatCompletion API.
|
||||
|
||||
NOTE: Currently we do not support the following features:
|
||||
- function_call (Users should implement this by themselves)
|
||||
- logit_bias (to be supported by vLLM engine)
|
||||
"""
|
||||
logger.info(f"Received chat completion request: {request}")
|
||||
|
||||
if not engine.is_running:
|
||||
engine.start_background_loop()
|
||||
|
||||
error_check_ret = await check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
if request.logit_bias is not None:
|
||||
# TODO: support logit_bias in vLLM engine.
|
||||
return create_error_response(HTTPStatus.BAD_REQUEST,
|
||||
"logit_bias is not currently supported")
|
||||
|
||||
prompt = await get_gen_prompt(request)
|
||||
token_ids, error_check_ret = await check_length(request, prompt=prompt)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
model_name = request.model
|
||||
request_id = f"cmpl-{random_uuid()}"
|
||||
created_time = int(time.time())
|
||||
try:
|
||||
sampling_params = SamplingParams(
|
||||
n=request.n,
|
||||
presence_penalty=request.presence_penalty,
|
||||
frequency_penalty=request.frequency_penalty,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
stop=request.stop,
|
||||
max_tokens=request.max_tokens,
|
||||
best_of=request.best_of,
|
||||
top_k=request.top_k,
|
||||
ignore_eos=request.ignore_eos,
|
||||
use_beam_search=request.use_beam_search,
|
||||
)
|
||||
except ValueError as e:
|
||||
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
|
||||
|
||||
result_generator = engine.generate(prompt, sampling_params, request_id,
|
||||
token_ids)
|
||||
|
||||
async def abort_request() -> None:
|
||||
await engine.abort(request_id)
|
||||
|
||||
def create_stream_response_json(
|
||||
index: int,
|
||||
text: str,
|
||||
finish_reason: Optional[str] = None,
|
||||
) -> str:
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=index,
|
||||
delta=DeltaMessage(content=text),
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
response = ChatCompletionStreamResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=[choice_data],
|
||||
)
|
||||
response_json = response.json(ensure_ascii=False)
|
||||
|
||||
return response_json
|
||||
|
||||
async def completion_stream_generator() -> AsyncGenerator[str, None]:
|
||||
# First chunk with role
|
||||
for i in range(request.n):
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=i,
|
||||
delta=DeltaMessage(role="assistant"),
|
||||
finish_reason=None,
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(id=request_id,
|
||||
choices=[choice_data],
|
||||
model=model_name)
|
||||
data = chunk.json(exclude_unset=True, ensure_ascii=False)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
previous_texts = [""] * request.n
|
||||
previous_num_tokens = [0] * request.n
|
||||
async for res in result_generator:
|
||||
res: RequestOutput
|
||||
for output in res.outputs:
|
||||
i = output.index
|
||||
delta_text = output.text[len(previous_texts[i]):]
|
||||
previous_texts[i] = output.text
|
||||
previous_num_tokens[i] = len(output.token_ids)
|
||||
response_json = create_stream_response_json(
|
||||
index=i,
|
||||
text=delta_text,
|
||||
)
|
||||
yield f"data: {response_json}\n\n"
|
||||
if output.finish_reason is not None:
|
||||
response_json = create_stream_response_json(
|
||||
index=i,
|
||||
text="",
|
||||
finish_reason=output.finish_reason,
|
||||
)
|
||||
yield f"data: {response_json}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
# Streaming response
|
||||
if request.stream:
|
||||
background_tasks = BackgroundTasks()
|
||||
# Abort the request if the client disconnects.
|
||||
background_tasks.add_task(abort_request)
|
||||
return StreamingResponse(completion_stream_generator(),
|
||||
media_type="text/event-stream",
|
||||
background=background_tasks)
|
||||
|
||||
# Non-streaming response
|
||||
final_res: RequestOutput = None
|
||||
async for res in result_generator:
|
||||
if await raw_request.is_disconnected():
|
||||
# Abort the request if the client disconnects.
|
||||
await abort_request()
|
||||
return create_error_response(HTTPStatus.BAD_REQUEST,
|
||||
"Client disconnected")
|
||||
final_res = res
|
||||
assert final_res is not None
|
||||
choices = []
|
||||
for output in final_res.outputs:
|
||||
choice_data = ChatCompletionResponseChoice(
|
||||
index=output.index,
|
||||
message=ChatMessage(role="assistant", content=output.text),
|
||||
finish_reason=output.finish_reason,
|
||||
)
|
||||
choices.append(choice_data)
|
||||
|
||||
num_prompt_tokens = len(final_res.prompt_token_ids)
|
||||
num_generated_tokens = sum(
|
||||
len(output.token_ids) for output in final_res.outputs)
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=num_generated_tokens,
|
||||
total_tokens=num_prompt_tokens + num_generated_tokens,
|
||||
)
|
||||
response = ChatCompletionResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=choices,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
if request.stream:
|
||||
# When user requests streaming but we don't stream, we still need to
|
||||
# return a streaming response with a single event.
|
||||
response_json = response.json(ensure_ascii=False)
|
||||
|
||||
async def fake_stream_generator() -> AsyncGenerator[str, None]:
|
||||
yield f"data: {response_json}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(fake_stream_generator(),
|
||||
media_type="text/event-stream")
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@app.post("/v1/completions")
|
||||
async def create_completion(raw_request: Request):
|
||||
async def create_completion(request: CompletionRequest, raw_request: Request):
|
||||
"""Completion API similar to OpenAI's API.
|
||||
|
||||
See https://platform.openai.com/docs/api-reference/completions/create
|
||||
@ -99,9 +365,11 @@ async def create_completion(raw_request: Request):
|
||||
suffix)
|
||||
- logit_bias (to be supported by vLLM engine)
|
||||
"""
|
||||
request = CompletionRequest(**await raw_request.json())
|
||||
logger.info(f"Received completion request: {request}")
|
||||
|
||||
if not engine.is_running:
|
||||
engine.start_background_loop()
|
||||
|
||||
error_check_ret = await check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
@ -124,7 +392,34 @@ async def create_completion(raw_request: Request):
|
||||
|
||||
model_name = request.model
|
||||
request_id = f"cmpl-{random_uuid()}"
|
||||
|
||||
use_token_ids = False
|
||||
if isinstance(request.prompt, list):
|
||||
if len(request.prompt) == 0:
|
||||
return create_error_response(HTTPStatus.BAD_REQUEST,
|
||||
"please provide at least one prompt")
|
||||
first_element = request.prompt[0]
|
||||
if isinstance(first_element, int):
|
||||
use_token_ids = True
|
||||
prompt = request.prompt
|
||||
elif isinstance(first_element, (str, list)):
|
||||
# TODO: handles multiple prompt case in list[list[int]]
|
||||
if len(request.prompt) > 1:
|
||||
return create_error_response(
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
"multiple prompts in a batch is not currently supported")
|
||||
use_token_ids = not isinstance(first_element, str)
|
||||
prompt = request.prompt[0]
|
||||
else:
|
||||
prompt = request.prompt
|
||||
|
||||
if use_token_ids:
|
||||
_, error_check_ret = await check_length(request, prompt_ids=prompt)
|
||||
else:
|
||||
token_ids, error_check_ret = await check_length(request, prompt=prompt)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
created_time = int(time.time())
|
||||
try:
|
||||
sampling_params = SamplingParams(
|
||||
@ -144,22 +439,30 @@ async def create_completion(raw_request: Request):
|
||||
except ValueError as e:
|
||||
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
|
||||
|
||||
result_generator = engine.generate(prompt, sampling_params,
|
||||
request_id)
|
||||
if use_token_ids:
|
||||
result_generator = engine.generate(None,
|
||||
sampling_params,
|
||||
request_id,
|
||||
prompt_token_ids=prompt)
|
||||
else:
|
||||
result_generator = engine.generate(prompt, sampling_params, request_id,
|
||||
token_ids)
|
||||
|
||||
# Similar to the OpenAI API, when n != best_of, we do not stream the
|
||||
# results. In addition, we do not stream the results when use beam search.
|
||||
stream = (request.stream and
|
||||
(request.best_of is None or request.n == request.best_of) and
|
||||
not request.use_beam_search)
|
||||
stream = (request.stream
|
||||
and (request.best_of is None or request.n == request.best_of)
|
||||
and not request.use_beam_search)
|
||||
|
||||
async def abort_request() -> None:
|
||||
await engine.abort(request_id)
|
||||
|
||||
def create_stream_response_json(index: int,
|
||||
def create_stream_response_json(
|
||||
index: int,
|
||||
text: str,
|
||||
logprobs: Optional[LogProbs] = None,
|
||||
finish_reason: Optional[str] = None) -> str:
|
||||
finish_reason: Optional[str] = None,
|
||||
) -> str:
|
||||
choice_data = CompletionResponseStreamChoice(
|
||||
index=index,
|
||||
text=text,
|
||||
@ -200,7 +503,8 @@ async def create_completion(raw_request: Request):
|
||||
)
|
||||
yield f"data: {response_json}\n\n"
|
||||
if output.finish_reason is not None:
|
||||
logprobs = LogProbs() if request.logprobs is not None else None
|
||||
logprobs = (LogProbs()
|
||||
if request.logprobs is not None else None)
|
||||
response_json = create_stream_response_json(
|
||||
index=i,
|
||||
text="",
|
||||
@ -244,8 +548,8 @@ async def create_completion(raw_request: Request):
|
||||
choices.append(choice_data)
|
||||
|
||||
num_prompt_tokens = len(final_res.prompt_token_ids)
|
||||
num_generated_tokens = sum(len(output.token_ids)
|
||||
for output in final_res.outputs)
|
||||
num_generated_tokens = sum(
|
||||
len(output.token_ids) for output in final_res.outputs)
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=num_generated_tokens,
|
||||
@ -263,9 +567,11 @@ async def create_completion(raw_request: Request):
|
||||
# When user requests streaming but we don't stream, we still need to
|
||||
# return a streaming response with a single event.
|
||||
response_json = response.json(ensure_ascii=False)
|
||||
|
||||
async def fake_stream_generator() -> AsyncGenerator[str, None]:
|
||||
yield f"data: {response_json}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(fake_stream_generator(),
|
||||
media_type="text/event-stream")
|
||||
|
||||
@ -274,26 +580,34 @@ async def create_completion(raw_request: Request):
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="vLLM OpenAI-Compatible RESTful API server."
|
||||
)
|
||||
parser.add_argument("--host", type=str, default="localhost", help="host name")
|
||||
description="vLLM OpenAI-Compatible RESTful API server.")
|
||||
parser.add_argument("--host",
|
||||
type=str,
|
||||
default="localhost",
|
||||
help="host name")
|
||||
parser.add_argument("--port", type=int, default=8000, help="port number")
|
||||
parser.add_argument(
|
||||
"--allow-credentials", action="store_true", help="allow credentials"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--allowed-origins", type=json.loads, default=["*"], help="allowed origins"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--allowed-methods", type=json.loads, default=["*"], help="allowed methods"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--allowed-headers", type=json.loads, default=["*"], help="allowed headers"
|
||||
)
|
||||
parser.add_argument("--served-model-name", type=str, default=None,
|
||||
help="The model name used in the API. If not specified, "
|
||||
"the model name will be the same as the "
|
||||
"huggingface name.")
|
||||
parser.add_argument("--allow-credentials",
|
||||
action="store_true",
|
||||
help="allow credentials")
|
||||
parser.add_argument("--allowed-origins",
|
||||
type=json.loads,
|
||||
default=["*"],
|
||||
help="allowed origins")
|
||||
parser.add_argument("--allowed-methods",
|
||||
type=json.loads,
|
||||
default=["*"],
|
||||
help="allowed methods")
|
||||
parser.add_argument("--allowed-headers",
|
||||
type=json.loads,
|
||||
default=["*"],
|
||||
help="allowed headers")
|
||||
parser.add_argument("--served-model-name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The model name used in the API. If not "
|
||||
"specified, the model name will be the same as "
|
||||
"the huggingface name.")
|
||||
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -307,13 +621,24 @@ if __name__ == "__main__":
|
||||
|
||||
logger.info(f"args: {args}")
|
||||
|
||||
served_model = args.served_model_name or args.model
|
||||
if args.served_model_name is not None:
|
||||
served_model = args.served_model_name
|
||||
else:
|
||||
served_model = args.model
|
||||
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
engine = AsyncLLMEngine.from_engine_args(engine_args,
|
||||
start_engine_loop=False)
|
||||
engine_model_config = asyncio.run(engine.get_model_config())
|
||||
max_model_len = engine_model_config.get_max_model_len()
|
||||
|
||||
# A separate tokenizer to map token IDs to strings.
|
||||
tokenizer = get_tokenizer(args.model)
|
||||
tokenizer = get_tokenizer(engine_args.tokenizer,
|
||||
tokenizer_mode=engine_args.tokenizer_mode,
|
||||
trust_remote_code=engine_args.trust_remote_code)
|
||||
|
||||
uvicorn.run(app, host=args.host, port=args.port, log_level="info",
|
||||
uvicorn.run(app,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
log_level="info",
|
||||
timeout_keep_alive=TIMEOUT_KEEP_ALIVE)
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
# Adapted from https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
|
||||
# Adapted from
|
||||
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
|
||||
import time
|
||||
from typing import Dict, List, Literal, Optional, Union
|
||||
|
||||
@ -53,21 +54,28 @@ class UsageInfo(BaseModel):
|
||||
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
model: str
|
||||
messages: List[Dict[str, str]]
|
||||
messages: Union[str, List[Dict[str, str]]]
|
||||
temperature: Optional[float] = 0.7
|
||||
top_p: Optional[float] = 1.0
|
||||
n: Optional[int] = 1
|
||||
max_tokens: Optional[int] = None
|
||||
stop: Optional[Union[str, List[str]]] = None
|
||||
max_tokens: Optional[int] = 16
|
||||
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
||||
stream: Optional[bool] = False
|
||||
presence_penalty: Optional[float] = 0.0
|
||||
frequency_penalty: Optional[float] = 0.0
|
||||
logit_bias: Optional[Dict[str, float]] = None
|
||||
user: Optional[str] = None
|
||||
# Additional parameters supported by vLLM
|
||||
best_of: Optional[int] = None
|
||||
top_k: Optional[int] = -1
|
||||
ignore_eos: Optional[bool] = False
|
||||
use_beam_search: Optional[bool] = False
|
||||
|
||||
|
||||
class CompletionRequest(BaseModel):
|
||||
model: str
|
||||
prompt: str
|
||||
# a string, array of strings, array of tokens, or array of token arrays
|
||||
prompt: Union[List[int], List[List[int]], str, List[str]]
|
||||
suffix: Optional[str] = None
|
||||
max_tokens: Optional[int] = 16
|
||||
temperature: Optional[float] = 1.0
|
||||
@ -92,7 +100,8 @@ class LogProbs(BaseModel):
|
||||
text_offset: List[int] = Field(default_factory=list)
|
||||
token_logprobs: List[Optional[float]] = Field(default_factory=list)
|
||||
tokens: List[str] = Field(default_factory=list)
|
||||
top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list)
|
||||
top_logprobs: List[Optional[Dict[str,
|
||||
float]]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class CompletionResponseChoice(BaseModel):
|
||||
@ -124,3 +133,42 @@ class CompletionStreamResponse(BaseModel):
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: List[CompletionResponseStreamChoice]
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
|
||||
|
||||
class ChatCompletionResponseChoice(BaseModel):
|
||||
index: int
|
||||
message: ChatMessage
|
||||
finish_reason: Optional[Literal["stop", "length"]] = None
|
||||
|
||||
|
||||
class ChatCompletionResponse(BaseModel):
|
||||
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
|
||||
object: str = "chat.completion"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: List[ChatCompletionResponseChoice]
|
||||
usage: UsageInfo
|
||||
|
||||
|
||||
class DeltaMessage(BaseModel):
|
||||
role: Optional[str] = None
|
||||
content: Optional[str] = None
|
||||
|
||||
|
||||
class ChatCompletionResponseStreamChoice(BaseModel):
|
||||
index: int
|
||||
delta: DeltaMessage
|
||||
finish_reason: Optional[Literal["stop", "length"]] = None
|
||||
|
||||
|
||||
class ChatCompletionStreamResponse(BaseModel):
|
||||
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
|
||||
object: str = "chat.completion.chunk"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: List[ChatCompletionResponseStreamChoice]
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
# Adapted from https://github.com/skypilot-org/skypilot/blob/86dc0f6283a335e4aa37b3c10716f90999f48ab6/sky/sky_logging.py
|
||||
|
||||
# Adapted from
|
||||
# https://github.com/skypilot-org/skypilot/blob/86dc0f6283a335e4aa37b3c10716f90999f48ab6/sky/sky_logging.py
|
||||
"""Logging configuration for vLLM."""
|
||||
import logging
|
||||
import sys
|
||||
|
||||
|
||||
_FORMAT = "%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s"
|
||||
_DATE_FORMAT = "%m-%d %H:%M:%S"
|
||||
|
||||
|
||||
@ -2,7 +2,6 @@ from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
|
||||
|
||||
__all__ = [
|
||||
"InputMetadata",
|
||||
"get_model",
|
||||
|
||||
@ -1,18 +1,29 @@
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
|
||||
from xformers.ops import AttentionBias
|
||||
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import SequenceData
|
||||
|
||||
|
||||
class InputMetadata:
|
||||
"""Metadata for input sequences. Used for PagedAttention.
|
||||
|
||||
Args:
|
||||
seq_groups: List of (seq_ids, sampling_params).
|
||||
seq_data: Seq_id -> SequenceData.
|
||||
prompt_lens: Lengths of prompts.
|
||||
slot_mapping: The address to write the new KV to of each token.
|
||||
context_lens: the length of attention context for each generation token.
|
||||
max_context_len: The maximum context length.
|
||||
block_tables: The block tables. (Seq id -> list of physical block)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
seq_groups: List[Tuple[List[int], SamplingParams]], # List of (seq_ids, sampling_params).
|
||||
seq_data: Dict[int, SequenceData], # Seq_id -> SequenceData.
|
||||
seq_groups: List[Tuple[List[int], SamplingParams]],
|
||||
seq_data: Dict[int, SequenceData],
|
||||
prompt_lens: List[int],
|
||||
slot_mapping: torch.Tensor,
|
||||
context_lens: torch.Tensor,
|
||||
@ -27,7 +38,6 @@ class InputMetadata:
|
||||
self.max_context_len = max_context_len
|
||||
self.block_tables = block_tables
|
||||
|
||||
self.attn_bias = BlockDiagonalCausalMask.from_seqlens(prompt_lens)
|
||||
self.num_prompts = len(prompt_lens)
|
||||
self.num_prompt_tokens = sum(prompt_lens)
|
||||
self.num_generation_tokens = context_lens.shape[0]
|
||||
@ -39,6 +49,9 @@ class InputMetadata:
|
||||
assert block_tables.shape[0] == self.num_generation_tokens
|
||||
assert context_lens.shape[0] == self.num_generation_tokens
|
||||
|
||||
# Set during the execution of the first attention op.
|
||||
self.attn_bias: List[AttentionBias] = []
|
||||
|
||||
def __repr__(self) -> str:
|
||||
# Print only useful metadata.
|
||||
return (f'InputMetadata('
|
||||
|
||||
@ -4,10 +4,50 @@ import torch.nn as nn
|
||||
|
||||
from vllm import activation_ops
|
||||
|
||||
|
||||
class SiluAndMul(nn.Module):
|
||||
"""An activation function for SwiGLU.
|
||||
|
||||
The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[1] // 2.
|
||||
|
||||
Shapes:
|
||||
x: (num_tokens, 2 * d)
|
||||
return: (num_tokens, d)
|
||||
"""
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
num_tokens = x.shape[0]
|
||||
d = x.shape[1] // 2
|
||||
out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device)
|
||||
activation_ops.silu_and_mul(out, x)
|
||||
return out
|
||||
|
||||
|
||||
class NewGELU(nn.Module):
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
num_tokens = x.shape[0]
|
||||
d = x.shape[1]
|
||||
out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device)
|
||||
activation_ops.gelu_new(out, x)
|
||||
return out
|
||||
|
||||
|
||||
class FastGELU(nn.Module):
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
num_tokens = x.shape[0]
|
||||
d = x.shape[1]
|
||||
out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device)
|
||||
activation_ops.gelu_fast(out, x)
|
||||
return out
|
||||
|
||||
|
||||
_ACTIVATION_REGISTRY = {
|
||||
"gelu": nn.GELU(),
|
||||
"gelu_new": nn.GELU(approximate="tanh"), # NOTE: This may introduce small rounding errors.
|
||||
"gelu_fast": nn.GELU(approximate="tanh"), # NOTE: This may introduce small rounding errors.
|
||||
"gelu_fast": FastGELU(),
|
||||
"gelu_new": NewGELU(),
|
||||
"gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
|
||||
"relu": nn.ReLU(),
|
||||
}
|
||||
|
||||
@ -18,23 +58,3 @@ def get_act_fn(act_fn: str) -> nn.Module:
|
||||
if act_fn in _ACTIVATION_REGISTRY:
|
||||
return _ACTIVATION_REGISTRY[act_fn]
|
||||
raise ValueError(f"Activation function {act_fn!r} is not supported.")
|
||||
|
||||
|
||||
class SiluAndMul(nn.Module):
|
||||
"""An activation function for SwiGLU.
|
||||
|
||||
The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[1] // 2.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor, # (num_tokens, 2 * d)
|
||||
) -> torch.Tensor: # (num_tokens, d)
|
||||
num_tokens = x.shape[0]
|
||||
d = x.shape[1] // 2
|
||||
out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device)
|
||||
activation_ops.silu_and_mul(out, x)
|
||||
return out
|
||||
|
||||
@ -1,28 +1,39 @@
|
||||
"""Multi-head attention."""
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from xformers import ops as xops
|
||||
from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
|
||||
LowerTriangularMaskWithTensorBias)
|
||||
|
||||
from vllm import attention_ops
|
||||
from vllm import cache_ops
|
||||
from vllm import pos_encoding_ops
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
|
||||
_SUPPORTED_HEAD_SIZES = [64, 80, 96, 128]
|
||||
_SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256]
|
||||
|
||||
|
||||
class PagedAttention(nn.Module):
|
||||
# pylint: disable=line-too-long
|
||||
"""GPT-style multi-head PagedAttention.
|
||||
|
||||
This class takes flattened 1D query, key, and value tensors as input. The
|
||||
input 1D tensors can be split into three parts: the prompt tokens, the
|
||||
generation tokens, and the paddings.
|
||||
input 1D tensors can either contain prompt tokens or generation tokens, in
|
||||
addition to paddings.
|
||||
|
||||
|<------------------------------------- num_valid_tokens ------------------------------------->|
|
||||
|<--------------- num_prompt_tokens -------------->|<------- num_generation_tokens (M) ------->|
|
||||
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|<--generation_0-->|...|<--generation_M-1-->|<--padding-->|
|
||||
If the input tensors contain prompt tokens, the layout is as follows:
|
||||
|
||||
|<---------------------- num_valid_tokens ---------------------->|
|
||||
|<--------------- num_prompt_tokens -------------->|
|
||||
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|<--padding-->|
|
||||
|
||||
Otherwise, the layout is as follows:
|
||||
|
||||
|<------------------ num_valid_tokens ------------------->|
|
||||
|<------- num_generation_tokens (M) ------->|
|
||||
|<--generation_0-->|...|<--generation_M-1-->|<--padding-->|
|
||||
|
||||
The prompts might have different lengths, while the generation tokens always
|
||||
have length 1. The paddings are appended to make the input length a multiple
|
||||
@ -41,34 +52,68 @@ class PagedAttention(nn.Module):
|
||||
5. Output a flattened 1D tensor.
|
||||
"""
|
||||
|
||||
def __init__(self, num_heads: int, head_size: int, scale: float) -> None:
|
||||
def __init__(self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: Optional[int] = None) -> None:
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.attn_op = xops.fmha.cutlass.FwOp()
|
||||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
self.head_mapping = torch.repeat_interleave(
|
||||
torch.arange(self.num_kv_heads, dtype=torch.int32, device="cuda"),
|
||||
self.num_queries_per_kv)
|
||||
|
||||
if self.head_size not in _SUPPORTED_HEAD_SIZES:
|
||||
raise ValueError(f"head_size ({self.head_size}) is not supported. "
|
||||
f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.")
|
||||
|
||||
def set_attn_bias(self, input_metadata: InputMetadata) -> None:
|
||||
if input_metadata.attn_bias:
|
||||
# Already set by a previous layer.
|
||||
return
|
||||
prompt_lens = input_metadata.prompt_lens
|
||||
attn_bias = BlockDiagonalCausalMask.from_seqlens(prompt_lens)
|
||||
input_metadata.attn_bias.append(attn_bias)
|
||||
|
||||
def multi_query_kv_attention(
|
||||
self,
|
||||
output: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
|
||||
query: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
|
||||
key: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
|
||||
value: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
|
||||
attn_bias: xops.AttentionBias,
|
||||
output: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
) -> torch.Tensor:
|
||||
"""Normal attention for the prompt tokens.
|
||||
|
||||
Args:
|
||||
output: shape = [num_prompt_tokens, num_heads, head_size]
|
||||
query: shape = [num_prompt_tokens, num_heads, head_size]
|
||||
key: shape = [num_prompt_tokens, num_kv_heads, head_size]
|
||||
value: shape = [num_prompt_tokens, num_kv_heads, head_size]
|
||||
input_metadata: metadata for paged attention.
|
||||
"""
|
||||
|
||||
if self.num_kv_heads != self.num_heads:
|
||||
# Project the key and value tensors to the desired number of heads.
|
||||
key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=1)
|
||||
value = torch.repeat_interleave(value,
|
||||
self.num_queries_per_kv,
|
||||
dim=1)
|
||||
|
||||
# TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize.
|
||||
out = xops.memory_efficient_attention_forward(
|
||||
query.unsqueeze(0),
|
||||
key.unsqueeze(0),
|
||||
value.unsqueeze(0),
|
||||
attn_bias=attn_bias,
|
||||
attn_bias=input_metadata.attn_bias[0],
|
||||
p=0.0,
|
||||
scale=self.scale,
|
||||
op=self.attn_op,
|
||||
)
|
||||
# TODO(woosuk): Unnecessary copy. Optimize.
|
||||
output.copy_(out.squeeze(0))
|
||||
@ -76,42 +121,72 @@ class PagedAttention(nn.Module):
|
||||
|
||||
def single_query_cached_kv_attention(
|
||||
self,
|
||||
output: torch.Tensor, # [num_generation_tokens, num_heads, head_size]
|
||||
query: torch.Tensor, # [num_generation_tokens, num_heads, head_size]
|
||||
key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x]
|
||||
value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
|
||||
output: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
) -> None:
|
||||
"""PagedAttention for the generation tokens.
|
||||
|
||||
Args:
|
||||
output: shape = [num_generation_tokens, num_heads, head_size]
|
||||
query: shape = [num_generation_tokens, num_heads, head_size]
|
||||
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
|
||||
block_size, x]
|
||||
value_cache: shape = [num_blocks, num_kv_heads, head_size,
|
||||
block_size]
|
||||
input_metadata: metadata for paged attention.
|
||||
"""
|
||||
block_size = value_cache.shape[3]
|
||||
attention_ops.single_query_cached_kv_attention(
|
||||
output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
self.head_mapping,
|
||||
self.scale,
|
||||
input_metadata.block_tables,
|
||||
input_metadata.context_lens,
|
||||
block_size,
|
||||
input_metadata.max_context_len,
|
||||
None, # alibi_slopes
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor, # [num_tokens, num_heads * head_size]
|
||||
key: torch.Tensor, # [num_tokens, num_heads * head_size]
|
||||
value: torch.Tensor, # [num_tokens, num_heads * head_size]
|
||||
key_cache: Optional[torch.Tensor], # [num_blocks, num_heads, head_size/x, block_size, x]
|
||||
value_cache: Optional[torch.Tensor], # [num_blocks, num_heads, head_size, block_size]
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_cache: Optional[torch.Tensor],
|
||||
value_cache: Optional[torch.Tensor],
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor: # [num_tokens, num_heads * head_size]
|
||||
# NOTE: The query, key, and value tensors must be sliced from a qkv
|
||||
# tensor of shape [num_tokens, 3 * num_heads * head_size].
|
||||
) -> torch.Tensor:
|
||||
"""PagedAttention forward pass.
|
||||
|
||||
NOTE: The query, key, and value tensors must be sliced from a qkv
|
||||
tensor of shape [num_tokens, 3 * num_heads * head_size].
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads * head_size]
|
||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
|
||||
block_size, x]
|
||||
value_cache: shape = [num_blocks, num_kv_heads, head_size,
|
||||
block_size]
|
||||
input_metadata: metadata for paged attention.
|
||||
cache_event: event to wait for the cache operations to finish.
|
||||
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
|
||||
# Reshape the query, key, and value tensors.
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
key = key.view(-1, self.num_heads, self.head_size)
|
||||
value = value.view(-1, self.num_heads, self.head_size)
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
|
||||
# Pre-allocate the output tensor.
|
||||
output = torch.empty_like(query)
|
||||
@ -119,12 +194,15 @@ class PagedAttention(nn.Module):
|
||||
# Compute the attention op for prompts.
|
||||
num_prompt_tokens = input_metadata.num_prompt_tokens
|
||||
if num_prompt_tokens > 0:
|
||||
# Prompt run.
|
||||
assert input_metadata.num_generation_tokens == 0
|
||||
self.set_attn_bias(input_metadata)
|
||||
self.multi_query_kv_attention(
|
||||
output[:num_prompt_tokens],
|
||||
query[:num_prompt_tokens],
|
||||
key[:num_prompt_tokens],
|
||||
value[:num_prompt_tokens],
|
||||
input_metadata.attn_bias,
|
||||
input_metadata,
|
||||
)
|
||||
|
||||
# Wait until the cache op is done.
|
||||
@ -147,17 +225,16 @@ class PagedAttention(nn.Module):
|
||||
)
|
||||
|
||||
if input_metadata.num_generation_tokens > 0:
|
||||
# Decoding run.
|
||||
assert input_metadata.num_prompt_tokens == 0
|
||||
assert key_cache is not None and value_cache is not None, (
|
||||
"key_cache and value_cache must be provided when "
|
||||
"generating tokens."
|
||||
)
|
||||
"generating tokens.")
|
||||
# Compute the attention op for generation tokens.
|
||||
self.single_query_cached_kv_attention(
|
||||
output[num_prompt_tokens:num_valid_tokens],
|
||||
query[num_prompt_tokens:num_valid_tokens],
|
||||
key_cache,
|
||||
value_cache,
|
||||
input_metadata)
|
||||
query[num_prompt_tokens:num_valid_tokens], key_cache,
|
||||
value_cache, input_metadata)
|
||||
|
||||
# Reshape the output tensor.
|
||||
# NOTE(woosuk): The output tensor may include paddings.
|
||||
@ -165,7 +242,7 @@ class PagedAttention(nn.Module):
|
||||
|
||||
|
||||
class PagedAttentionWithRoPE(PagedAttention):
|
||||
"""PagedAttention with GPT-NeoX style rotary embedding."""
|
||||
"""PagedAttention with rotary embedding."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -175,19 +252,23 @@ class PagedAttentionWithRoPE(PagedAttention):
|
||||
rotary_dim: int,
|
||||
max_position: int = 8192,
|
||||
base: int = 10000,
|
||||
num_kv_heads: Optional[int] = None,
|
||||
is_neox_style: bool = True,
|
||||
) -> None:
|
||||
super().__init__(num_heads, head_size, scale)
|
||||
super().__init__(num_heads, head_size, scale, num_kv_heads)
|
||||
self.is_neox_style = is_neox_style
|
||||
|
||||
# Create the cos and sin cache.
|
||||
inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim))
|
||||
t = torch.arange(max_position).float()
|
||||
freqs = torch.einsum('i,j -> ij', t, inv_freq.float())
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq.float())
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
cache = torch.cat((cos, sin), dim=-1)
|
||||
|
||||
# FIXME(woosuk): This assumes that we configure the default dtype when
|
||||
# initializing the model. Make it more robust.
|
||||
# initializing the model.
|
||||
# TODO(woosuk): Make it more robust.
|
||||
torch_dtype = torch.get_default_dtype()
|
||||
cache = cache.to(torch_dtype)
|
||||
# Embedding size: [max_position, rotary_dim]
|
||||
@ -195,23 +276,42 @@ class PagedAttentionWithRoPE(PagedAttention):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor, # [num_tokens]
|
||||
query: torch.Tensor, # [num_tokens, num_heads * head_size]
|
||||
key: torch.Tensor, # [num_tokens, num_heads * head_size]
|
||||
value: torch.Tensor, # [num_tokens, num_heads * head_size]
|
||||
key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x]
|
||||
value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor: # [num_tokens, num_heads * head_size]
|
||||
) -> torch.Tensor:
|
||||
""" PagedAttention forward pass with rotary embedding.
|
||||
|
||||
Args:
|
||||
positions: shape = [num_tokens]
|
||||
query: shape = [num_tokens, num_heads * head_size]
|
||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
|
||||
block_size, x]
|
||||
value_cache: shape = [num_blocks, num_kv_heads, head_size,
|
||||
block_size]
|
||||
input_metadata: metadata for paged attention.
|
||||
cache_event: event to wait for the cache operations to finish.
|
||||
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
|
||||
# Apply rotary embedding to the query and key before passing them
|
||||
# to the attention op.
|
||||
pos_encoding_ops.rotary_embedding_neox(
|
||||
pos_encoding_ops.rotary_embedding(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
self.head_size,
|
||||
self.cos_sin_cache,
|
||||
self.is_neox_style,
|
||||
)
|
||||
return super().forward(
|
||||
query,
|
||||
@ -222,3 +322,125 @@ class PagedAttentionWithRoPE(PagedAttention):
|
||||
input_metadata,
|
||||
cache_event,
|
||||
)
|
||||
|
||||
|
||||
class PagedAttentionWithALiBi(PagedAttention):
|
||||
"""PagedAttention with ALiBi attention bias."""
|
||||
|
||||
def __init__(self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
slopes: List[float],
|
||||
num_kv_heads: Optional[int] = None) -> None:
|
||||
super().__init__(num_heads, head_size, scale, num_kv_heads)
|
||||
assert len(slopes) == num_heads
|
||||
|
||||
slopes = torch.tensor(slopes, dtype=torch.float32)
|
||||
self.register_buffer("alibi_slopes", slopes, persistent=False)
|
||||
|
||||
def set_attn_bias(self, input_metadata: InputMetadata) -> None:
|
||||
if input_metadata.attn_bias:
|
||||
# Already set by a previous layer.
|
||||
return
|
||||
# Generates ALiBi mask for each prompt.
|
||||
for prompt_len in input_metadata.prompt_lens:
|
||||
bias = torch.arange(prompt_len)
|
||||
# Note(zhuohan): HF uses
|
||||
# `bias = bias[None, :].repeat(prompt_len, 1)`
|
||||
# here. We find that both biases give the same results, but
|
||||
# the bias below more accurately follows the original ALiBi
|
||||
# paper.
|
||||
bias = bias[None, :] - bias[:, None]
|
||||
bias = bias.to(self.alibi_slopes.device)
|
||||
|
||||
# When using custom attention bias, xformers requires the bias to
|
||||
# be sliced from a tensor whose length is a multiple of 8.
|
||||
padded_len = (prompt_len + 7) // 8 * 8
|
||||
bias = torch.empty(
|
||||
1, # batch_size
|
||||
self.num_heads,
|
||||
prompt_len,
|
||||
padded_len,
|
||||
device=self.alibi_slopes.device,
|
||||
)[:, :, :, :prompt_len].copy_(bias)
|
||||
bias.mul_(self.alibi_slopes[:, None, None])
|
||||
attn_bias = LowerTriangularMaskWithTensorBias(bias)
|
||||
input_metadata.attn_bias.append(attn_bias)
|
||||
|
||||
def multi_query_kv_attention(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
) -> torch.Tensor:
|
||||
"""Attention with ALiBi bias for the prompt tokens.
|
||||
|
||||
Args:
|
||||
output: shape = [num_prompt_tokens, num_heads, head_size]
|
||||
query: shape = [num_prompt_tokens, num_heads, head_size]
|
||||
key: shape = [num_prompt_tokens, num_kv_heads, head_size]
|
||||
value: shape = [num_prompt_tokens, num_kv_heads, head_size]
|
||||
input_metadata: metadata for paged attention.
|
||||
"""
|
||||
if self.num_kv_heads != self.num_heads:
|
||||
# Project the key and value tensors to the desired number of heads.
|
||||
key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=1)
|
||||
value = torch.repeat_interleave(value,
|
||||
self.num_queries_per_kv,
|
||||
dim=1)
|
||||
|
||||
# FIXME(woosuk): Because xformers does not support dynamic sequence
|
||||
# lengths with custom attention bias, we process each prompt one by
|
||||
# one. This is inefficient, especially when we have many short prompts.
|
||||
start = 0
|
||||
for i, prompt_len in enumerate(input_metadata.prompt_lens):
|
||||
end = start + prompt_len
|
||||
out = xops.memory_efficient_attention_forward(
|
||||
query[None, start:end],
|
||||
key[None, start:end],
|
||||
value[None, start:end],
|
||||
attn_bias=input_metadata.attn_bias[i],
|
||||
p=0.0,
|
||||
scale=self.scale,
|
||||
)
|
||||
# TODO(woosuk): Unnecessary copy. Optimize.
|
||||
output[start:end].copy_(out.squeeze(0))
|
||||
start += prompt_len
|
||||
return output
|
||||
|
||||
def single_query_cached_kv_attention(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
) -> None:
|
||||
"""PagedAttention with ALiBi bias for the generation tokens.
|
||||
|
||||
Args:
|
||||
output: shape = [num_generation_tokens, num_heads, head_size]
|
||||
query: shape = [num_generation_tokens, num_heads, head_size]
|
||||
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
|
||||
block_size, x]
|
||||
value_cache: shape = [num_blocks, num_kv_heads, head_size,
|
||||
block_size]
|
||||
input_metadata: metadata for paged attention.
|
||||
"""
|
||||
block_size = value_cache.shape[3]
|
||||
attention_ops.single_query_cached_kv_attention(
|
||||
output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
self.head_mapping,
|
||||
self.scale,
|
||||
input_metadata.block_tables,
|
||||
input_metadata.context_lens,
|
||||
block_size,
|
||||
input_metadata.max_context_len,
|
||||
self.alibi_slopes,
|
||||
)
|
||||
|
||||
@ -9,7 +9,9 @@ from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||
gather_from_tensor_model_parallel_region)
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import SequenceOutputs
|
||||
from vllm.sequence import SamplerOutput, SequenceOutputs
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
|
||||
|
||||
class Sampler(nn.Module):
|
||||
@ -36,12 +38,15 @@ class Sampler(nn.Module):
|
||||
embedding: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
) -> Dict[int, SequenceOutputs]:
|
||||
embedding_bias: Optional[torch.Tensor] = None,
|
||||
) -> SamplerOutput:
|
||||
# Get the hidden states that we use for sampling.
|
||||
hidden_states = _prune_hidden_states(hidden_states, input_metadata)
|
||||
|
||||
# Get the logits for the next tokens.
|
||||
logits = torch.matmul(hidden_states, embedding.t())
|
||||
if embedding_bias is not None:
|
||||
logits += embedding_bias
|
||||
logits = gather_from_tensor_model_parallel_region(logits)
|
||||
# Remove paddings in vocab (if any).
|
||||
logits = logits[:, :self.vocab_size]
|
||||
@ -49,34 +54,37 @@ class Sampler(nn.Module):
|
||||
# Apply presence and frequency penalties.
|
||||
output_tokens = _get_output_tokens(input_metadata)
|
||||
assert len(output_tokens) == logits.shape[0]
|
||||
presence_penalties, frequency_penalties = _get_penalties(input_metadata)
|
||||
presence_penalties, frequency_penalties = _get_penalties(
|
||||
input_metadata)
|
||||
assert len(presence_penalties) == logits.shape[0]
|
||||
assert len(frequency_penalties) == logits.shape[0]
|
||||
logits = _apply_penalties(
|
||||
logits, output_tokens, presence_penalties, frequency_penalties,
|
||||
self.vocab_size)
|
||||
logits = _apply_penalties(logits, output_tokens, presence_penalties,
|
||||
frequency_penalties, self.vocab_size)
|
||||
|
||||
# Apply temperature scaling.
|
||||
temperatures = _get_temperatures(input_metadata)
|
||||
assert len(temperatures) == logits.shape[0]
|
||||
if any(t != 1.0 for t in temperatures):
|
||||
t = torch.tensor(
|
||||
temperatures, dtype=logits.dtype, device=logits.device)
|
||||
t = torch.tensor(temperatures,
|
||||
dtype=logits.dtype,
|
||||
device=logits.device)
|
||||
# Use in-place division to avoid creating a new tensor.
|
||||
logits.div_(t.unsqueeze(dim=1))
|
||||
|
||||
# Apply top-p and top-k truncation.
|
||||
top_ps, top_ks = _get_top_p_top_k(input_metadata, self.vocab_size)
|
||||
assert len(top_ps) == len(top_ks) == logits.shape[0]
|
||||
do_top_p = any(p < 1.0 - _SAMPLING_EPS for p in top_ps)
|
||||
do_top_k = any(k != self.vocab_size for k in top_ks)
|
||||
if do_top_p or do_top_k:
|
||||
logits = _apply_top_p_top_k(logits, top_ps, top_ks)
|
||||
|
||||
# We use float32 for probabilities and log probabilities.
|
||||
# Compute the probabilities.
|
||||
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
||||
# Compute the log probabilities (before applying top-p and top-k).
|
||||
logprobs = torch.log(probs)
|
||||
|
||||
# Apply top-p and top-k truncation.
|
||||
top_ps, top_ks = _get_top_p_top_k(input_metadata, self.vocab_size)
|
||||
assert len(top_ps) == len(top_ks) == probs.shape[0]
|
||||
if any(p < 1.0 for p in top_ps) or any(k != self.vocab_size for k in top_ks):
|
||||
probs = _apply_top_p_top_k(probs, top_ps, top_ks)
|
||||
|
||||
# Sample the next tokens.
|
||||
return _sample(probs, logprobs, input_metadata)
|
||||
|
||||
@ -92,12 +100,12 @@ def _prune_hidden_states(
|
||||
start_idx += prompt_len
|
||||
last_token_indicies.extend(
|
||||
range(start_idx, start_idx + input_metadata.num_generation_tokens))
|
||||
return hidden_states[last_token_indicies]
|
||||
return hidden_states.index_select(
|
||||
0, torch.tensor(last_token_indicies, device=hidden_states.device))
|
||||
|
||||
|
||||
def _get_penalties(
|
||||
input_metadata: InputMetadata,
|
||||
) -> Tuple[List[float], List[float]]:
|
||||
input_metadata: InputMetadata) -> Tuple[List[float], List[float]]:
|
||||
# Collect the presence and frequency penalties.
|
||||
presence_penalties: List[float] = []
|
||||
frequency_penalties: List[float] = []
|
||||
@ -116,9 +124,7 @@ def _get_penalties(
|
||||
return presence_penalties, frequency_penalties
|
||||
|
||||
|
||||
def _get_output_tokens(
|
||||
input_metadata: InputMetadata,
|
||||
) -> List[List[int]]:
|
||||
def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]:
|
||||
output_tokens: List[List[int]] = []
|
||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
||||
seq_ids, _ = seq_group
|
||||
@ -152,7 +158,7 @@ def _apply_penalties(
|
||||
continue
|
||||
p = presence_penalties[i]
|
||||
f = frequency_penalties[i]
|
||||
if p == 0.0 and f == 0.0:
|
||||
if abs(p) < _SAMPLING_EPS and abs(f) < _SAMPLING_EPS:
|
||||
continue
|
||||
indices.append(i)
|
||||
|
||||
@ -168,11 +174,13 @@ def _apply_penalties(
|
||||
device=logits.device)
|
||||
|
||||
frequency_penalties = [frequency_penalties[i] for i in indices]
|
||||
frequency_penalties = torch.tensor(
|
||||
frequency_penalties, dtype=logits.dtype, device=logits.device)
|
||||
frequency_penalties = torch.tensor(frequency_penalties,
|
||||
dtype=logits.dtype,
|
||||
device=logits.device)
|
||||
presence_penalties = [presence_penalties[i] for i in indices]
|
||||
presence_penalties = torch.tensor(
|
||||
presence_penalties, dtype=logits.dtype, device=logits.device)
|
||||
presence_penalties = torch.tensor(presence_penalties,
|
||||
dtype=logits.dtype,
|
||||
device=logits.device)
|
||||
|
||||
# We follow the definition in OpenAI API.
|
||||
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
|
||||
@ -182,15 +190,13 @@ def _apply_penalties(
|
||||
return logits
|
||||
|
||||
|
||||
def _get_temperatures(
|
||||
input_metadata: InputMetadata,
|
||||
) -> List[float]:
|
||||
def _get_temperatures(input_metadata: InputMetadata) -> List[float]:
|
||||
# Collect the temperatures for the logits.
|
||||
temperatures: List[float] = []
|
||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
||||
seq_ids, sampling_params = seq_group
|
||||
temperature = sampling_params.temperature
|
||||
if temperature == 0.0:
|
||||
if temperature < _SAMPLING_EPS:
|
||||
# NOTE: Zero temperature means deterministic sampling
|
||||
# (i.e., greedy sampling or beam search).
|
||||
# Set the temperature to 1 to avoid division by zero.
|
||||
@ -230,30 +236,32 @@ def _get_top_p_top_k(
|
||||
|
||||
|
||||
def _apply_top_p_top_k(
|
||||
probs: torch.Tensor,
|
||||
logits: torch.Tensor,
|
||||
top_ps: List[float],
|
||||
top_ks: List[int],
|
||||
) -> torch.Tensor:
|
||||
p = torch.tensor(top_ps, dtype=probs.dtype, device=probs.device)
|
||||
k = torch.tensor(top_ks, dtype=torch.int, device=probs.device)
|
||||
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
|
||||
p = torch.tensor(top_ps, dtype=logits.dtype, device=logits.device)
|
||||
k = torch.tensor(top_ks, dtype=torch.int, device=logits.device)
|
||||
logits_sort, logits_idx = logits.sort(dim=-1, descending=True)
|
||||
|
||||
# Apply top-p.
|
||||
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
||||
probs_sort = logits_sort.softmax(dim=-1)
|
||||
probs_sum = probs_sort.cumsum(dim=-1)
|
||||
top_p_mask = (probs_sum - probs_sort) > p.unsqueeze(dim=1)
|
||||
probs_sort[top_p_mask] = 0.0
|
||||
logits_sort[top_p_mask] = -float("inf")
|
||||
|
||||
# Apply top-k.
|
||||
# Create a mask for the top-k elements.
|
||||
top_k_mask = torch.arange(probs_idx.shape[-1], device=probs_idx.device)
|
||||
top_k_mask = top_k_mask.expand(probs_idx.shape[0], -1)
|
||||
top_k_mask = torch.arange(logits_idx.shape[-1], device=logits_idx.device)
|
||||
top_k_mask = top_k_mask.expand(logits_idx.shape[0], -1)
|
||||
top_k_mask = top_k_mask >= k.unsqueeze(dim=1)
|
||||
probs_sort[top_k_mask] = 0.0
|
||||
logits_sort[top_k_mask] = -float("inf")
|
||||
|
||||
# Re-sort the probabilities.
|
||||
probs = torch.gather(
|
||||
probs_sort, dim=-1, index=torch.argsort(probs_idx, dim=-1))
|
||||
return probs
|
||||
logits = torch.gather(logits_sort,
|
||||
dim=-1,
|
||||
index=torch.argsort(logits_idx, dim=-1))
|
||||
return logits
|
||||
|
||||
|
||||
def _get_topk_logprobs(
|
||||
@ -284,9 +292,15 @@ def _sample_from_prompt(
|
||||
if sampling_params.use_beam_search:
|
||||
# Beam search.
|
||||
beam_width = sampling_params.best_of
|
||||
_, next_token_ids = torch.topk(prob, beam_width)
|
||||
# Sample 2 * beam_width candidates to make sure that with high
|
||||
# probability we can get `beam_width` candidates in addition to
|
||||
# the finished sequences for the next iteration. See
|
||||
# https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563
|
||||
# for details. See also HF reference:
|
||||
# https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
|
||||
_, next_token_ids = torch.topk(prob, 2 * beam_width)
|
||||
next_token_ids = next_token_ids.tolist()
|
||||
elif sampling_params.temperature == 0.0:
|
||||
elif sampling_params.temperature < _SAMPLING_EPS:
|
||||
# Greedy sampling.
|
||||
assert sampling_params.best_of == 1
|
||||
next_token_id = torch.argmax(prob)
|
||||
@ -295,8 +309,9 @@ def _sample_from_prompt(
|
||||
# Random sampling.
|
||||
# Sample `best_of` tokens for the prompt.
|
||||
num_seqs = sampling_params.best_of
|
||||
next_token_ids = torch.multinomial(
|
||||
prob, num_samples=num_seqs, replacement=True)
|
||||
next_token_ids = torch.multinomial(prob,
|
||||
num_samples=num_seqs,
|
||||
replacement=True)
|
||||
next_token_ids = next_token_ids.tolist()
|
||||
return next_token_ids
|
||||
|
||||
@ -314,36 +329,19 @@ def _sample_from_generation_tokens(
|
||||
if sampling_params.use_beam_search:
|
||||
# Beam search.
|
||||
# Add cumulative logprobs for the sequences in the group.
|
||||
seq_logprobs = torch.tensor(
|
||||
seq_logprobs, dtype=torch.float, device=logprobs.device)
|
||||
seq_logprobs = torch.tensor(seq_logprobs,
|
||||
dtype=torch.float,
|
||||
device=logprobs.device)
|
||||
logprobs = logprobs + seq_logprobs.unsqueeze(dim=1)
|
||||
|
||||
vocab_size = logprobs.size(-1)
|
||||
beam_width = len(seq_ids)
|
||||
_, topk_ids = torch.topk(logprobs.flatten(), beam_width)
|
||||
_, topk_ids = torch.topk(logprobs.flatten(), 2 * beam_width)
|
||||
topk_ids = topk_ids.tolist()
|
||||
seq_idx = [i // vocab_size for i in topk_ids]
|
||||
beam_seq_ids = [seq_ids[i] for i in seq_idx]
|
||||
token_ids = [i % vocab_size for i in topk_ids]
|
||||
|
||||
beam_outputs: Dict[int, Tuple[int, int]] = {}
|
||||
outstanding_beams: List[Tuple[int, int]] = []
|
||||
# If a beam survives, continue with it.
|
||||
for seq_id, token_id in zip(beam_seq_ids, token_ids):
|
||||
if seq_id not in beam_outputs:
|
||||
beam_outputs[seq_id] = (seq_id, token_id)
|
||||
else:
|
||||
outstanding_beams.append((seq_id, token_id))
|
||||
|
||||
# If a beam is discarded, fork another beam.
|
||||
for seq_id in seq_ids:
|
||||
if seq_id not in beam_outputs:
|
||||
beam_outputs[seq_id] = outstanding_beams.pop()
|
||||
assert not outstanding_beams
|
||||
|
||||
parent_seq_ids = [beam_outputs[seq_id][0] for seq_id in seq_ids]
|
||||
next_token_ids = [beam_outputs[seq_id][1] for seq_id in seq_ids]
|
||||
elif sampling_params.temperature == 0.0:
|
||||
parent_seq_ids = [seq_ids[i] for i in seq_idx]
|
||||
next_token_ids = [i % vocab_size for i in topk_ids]
|
||||
elif sampling_params.temperature < _SAMPLING_EPS:
|
||||
# Greedy sampling.
|
||||
assert len(seq_ids) == 1
|
||||
next_token_id = torch.argmax(probs, dim=-1)
|
||||
@ -352,8 +350,9 @@ def _sample_from_generation_tokens(
|
||||
else:
|
||||
# Random sampling.
|
||||
# Sample 1 token for each sequence in the group.
|
||||
next_token_ids = torch.multinomial(
|
||||
probs, num_samples=1, replacement=True)
|
||||
next_token_ids = torch.multinomial(probs,
|
||||
num_samples=1,
|
||||
replacement=True)
|
||||
next_token_ids = next_token_ids.squeeze(dim=-1).tolist()
|
||||
parent_seq_ids = seq_ids
|
||||
return parent_seq_ids, next_token_ids
|
||||
@ -363,16 +362,18 @@ def _sample(
|
||||
probs: torch.Tensor,
|
||||
logprobs: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
) -> Dict[int, SequenceOutputs]:
|
||||
seq_outputs: Dict[int, SequenceOutputs] = {}
|
||||
) -> SamplerOutput:
|
||||
seq_outputs: SamplerOutput = []
|
||||
|
||||
# TODO(woosuk): Optimize.
|
||||
idx = 0
|
||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
||||
seq_group_outputs: List[SequenceOutputs] = []
|
||||
seq_ids, sampling_params = seq_group
|
||||
if i < input_metadata.num_prompts:
|
||||
# Generate the next tokens for a prompt input.
|
||||
assert len(seq_ids) == sampling_params.best_of
|
||||
assert len(seq_ids) == 1, "Prompt input should have only one seq."
|
||||
parent_seq_id = seq_ids[0]
|
||||
prob = probs[idx]
|
||||
logprob = logprobs[idx]
|
||||
idx += 1
|
||||
@ -380,45 +381,47 @@ def _sample(
|
||||
# Sample the next tokens.
|
||||
next_token_ids = _sample_from_prompt(prob, sampling_params)
|
||||
# Get top-k log probabilities for the next tokens.
|
||||
next_logprobs = _get_topk_logprobs(
|
||||
logprob, sampling_params.logprobs)
|
||||
next_logprobs = _get_topk_logprobs(logprob,
|
||||
sampling_params.logprobs)
|
||||
|
||||
# Build the output.
|
||||
for seq_id, next_token_id in zip(seq_ids, next_token_ids):
|
||||
for next_token_id in next_token_ids:
|
||||
output_logprobs = next_logprobs.copy()
|
||||
output_logprobs[next_token_id] = logprob[next_token_id].item()
|
||||
seq_outputs[seq_id] = SequenceOutputs(
|
||||
seq_id, seq_id, next_token_id, output_logprobs)
|
||||
seq_group_outputs.append(
|
||||
SequenceOutputs(parent_seq_id, next_token_id,
|
||||
output_logprobs))
|
||||
else:
|
||||
# Generate the next tokens for generation tokens.
|
||||
prob = probs[idx:idx + len(seq_ids)]
|
||||
logprob = logprobs[idx:idx + len(seq_ids)]
|
||||
idx += len(seq_ids)
|
||||
num_parent_seqs = len(seq_ids)
|
||||
prob = probs[idx:idx + num_parent_seqs]
|
||||
logprob = logprobs[idx:idx + num_parent_seqs]
|
||||
idx += num_parent_seqs
|
||||
|
||||
# Sample the next tokens.
|
||||
seq_logprobs = [
|
||||
input_metadata.seq_data[seq_id].cumulative_logprob
|
||||
for seq_id in seq_ids]
|
||||
for seq_id in seq_ids
|
||||
]
|
||||
parent_seq_ids, next_token_ids = _sample_from_generation_tokens(
|
||||
seq_ids, prob, logprob, seq_logprobs, sampling_params)
|
||||
|
||||
# Get top-k log probabilities for the next tokens.
|
||||
next_logprobs: Dict[int, Dict[int, float]] = {}
|
||||
for i, seq_id in enumerate(seq_ids):
|
||||
for j, seq_id in enumerate(seq_ids):
|
||||
next_logprobs[seq_id] = _get_topk_logprobs(
|
||||
logprob[i], sampling_params.logprobs)
|
||||
logprob[j], sampling_params.logprobs)
|
||||
|
||||
# Build the output.
|
||||
for seq_id, parent_seq_id, next_token_id in zip(
|
||||
seq_ids, parent_seq_ids, next_token_ids):
|
||||
i = seq_ids.index(parent_seq_id)
|
||||
for parent_seq_id, next_token_id in zip(parent_seq_ids,
|
||||
next_token_ids):
|
||||
j = seq_ids.index(parent_seq_id)
|
||||
output_logprobs = next_logprobs[parent_seq_id].copy()
|
||||
output_logprobs[next_token_id] = logprob[i, next_token_id].item()
|
||||
seq_outputs[seq_id] = SequenceOutputs(
|
||||
seq_id,
|
||||
parent_seq_id,
|
||||
next_token_id,
|
||||
output_logprobs,
|
||||
)
|
||||
output_logprobs[next_token_id] = logprob[j,
|
||||
next_token_id].item()
|
||||
seq_group_outputs.append(
|
||||
SequenceOutputs(parent_seq_id, next_token_id,
|
||||
output_logprobs))
|
||||
seq_outputs.append(seq_group_outputs)
|
||||
|
||||
return seq_outputs
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
"""Utilities for selecting and loading models."""
|
||||
import contextlib
|
||||
from typing import Type
|
||||
|
||||
import torch
|
||||
@ -6,19 +7,39 @@ import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.model_executor.models import (GPT2LMHeadModel, GPTNeoXForCausalLM,
|
||||
LlamaForCausalLM, OPTForCausalLM)
|
||||
from vllm.model_executor.models import * # pylint: disable=wildcard-import
|
||||
from vllm.model_executor.weight_utils import initialize_dummy_weights
|
||||
|
||||
# TODO(woosuk): Lazy-load the model classes.
|
||||
_MODEL_REGISTRY = {
|
||||
"AquilaModel": AquilaForCausalLM,
|
||||
"BaiChuanForCausalLM": BaiChuanForCausalLM, # baichuan-7b
|
||||
"BaichuanForCausalLM": BaichuanForCausalLM, # baichuan-13b
|
||||
"BloomForCausalLM": BloomForCausalLM,
|
||||
"FalconForCausalLM": FalconForCausalLM,
|
||||
"GPT2LMHeadModel": GPT2LMHeadModel,
|
||||
"GPTBigCodeForCausalLM": GPTBigCodeForCausalLM,
|
||||
"GPTJForCausalLM": GPTJForCausalLM,
|
||||
"GPTNeoXForCausalLM": GPTNeoXForCausalLM,
|
||||
"InternLMForCausalLM": InternLMForCausalLM,
|
||||
"LlamaForCausalLM": LlamaForCausalLM,
|
||||
"LLaMAForCausalLM": LlamaForCausalLM, # For decapoda-research/llama-*
|
||||
"MPTForCausalLM": MPTForCausalLM,
|
||||
"OPTForCausalLM": OPTForCausalLM,
|
||||
"QWenLMHeadModel": QWenLMHeadModel,
|
||||
"RWForCausalLM": FalconForCausalLM,
|
||||
}
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _set_default_torch_dtype(dtype: torch.dtype):
|
||||
"""Sets the default torch dtype to the given dtype."""
|
||||
old_dtype = torch.get_default_dtype()
|
||||
torch.set_default_dtype(dtype)
|
||||
yield
|
||||
torch.set_default_dtype(old_dtype)
|
||||
|
||||
|
||||
def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
|
||||
architectures = getattr(config, "architectures", [])
|
||||
for arch in architectures:
|
||||
@ -26,26 +47,23 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
|
||||
return _MODEL_REGISTRY[arch]
|
||||
raise ValueError(
|
||||
f"Model architectures {architectures} are not supported for now. "
|
||||
f"Supported architectures: {list(_MODEL_REGISTRY.keys())}"
|
||||
)
|
||||
f"Supported architectures: {list(_MODEL_REGISTRY.keys())}")
|
||||
|
||||
|
||||
def get_model(model_config: ModelConfig) -> nn.Module:
|
||||
model_class = _get_model_architecture(model_config.hf_config)
|
||||
torch.set_default_dtype(model_config.dtype)
|
||||
|
||||
with _set_default_torch_dtype(model_config.dtype):
|
||||
# Create a model instance.
|
||||
# The weights will be initialized as empty tensors.
|
||||
model = model_class(model_config.hf_config)
|
||||
if model_config.use_dummy_weights:
|
||||
if model_config.load_format == "dummy":
|
||||
model = model.cuda()
|
||||
# NOTE(woosuk): For accurate performance evaluation, we assign
|
||||
# random values to the weights.
|
||||
initialize_dummy_weights(model)
|
||||
else:
|
||||
# Load the weights from the cached or downloaded files.
|
||||
model.load_weights(
|
||||
model_config.model, model_config.download_dir,
|
||||
model_config.use_np_weights)
|
||||
model.load_weights(model_config.model, model_config.download_dir,
|
||||
model_config.load_format)
|
||||
model = model.cuda()
|
||||
return model.eval()
|
||||
|
||||
@ -1,12 +1,31 @@
|
||||
from vllm.model_executor.models.gpt_neox import GPTNeoXForCausalLM
|
||||
from vllm.model_executor.models.aquila import AquilaForCausalLM
|
||||
from vllm.model_executor.models.baichuan import (BaiChuanForCausalLM,
|
||||
BaichuanForCausalLM)
|
||||
from vllm.model_executor.models.bloom import BloomForCausalLM
|
||||
from vllm.model_executor.models.falcon import FalconForCausalLM
|
||||
from vllm.model_executor.models.gpt2 import GPT2LMHeadModel
|
||||
from vllm.model_executor.models.gpt_bigcode import GPTBigCodeForCausalLM
|
||||
from vllm.model_executor.models.gpt_j import GPTJForCausalLM
|
||||
from vllm.model_executor.models.gpt_neox import GPTNeoXForCausalLM
|
||||
from vllm.model_executor.models.internlm import InternLMForCausalLM
|
||||
from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||
from vllm.model_executor.models.mpt import MPTForCausalLM
|
||||
from vllm.model_executor.models.opt import OPTForCausalLM
|
||||
|
||||
from vllm.model_executor.models.qwen import QWenLMHeadModel
|
||||
|
||||
__all__ = [
|
||||
"AquilaForCausalLM",
|
||||
"BaiChuanForCausalLM",
|
||||
"BaichuanForCausalLM",
|
||||
"BloomForCausalLM",
|
||||
"FalconForCausalLM",
|
||||
"GPT2LMHeadModel",
|
||||
"GPTBigCodeForCausalLM",
|
||||
"GPTJForCausalLM",
|
||||
"GPTNeoXForCausalLM",
|
||||
"InternLMForCausalLM",
|
||||
"LlamaForCausalLM",
|
||||
"MPTForCausalLM",
|
||||
"OPTForCausalLM",
|
||||
"QWenLMHeadModel",
|
||||
]
|
||||
|
||||
357
vllm/model_executor/models/aquila.py
Normal file
357
vllm/model_executor/models/aquila.py
Normal file
@ -0,0 +1,357 @@
|
||||
# coding=utf-8
|
||||
# Adapted from
|
||||
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only LLaMA model compatible with HuggingFace weights.
|
||||
|
||||
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
||||
InputMetadata to extract the original 2D shape of the input.
|
||||
"""
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.weight_utils import (
|
||||
hf_model_weights_iterator, load_padded_tensor_parallel_vocab,
|
||||
load_tensor_parallel_weights)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.transformers_utils.configs.aquila import AquilaConfig
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
class AquilaMLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
):
|
||||
super().__init__()
|
||||
self.gate_up_proj = ColumnParallelLinear(hidden_size,
|
||||
2 * intermediate_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.down_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
def forward(self, x):
|
||||
gate_up, _ = self.gate_up_proj(x)
|
||||
x = self.act_fn(gate_up)
|
||||
x, _ = self.down_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class AquilaRMSNorm(nn.Module):
|
||||
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
"""
|
||||
AquilaRMSNorm is equivalent to T5LayerNorm
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
variance = hidden_states.to(torch.float32).pow(2).mean(-1,
|
||||
keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance +
|
||||
self.variance_epsilon)
|
||||
|
||||
return (self.weight * hidden_states).to(input_dtype)
|
||||
|
||||
|
||||
class AquilaAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.total_num_heads = num_heads
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_size
|
||||
self.total_num_kv_heads = num_kv_heads
|
||||
assert self.total_num_kv_heads % tp_size == 0
|
||||
self.num_kv_heads = self.total_num_kv_heads // tp_size
|
||||
self.head_dim = hidden_size // self.total_num_heads
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
|
||||
self.qkv_proj = ColumnParallelLinear(
|
||||
hidden_size,
|
||||
(self.total_num_heads + 2 * self.total_num_kv_heads) *
|
||||
self.head_dim,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
perform_initialization=False,
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False,
|
||||
)
|
||||
self.attn = PagedAttentionWithRoPE(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
rotary_dim=self.head_dim,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
k_cache, v_cache = kv_cache
|
||||
attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
|
||||
input_metadata, cache_event)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class AquilaDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(self, config: AquilaConfig):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.self_attn = AquilaAttention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_kv_heads=config.num_attention_heads,
|
||||
)
|
||||
self.mlp = AquilaMLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
)
|
||||
self.input_layernorm = AquilaRMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = AquilaRMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
# Self Attention
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
input_metadata=input_metadata,
|
||||
cache_event=cache_event,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AquilaModel(nn.Module):
|
||||
|
||||
def __init__(self, config: AquilaConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
#vocab_size = ((config.vocab_size + 63) // 64) * 64
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
perform_initialization=False)
|
||||
self.layers = nn.ModuleList([
|
||||
AquilaDecoderLayer(config) for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
self.norm = AquilaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
for i in range(len(self.layers)):
|
||||
if cache_events is None:
|
||||
cache_event = None
|
||||
else:
|
||||
cache_event = cache_events[i]
|
||||
layer = self.layers[i]
|
||||
hidden_states = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i],
|
||||
input_metadata,
|
||||
cache_event,
|
||||
)
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AquilaForCausalLM(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.model = AquilaModel(config)
|
||||
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
||||
self.lm_head = ColumnParallelLinear(config.hidden_size,
|
||||
vocab_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> SamplerOutput:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
input_metadata, cache_events)
|
||||
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
||||
input_metadata)
|
||||
return next_tokens
|
||||
|
||||
_column_parallel_weights = [
|
||||
"qkv_proj.weight", "gate_proj.weight", "up_proj.weight"
|
||||
]
|
||||
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
|
||||
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto"):
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||
q_proj_shard_size = (self.config.hidden_size // tp_size)
|
||||
kv_proj_shard_size = (self.config.hidden_size //
|
||||
self.config.num_attention_heads *
|
||||
self.config.num_attention_heads // tp_size)
|
||||
attention_weight_specs = [
|
||||
# (weight_name, shard_size, offset)
|
||||
("q_proj", q_proj_shard_size, 0),
|
||||
("k_proj", kv_proj_shard_size, q_proj_shard_size),
|
||||
("v_proj", kv_proj_shard_size,
|
||||
q_proj_shard_size + kv_proj_shard_size),
|
||||
]
|
||||
state_dict = self.state_dict()
|
||||
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format):
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
is_attention_weight = False
|
||||
for weight_name, shard_size, offset in attention_weight_specs:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
param = state_dict[name.replace(weight_name, "qkv_proj")]
|
||||
|
||||
loaded_weight = loaded_weight[
|
||||
shard_size * tensor_model_parallel_rank:shard_size *
|
||||
(tensor_model_parallel_rank + 1)]
|
||||
param_slice = param.data[offset:offset + shard_size]
|
||||
assert param_slice.shape == loaded_weight.shape
|
||||
|
||||
param_slice.copy_(loaded_weight)
|
||||
is_attention_weight = True
|
||||
break
|
||||
if is_attention_weight:
|
||||
continue
|
||||
|
||||
is_gate_up_weight = False
|
||||
for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
|
||||
if weight_name not in name:
|
||||
continue
|
||||
param = state_dict[name.replace(weight_name, "gate_up_proj")]
|
||||
shard_size = param.shape[0] // 2
|
||||
loaded_weight = loaded_weight[
|
||||
shard_size * tensor_model_parallel_rank:shard_size *
|
||||
(tensor_model_parallel_rank + 1)]
|
||||
param_slice = param.data[shard_size * stride_id:shard_size *
|
||||
(stride_id + 1)]
|
||||
assert param_slice.shape == loaded_weight.shape
|
||||
param_slice.copy_(loaded_weight)
|
||||
is_gate_up_weight = True
|
||||
break
|
||||
if is_gate_up_weight:
|
||||
continue
|
||||
|
||||
param = state_dict[name]
|
||||
if "embed_tokens" in name or "lm_head" in name:
|
||||
load_padded_tensor_parallel_vocab(param, loaded_weight,
|
||||
tensor_model_parallel_rank)
|
||||
continue
|
||||
|
||||
load_tensor_parallel_weights(param, loaded_weight, name,
|
||||
self._column_parallel_weights,
|
||||
self._row_parallel_weights,
|
||||
tensor_model_parallel_rank)
|
||||
374
vllm/model_executor/models/baichuan.py
Normal file
374
vllm/model_executor/models/baichuan.py
Normal file
@ -0,0 +1,374 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only BaiChuan model compatible with HuggingFace weights.
|
||||
|
||||
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
||||
InputMetadata to extract the original 2D shape of the input.
|
||||
"""
|
||||
import math
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.attention import (PagedAttentionWithRoPE,
|
||||
PagedAttentionWithALiBi)
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.weight_utils import (
|
||||
convert_pyslice_to_tensor, hf_model_weights_iterator,
|
||||
load_padded_tensor_parallel_vocab, load_tensor_parallel_weights)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.transformers_utils.configs.baichuan import BaiChuanConfig
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
|
||||
closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
|
||||
base = torch.tensor(
|
||||
2**(-(2**-(math.log2(closest_power_of_2) - 3))),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
|
||||
slopes = torch.pow(base, powers)
|
||||
|
||||
if closest_power_of_2 != total_num_heads:
|
||||
extra_base = torch.tensor(
|
||||
2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
num_remaining_heads = min(closest_power_of_2,
|
||||
total_num_heads - closest_power_of_2)
|
||||
extra_powers = torch.arange(start=1,
|
||||
end=1 + 2 * num_remaining_heads,
|
||||
step=2,
|
||||
dtype=torch.int32)
|
||||
slopes = torch.cat(
|
||||
[slopes, torch.pow(extra_base, extra_powers)], dim=0)
|
||||
return slopes
|
||||
|
||||
|
||||
class BaiChuanMLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
):
|
||||
super().__init__()
|
||||
self.gate_up_proj = ColumnParallelLinear(hidden_size,
|
||||
2 * intermediate_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.down_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
def forward(self, x):
|
||||
gate_up, _ = self.gate_up_proj(x)
|
||||
x = self.act_fn(gate_up)
|
||||
x, _ = self.down_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class BaiChuanAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
position_embedding: str,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
|
||||
)
|
||||
self.total_num_heads = num_heads
|
||||
assert self.total_num_heads % tensor_model_parallel_world_size == 0
|
||||
self.num_heads = (self.total_num_heads //
|
||||
tensor_model_parallel_world_size)
|
||||
self.head_dim = hidden_size // self.total_num_heads
|
||||
self.postion_embedding = position_embedding
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
self.W_pack = ColumnParallelLinear(
|
||||
hidden_size,
|
||||
3 * hidden_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
perform_initialization=False,
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False,
|
||||
)
|
||||
# Create the alibi slopes and slice them.
|
||||
if self.postion_embedding == "ALIBI":
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
head_start = tp_rank * self.num_heads
|
||||
head_end = (tp_rank + 1) * self.num_heads
|
||||
alibi_slopes = _get_alibi_slopes(self.total_num_heads)
|
||||
alibi_slopes = alibi_slopes[head_start:head_end].tolist()
|
||||
|
||||
scaling = self.head_dim**-0.5
|
||||
self.attn = PagedAttentionWithALiBi(self.num_heads, self.head_dim,
|
||||
scaling, alibi_slopes)
|
||||
else:
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.attn = PagedAttentionWithRoPE(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
rotary_dim=self.head_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.W_pack(hidden_states)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
k_cache, v_cache = kv_cache
|
||||
if self.postion_embedding == "ALIBI":
|
||||
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
|
||||
cache_event)
|
||||
else:
|
||||
attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
|
||||
input_metadata, cache_event)
|
||||
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class BaiChuanDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(self, config: BaiChuanConfig, position_embedding: str):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.self_attn = BaiChuanAttention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
position_embedding=position_embedding,
|
||||
)
|
||||
self.mlp = BaiChuanMLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
# Self Attention
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
input_metadata=input_metadata,
|
||||
cache_event=cache_event,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BaiChuanModel(nn.Module):
|
||||
|
||||
def __init__(self, config: BaiChuanConfig, position_embedding: str):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
perform_initialization=False)
|
||||
self.layers = nn.ModuleList([
|
||||
BaiChuanDecoderLayer(config, position_embedding)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
for i in range(len(self.layers)):
|
||||
if cache_events is None:
|
||||
cache_event = None
|
||||
else:
|
||||
cache_event = cache_events[i]
|
||||
layer = self.layers[i]
|
||||
hidden_states = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i],
|
||||
input_metadata,
|
||||
cache_event,
|
||||
)
|
||||
hidden_states = self.norm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BaiChuanBaseForCausalLM(nn.Module):
|
||||
|
||||
def __init__(self, config, position_embedding: str):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.model = BaiChuanModel(config, position_embedding)
|
||||
self.lm_head = ColumnParallelLinear(config.hidden_size,
|
||||
config.vocab_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> SamplerOutput:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
input_metadata, cache_events)
|
||||
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
||||
input_metadata)
|
||||
return next_tokens
|
||||
|
||||
_column_parallel_weights = []
|
||||
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
|
||||
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto"):
|
||||
tp_world_size = get_tensor_model_parallel_world_size()
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
state_dict = self.state_dict()
|
||||
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format):
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
||||
|
||||
if "W_pack" in name:
|
||||
total_num_heads = self.config.num_attention_heads
|
||||
hidden_size = self.config.hidden_size
|
||||
head_size = hidden_size // total_num_heads
|
||||
num_heads = total_num_heads // tp_world_size
|
||||
head_start = tp_rank * num_heads
|
||||
head_end = (tp_rank + 1) * num_heads
|
||||
|
||||
loaded_weight = loaded_weight.view(3, total_num_heads,
|
||||
head_size, hidden_size)
|
||||
loaded_weight = loaded_weight[:, head_start:head_end, :, :]
|
||||
loaded_weight = loaded_weight.reshape(-1, hidden_size)
|
||||
|
||||
is_gate_up_weight = False
|
||||
for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
|
||||
if weight_name not in name:
|
||||
continue
|
||||
param = state_dict[name.replace(weight_name, "gate_up_proj")]
|
||||
shard_size = param.shape[0] // 2
|
||||
loaded_weight = loaded_weight[shard_size * tp_rank:shard_size *
|
||||
(tp_rank + 1)]
|
||||
param_slice = param.data[shard_size * stride_id:shard_size *
|
||||
(stride_id + 1)]
|
||||
assert param_slice.shape == loaded_weight.shape
|
||||
param_slice.copy_(loaded_weight)
|
||||
is_gate_up_weight = True
|
||||
break
|
||||
if is_gate_up_weight:
|
||||
continue
|
||||
|
||||
param = state_dict[name]
|
||||
|
||||
if "embed_tokens" in name or "lm_head" in name:
|
||||
load_padded_tensor_parallel_vocab(param, loaded_weight,
|
||||
tp_rank)
|
||||
continue
|
||||
|
||||
load_tensor_parallel_weights(
|
||||
param,
|
||||
loaded_weight,
|
||||
name,
|
||||
self._column_parallel_weights,
|
||||
self._row_parallel_weights,
|
||||
tp_rank,
|
||||
)
|
||||
|
||||
|
||||
class BaichuanForCausalLM(BaiChuanBaseForCausalLM): # baichuan 13b
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config, "ALIBI")
|
||||
|
||||
|
||||
class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): # baichuan 7b
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config, "ROPE")
|
||||
324
vllm/model_executor/models/bloom.py
Normal file
324
vllm/model_executor/models/bloom.py
Normal file
@ -0,0 +1,324 @@
|
||||
# coding=utf-8
|
||||
# Adapted from
|
||||
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/bloom/modeling_bloom.py
|
||||
# Copyright 2023 The CacheFlow team.
|
||||
# Copyright 2022 HuggingFace Inc. team and BigScience workshop.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only BLOOM model compatible with HuggingFace weights.
|
||||
|
||||
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
||||
InputMetadata to extract the original 2D shape of the input.
|
||||
"""
|
||||
import math
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import BloomConfig
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.attention import PagedAttentionWithALiBi
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
|
||||
load_tensor_parallel_weights)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
|
||||
closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
|
||||
base = torch.tensor(
|
||||
2**(-(2**-(math.log2(closest_power_of_2) - 3))),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
|
||||
slopes = torch.pow(base, powers)
|
||||
|
||||
if closest_power_of_2 != total_num_heads:
|
||||
extra_base = torch.tensor(
|
||||
2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
num_remaining_heads = min(closest_power_of_2,
|
||||
total_num_heads - closest_power_of_2)
|
||||
extra_powers = torch.arange(start=1,
|
||||
end=1 + 2 * num_remaining_heads,
|
||||
step=2,
|
||||
dtype=torch.int32)
|
||||
slopes = torch.cat(
|
||||
[slopes, torch.pow(extra_base, extra_powers)], dim=0)
|
||||
return slopes
|
||||
|
||||
|
||||
class BloomAttention(nn.Module):
|
||||
|
||||
def __init__(self, config: BloomConfig):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.total_num_heads = config.n_head
|
||||
self.head_dim = self.hidden_size // self.total_num_heads
|
||||
assert self.head_dim * self.total_num_heads == self.hidden_size
|
||||
|
||||
tp_world_size = get_tensor_model_parallel_world_size()
|
||||
assert self.total_num_heads % tp_world_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_world_size
|
||||
|
||||
self.query_key_value = ColumnParallelLinear(
|
||||
self.hidden_size,
|
||||
3 * self.hidden_size,
|
||||
bias=True,
|
||||
gather_output=False,
|
||||
perform_initialization=False,
|
||||
)
|
||||
self.dense = RowParallelLinear(
|
||||
self.hidden_size,
|
||||
self.hidden_size,
|
||||
bias=True,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False,
|
||||
)
|
||||
|
||||
# Create the alibi slopes and slice them.
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
head_start = tp_rank * self.num_heads
|
||||
head_end = (tp_rank + 1) * self.num_heads
|
||||
alibi_slopes = _get_alibi_slopes(self.total_num_heads)
|
||||
alibi_slopes = alibi_slopes[head_start:head_end].tolist()
|
||||
|
||||
scaling = self.head_dim**-0.5
|
||||
self.attn = PagedAttentionWithALiBi(self.num_heads, self.head_dim,
|
||||
scaling, alibi_slopes)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
position_ids: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
del position_ids # Unused.
|
||||
qkv, _ = self.query_key_value(hidden_states)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
k_cache, v_cache = kv_cache
|
||||
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
|
||||
cache_event)
|
||||
output, _ = self.dense(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class BloomMLP(nn.Module):
|
||||
|
||||
def __init__(self, config: BloomConfig):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
self.dense_h_to_4h = ColumnParallelLinear(hidden_size,
|
||||
4 * hidden_size,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.act = get_act_fn("gelu")
|
||||
self.dense_4h_to_h = RowParallelLinear(4 * hidden_size,
|
||||
hidden_size,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x, _ = self.dense_h_to_4h(x)
|
||||
x = self.act(x)
|
||||
x, _ = self.dense_4h_to_h(x)
|
||||
return x
|
||||
|
||||
|
||||
class BloomBlock(nn.Module):
|
||||
|
||||
def __init__(self, config: BloomConfig):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
|
||||
self.input_layernorm = nn.LayerNorm(hidden_size,
|
||||
eps=config.layer_norm_epsilon)
|
||||
self.self_attention = BloomAttention(config)
|
||||
self.post_attention_layernorm = nn.LayerNorm(
|
||||
hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.mlp = BloomMLP(config)
|
||||
self.apply_residual_connection_post_layernorm = (
|
||||
config.apply_residual_connection_post_layernorm)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
position_ids: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
# Layer norm at the beginning of the transformer layer.
|
||||
layernorm_output = self.input_layernorm(hidden_states)
|
||||
|
||||
# Layer norm post the self attention.
|
||||
if self.apply_residual_connection_post_layernorm:
|
||||
residual = layernorm_output
|
||||
else:
|
||||
residual = hidden_states
|
||||
|
||||
# Self attention.
|
||||
attention_output = self.self_attention(
|
||||
position_ids=position_ids,
|
||||
hidden_states=layernorm_output,
|
||||
kv_cache=kv_cache,
|
||||
input_metadata=input_metadata,
|
||||
cache_event=cache_event,
|
||||
)
|
||||
attention_output = attention_output + residual
|
||||
layernorm_output = self.post_attention_layernorm(attention_output)
|
||||
|
||||
# Get residual
|
||||
if self.apply_residual_connection_post_layernorm:
|
||||
residual = layernorm_output
|
||||
else:
|
||||
residual = attention_output
|
||||
|
||||
# MLP.
|
||||
output = self.mlp(layernorm_output) + residual
|
||||
return output
|
||||
|
||||
|
||||
class BloomModel(nn.Module):
|
||||
|
||||
def __init__(self, config: BloomConfig):
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
|
||||
# Embedding + LN Embedding
|
||||
self.word_embeddings = VocabParallelEmbedding(
|
||||
config.vocab_size, self.embed_dim, perform_initialization=False)
|
||||
self.word_embeddings_layernorm = nn.LayerNorm(
|
||||
self.embed_dim, eps=config.layer_norm_epsilon)
|
||||
|
||||
# Transformer blocks
|
||||
self.h = nn.ModuleList(
|
||||
[BloomBlock(config) for _ in range(config.num_hidden_layers)])
|
||||
|
||||
# Final Layer Norm
|
||||
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.word_embeddings(input_ids)
|
||||
hidden_states = self.word_embeddings_layernorm(hidden_states)
|
||||
for i in range(len(self.h)):
|
||||
if cache_events is None:
|
||||
cache_event = None
|
||||
else:
|
||||
cache_event = cache_events[i]
|
||||
layer = self.h[i]
|
||||
hidden_states = layer(
|
||||
position_ids,
|
||||
hidden_states,
|
||||
kv_caches[i],
|
||||
input_metadata,
|
||||
cache_event,
|
||||
)
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BloomForCausalLM(nn.Module):
|
||||
|
||||
def __init__(self, config: BloomConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.transformer = BloomModel(config)
|
||||
# TODO(zhuohan): create a new weight after implementing pipeline
|
||||
# parallelism
|
||||
self.lm_head_weight = self.transformer.word_embeddings.weight
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> SamplerOutput:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
input_metadata, cache_events)
|
||||
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
|
||||
input_metadata)
|
||||
return next_tokens
|
||||
|
||||
_column_parallel_weights = [
|
||||
"word_embeddings.weight", "dense_h_to_4h.weight", "dense_h_to_4h.bias"
|
||||
]
|
||||
_row_parallel_weights = ["dense.weight", "dense_4h_to_h.weight"]
|
||||
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto"):
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
state_dict = self.state_dict()
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format):
|
||||
if name == "lm_head.weight":
|
||||
# Since hidden_states are parallelized, we need to
|
||||
# load lm_head.weight in parallel.
|
||||
self._column_parallel_weights.append(name)
|
||||
# If lm_head is provided, use it instead.
|
||||
param = self.lm_head_weight
|
||||
else:
|
||||
if not name.startswith("transformer."):
|
||||
name = "transformer." + name
|
||||
param = state_dict[name]
|
||||
|
||||
if "query_key_value" in name:
|
||||
# NOTE(woosuk): BLOOM's fused QKV has the shape of
|
||||
# [num_heads * 3 * head_size, hidden_size], while the
|
||||
# required shape is [3 * num_heads * head_size, hidden_size].
|
||||
# Thus, we need weight conversion.
|
||||
shard_size = param.shape[0]
|
||||
start = shard_size * tp_rank
|
||||
end = shard_size * (tp_rank + 1)
|
||||
loaded_weight = loaded_weight[start:end]
|
||||
|
||||
num_heads = self.config.num_attention_heads
|
||||
hidden_size = self.config.hidden_size
|
||||
head_size = hidden_size // num_heads
|
||||
if "query_key_value.weight" in name:
|
||||
loaded_weight = loaded_weight.view(-1, 3, head_size,
|
||||
hidden_size)
|
||||
loaded_weight = loaded_weight.transpose(0, 1)
|
||||
loaded_weight = loaded_weight.reshape(-1, hidden_size)
|
||||
elif "query_key_value.bias" in name:
|
||||
loaded_weight = loaded_weight.view(-1, 3, head_size)
|
||||
loaded_weight = loaded_weight.transpose(0, 1)
|
||||
loaded_weight = loaded_weight.reshape(-1)
|
||||
else:
|
||||
raise ValueError(f"Unexpected weight name: {name}")
|
||||
load_tensor_parallel_weights(param, loaded_weight, name,
|
||||
self._column_parallel_weights,
|
||||
self._row_parallel_weights, tp_rank)
|
||||
498
vllm/model_executor/models/falcon.py
Normal file
498
vllm/model_executor/models/falcon.py
Normal file
@ -0,0 +1,498 @@
|
||||
# coding=utf-8
|
||||
# Adapted from
|
||||
# https://github.com/huggingface/transformers/blob/a5cc30d72ae2dc19af534e4b35c986cc28db1275/src/transformers/models/falcon/modeling_falcon.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2023 the Falcon authors and HuggingFace Inc. team. All rights
|
||||
# reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch Falcon model."""
|
||||
|
||||
import math
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import LayerNorm
|
||||
from transformers import FalconConfig as HF_FalconConfig
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.attention import (PagedAttention,
|
||||
PagedAttentionWithALiBi,
|
||||
PagedAttentionWithRoPE)
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.weight_utils import (convert_pyslice_to_tensor,
|
||||
hf_model_weights_iterator,
|
||||
load_tensor_parallel_weights)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear,
|
||||
reduce_from_tensor_model_parallel_region)
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.transformers_utils.configs import RWConfig
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
FalconConfig = Union[HF_FalconConfig, RWConfig]
|
||||
|
||||
|
||||
# NOTE(Hesslow): Unfortunately we did not fuse matmul and bias during
|
||||
# training, this means that there's one additional quantization to bfloat16
|
||||
# between the operations. In order not to degrade the quality of our HF-port,
|
||||
# we keep these characteristics in the final model.
|
||||
class FalconLinear(nn.Linear):
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = x @ self.weight.T
|
||||
if self.bias is None:
|
||||
return hidden_states
|
||||
return hidden_states + self.bias
|
||||
|
||||
|
||||
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
|
||||
closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
|
||||
base = torch.tensor(2**(-(2**-(math.log2(closest_power_of_2) - 3))),
|
||||
dtype=torch.float32)
|
||||
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
|
||||
slopes = torch.pow(base, powers)
|
||||
|
||||
if closest_power_of_2 != total_num_heads:
|
||||
extra_base = torch.tensor(
|
||||
2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
|
||||
dtype=torch.float32)
|
||||
num_remaining_heads = min(closest_power_of_2,
|
||||
total_num_heads - closest_power_of_2)
|
||||
extra_powers = torch.arange(1,
|
||||
1 + 2 * num_remaining_heads,
|
||||
2,
|
||||
dtype=torch.int32)
|
||||
slopes = torch.cat(
|
||||
[slopes, torch.pow(extra_base, extra_powers)], dim=0)
|
||||
|
||||
return slopes
|
||||
|
||||
|
||||
class FalconAttention(nn.Module):
|
||||
|
||||
def __init__(self, config: FalconConfig):
|
||||
super().__init__()
|
||||
|
||||
self.hidden_size = config.hidden_size
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
self.total_num_heads = config.num_attention_heads
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_size
|
||||
self.head_dim = self.hidden_size // self.total_num_heads
|
||||
assert self.head_dim * self.total_num_heads == self.hidden_size
|
||||
|
||||
self.new_decoder_architecture = config.new_decoder_architecture
|
||||
self.multi_query = config.multi_query
|
||||
|
||||
if self.new_decoder_architecture:
|
||||
self.total_num_kv_heads = config.num_kv_heads
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
self.num_kv_heads = self.total_num_kv_heads // tp_size
|
||||
self.query_key_value = ColumnParallelLinear(
|
||||
self.hidden_size,
|
||||
(self.total_num_heads + 2 * self.total_num_kv_heads) *
|
||||
self.head_dim,
|
||||
bias=config.bias,
|
||||
gather_output=False,
|
||||
perform_initialization=False,
|
||||
skip_bias_add=True,
|
||||
)
|
||||
elif self.multi_query:
|
||||
self.total_num_kv_heads = 1
|
||||
self.num_kv_heads = 1
|
||||
self.query = ColumnParallelLinear(
|
||||
self.hidden_size,
|
||||
self.total_num_heads * self.head_dim,
|
||||
bias=config.bias,
|
||||
gather_output=False,
|
||||
perform_initialization=False,
|
||||
skip_bias_add=True,
|
||||
)
|
||||
self.key_value = FalconLinear(self.hidden_size,
|
||||
2 * self.head_dim,
|
||||
bias=config.bias)
|
||||
else:
|
||||
self.total_num_kv_heads = self.total_num_heads
|
||||
self.num_kv_heads = self.num_heads
|
||||
self.query_key_value = ColumnParallelLinear(
|
||||
self.hidden_size,
|
||||
(self.total_num_heads + 2 * self.total_num_kv_heads) *
|
||||
self.head_dim,
|
||||
bias=config.bias,
|
||||
gather_output=False,
|
||||
perform_initialization=False,
|
||||
skip_bias_add=True,
|
||||
)
|
||||
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
|
||||
# Layer-wise attention scaling
|
||||
self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
|
||||
self.reduce_row_parallel_results = not (config.new_decoder_architecture
|
||||
or config.parallel_attn)
|
||||
self.dense = RowParallelLinear(
|
||||
self.hidden_size,
|
||||
self.hidden_size,
|
||||
bias=config.bias,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False,
|
||||
skip_bias_add=True,
|
||||
reduce_results=self.reduce_row_parallel_results)
|
||||
|
||||
self.use_rotary = config.rotary
|
||||
self.use_alibi = config.alibi
|
||||
assert not (self.use_rotary and self.use_alibi), (
|
||||
"Rotary and alibi are mutually exclusive.")
|
||||
|
||||
if self.use_rotary:
|
||||
# TODO(zhuohan): Pass in correct `max_position``
|
||||
self.attn = PagedAttentionWithRoPE(self.num_heads,
|
||||
self.head_dim,
|
||||
self.inv_norm_factor,
|
||||
rotary_dim=self.head_dim,
|
||||
num_kv_heads=self.num_kv_heads)
|
||||
elif self.use_alibi:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
head_start = tp_rank * self.num_heads
|
||||
head_end = (tp_rank + 1) * self.num_heads
|
||||
alibi_slopes = (_get_alibi_slopes(self.total_num_heads) *
|
||||
self.inv_norm_factor)
|
||||
alibi_slopes = alibi_slopes[head_start:head_end].tolist()
|
||||
self.attn = PagedAttentionWithALiBi(self.num_heads,
|
||||
self.head_dim,
|
||||
self.inv_norm_factor,
|
||||
alibi_slopes,
|
||||
num_kv_heads=self.num_kv_heads)
|
||||
else:
|
||||
self.attn = PagedAttention(self.num_heads,
|
||||
self.head_dim,
|
||||
scale=self.inv_norm_factor,
|
||||
num_kv_heads=self.num_kv_heads)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
if not self.new_decoder_architecture and self.multi_query:
|
||||
q, bias = self.query(hidden_states)
|
||||
if bias is not None:
|
||||
q += bias
|
||||
kv = self.key_value(hidden_states)
|
||||
k, v = kv.split([self.kv_size, self.kv_size], dim=-1)
|
||||
else:
|
||||
qkv, bias = self.query_key_value(hidden_states)
|
||||
if bias is not None:
|
||||
qkv += bias
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
|
||||
dim=-1)
|
||||
k_cache, v_cache = kv_cache
|
||||
if self.use_rotary:
|
||||
attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
|
||||
input_metadata, cache_event)
|
||||
else:
|
||||
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
|
||||
cache_event)
|
||||
attn_output, bias = self.dense(attn_output)
|
||||
return attn_output, bias
|
||||
|
||||
|
||||
class FalconMLP(nn.Module):
|
||||
|
||||
def __init__(self, config: FalconConfig):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
|
||||
self.dense_h_to_4h = ColumnParallelLinear(hidden_size,
|
||||
4 * hidden_size,
|
||||
bias=config.bias,
|
||||
gather_output=False,
|
||||
perform_initialization=False,
|
||||
skip_bias_add=True)
|
||||
self.act = nn.GELU()
|
||||
self.reduce_row_parallel_results = not (config.new_decoder_architecture
|
||||
or config.parallel_attn)
|
||||
self.dense_4h_to_h = RowParallelLinear(
|
||||
4 * hidden_size,
|
||||
hidden_size,
|
||||
bias=config.bias,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False,
|
||||
skip_bias_add=True,
|
||||
reduce_results=self.reduce_row_parallel_results)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# NOTE(zhuohan): Following huggingface, we do not fuse bias add here.
|
||||
x, bias = self.dense_h_to_4h(x)
|
||||
if bias is not None:
|
||||
x += bias
|
||||
x = self.act(x)
|
||||
x, bias = self.dense_4h_to_h(x)
|
||||
return x, bias
|
||||
|
||||
|
||||
class FalconDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(self, config: FalconConfig):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.self_attention = FalconAttention(config)
|
||||
self.mlp = FalconMLP(config)
|
||||
self.config = config
|
||||
|
||||
if config.new_decoder_architecture:
|
||||
# The layer norm before self-attention
|
||||
self.ln_attn = LayerNorm(hidden_size,
|
||||
eps=config.layer_norm_epsilon)
|
||||
# The layer norm before the MLP
|
||||
self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
else:
|
||||
self.input_layernorm = LayerNorm(hidden_size,
|
||||
eps=config.layer_norm_epsilon)
|
||||
if not config.parallel_attn:
|
||||
self.post_attention_layernorm = LayerNorm(
|
||||
hidden_size, eps=config.layer_norm_epsilon)
|
||||
|
||||
self.reduce_row_parallel_results = not (config.new_decoder_architecture
|
||||
or config.parallel_attn)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
):
|
||||
residual = hidden_states
|
||||
|
||||
if self.config.new_decoder_architecture:
|
||||
attention_layernorm_out = self.ln_attn(hidden_states)
|
||||
mlp_layernorm_out = self.ln_mlp(hidden_states)
|
||||
else:
|
||||
attention_layernorm_out = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self attention.
|
||||
attention_output, attention_bias = self.self_attention(
|
||||
positions=positions,
|
||||
hidden_states=attention_layernorm_out,
|
||||
kv_cache=kv_cache,
|
||||
input_metadata=input_metadata,
|
||||
cache_event=cache_event,
|
||||
)
|
||||
if self.reduce_row_parallel_results and attention_bias is not None:
|
||||
attention_output += attention_bias
|
||||
|
||||
if not self.config.new_decoder_architecture:
|
||||
if self.config.parallel_attn:
|
||||
mlp_layernorm_out = attention_layernorm_out
|
||||
else:
|
||||
residual += attention_output
|
||||
mlp_layernorm_out = self.post_attention_layernorm(residual)
|
||||
|
||||
# MLP.
|
||||
mlp_output, mlp_bias = self.mlp(mlp_layernorm_out)
|
||||
if self.reduce_row_parallel_results and mlp_bias is not None:
|
||||
mlp_output += mlp_bias
|
||||
|
||||
if not self.reduce_row_parallel_results:
|
||||
# When MLP and Attention layers are parallel, we can use
|
||||
# only one all-reduce operator to reduce the results from
|
||||
# both MLP and Attention layers.
|
||||
mlp_output += attention_output
|
||||
mlp_output = reduce_from_tensor_model_parallel_region(mlp_output)
|
||||
if attention_bias is not None:
|
||||
mlp_output += attention_bias
|
||||
if mlp_bias is not None:
|
||||
mlp_output += mlp_bias
|
||||
|
||||
output = mlp_output + residual
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class FalconModel(nn.Module):
|
||||
|
||||
def __init__(self, config: FalconConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.use_alibi = config.alibi
|
||||
|
||||
# Embedding + LN Embedding
|
||||
self.word_embeddings = VocabParallelEmbedding(
|
||||
config.vocab_size, self.embed_dim, perform_initialization=False)
|
||||
|
||||
# Transformer blocks
|
||||
self.h = nn.ModuleList([
|
||||
FalconDecoderLayer(config) for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
|
||||
# Final Layer Norm
|
||||
self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.word_embeddings(input_ids)
|
||||
for i in range(len(self.h)):
|
||||
if cache_events is None:
|
||||
cache_event = None
|
||||
else:
|
||||
cache_event = cache_events[i]
|
||||
layer = self.h[i]
|
||||
hidden_states = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i],
|
||||
input_metadata,
|
||||
cache_event,
|
||||
)
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FalconForCausalLM(nn.Module):
|
||||
|
||||
def __init__(self, config: FalconConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.transformer = FalconModel(config)
|
||||
self.lm_head = ColumnParallelLinear(config.hidden_size,
|
||||
config.vocab_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> SamplerOutput:
|
||||
hidden_states = self.transformer(
|
||||
input_ids,
|
||||
positions,
|
||||
kv_caches,
|
||||
input_metadata,
|
||||
cache_events,
|
||||
)
|
||||
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
||||
input_metadata)
|
||||
|
||||
return next_tokens
|
||||
|
||||
_column_parallel_weights = [
|
||||
"word_embeddings.weight", "lm_head.weight", "dense_h_to_4h.weight",
|
||||
"dense_h_to_4h.bias"
|
||||
]
|
||||
_row_parallel_weights = ["dense.weight", "dense_4h_to_h.weight"]
|
||||
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto"):
|
||||
tp_size = (get_tensor_model_parallel_world_size())
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
hidden_size = self.config.hidden_size
|
||||
total_num_heads = self.config.num_attention_heads
|
||||
num_heads = total_num_heads // tp_size
|
||||
head_size = hidden_size // total_num_heads
|
||||
head_start = tp_rank * num_heads
|
||||
head_end = (tp_rank + 1) * num_heads
|
||||
if self.config.new_decoder_architecture:
|
||||
total_num_kv_heads = self.config.num_kv_heads
|
||||
num_kv_heads = total_num_kv_heads // tp_size
|
||||
separated_q_kv = False
|
||||
kv_head_start = tp_rank * num_kv_heads
|
||||
kv_head_end = (tp_rank + 1) * num_kv_heads
|
||||
elif self.config.multi_query:
|
||||
total_num_kv_heads = 1
|
||||
num_kv_heads = 1
|
||||
separated_q_kv = True
|
||||
kv_head_start = 0
|
||||
kv_head_end = 1
|
||||
else:
|
||||
total_num_kv_heads = total_num_heads
|
||||
num_kv_heads = total_num_kv_heads // tp_size
|
||||
separated_q_kv = False
|
||||
kv_head_start = tp_rank * num_kv_heads
|
||||
kv_head_end = (tp_rank + 1) * num_kv_heads
|
||||
num_query_heads_per_kv_head = total_num_heads // total_num_kv_heads
|
||||
state_dict = self.state_dict()
|
||||
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format):
|
||||
if "query_key_value" in name:
|
||||
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
||||
loaded_weight_size = loaded_weight.size()
|
||||
loaded_weight = loaded_weight.view(
|
||||
total_num_kv_heads, num_query_heads_per_kv_head + 2,
|
||||
head_size, *loaded_weight_size[1:])
|
||||
|
||||
wq = loaded_weight[:, :-2].reshape(-1, *loaded_weight_size[1:])
|
||||
wk = loaded_weight[:, [-2]].reshape(-1,
|
||||
*loaded_weight_size[1:])
|
||||
wv = loaded_weight[:, [-1]].reshape(-1,
|
||||
*loaded_weight_size[1:])
|
||||
|
||||
wq = wq[head_size * head_start:head_size * head_end]
|
||||
wk = wk[head_size * kv_head_start:head_size * kv_head_end]
|
||||
wv = wv[head_size * kv_head_start:head_size * kv_head_end]
|
||||
|
||||
if separated_q_kv:
|
||||
loaded_weight_q = wq
|
||||
loaded_weight_kv = torch.cat([wk, wv], dim=0)
|
||||
q_weight_name = name.replace("query_key_value", "query")
|
||||
kv_weight_name = name.replace("query_key_value",
|
||||
"key_value")
|
||||
load_tensor_parallel_weights(state_dict[q_weight_name],
|
||||
loaded_weight_q,
|
||||
q_weight_name,
|
||||
self._column_parallel_weights,
|
||||
self._row_parallel_weights,
|
||||
tp_rank)
|
||||
load_tensor_parallel_weights(state_dict[kv_weight_name],
|
||||
loaded_weight_kv,
|
||||
kv_weight_name,
|
||||
self._column_parallel_weights,
|
||||
self._row_parallel_weights,
|
||||
tp_rank)
|
||||
continue
|
||||
else:
|
||||
loaded_weight = torch.cat([wq, wk, wv], dim=0)
|
||||
|
||||
param = state_dict[name]
|
||||
load_tensor_parallel_weights(param, loaded_weight, name,
|
||||
self._column_parallel_weights,
|
||||
self._row_parallel_weights, tp_rank)
|
||||
@ -1,5 +1,6 @@
|
||||
# coding=utf-8
|
||||
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
|
||||
# Adapted from
|
||||
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
@ -20,7 +21,7 @@
|
||||
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
||||
InputMetadata to extract the original 2D shape of the input.
|
||||
"""
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -30,13 +31,14 @@ from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.attention import PagedAttention
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
|
||||
load_tensor_parallel_weights)
|
||||
from vllm.model_executor.weight_utils import (
|
||||
convert_pyslice_to_tensor, hf_model_weights_iterator,
|
||||
load_padded_tensor_parallel_vocab, load_tensor_parallel_weights)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
||||
from vllm.sequence import SequenceOutputs
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
@ -47,19 +49,25 @@ class GPT2Attention(nn.Module):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
total_num_heads = config.num_attention_heads
|
||||
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
|
||||
tensor_model_parallel_world_size = (
|
||||
get_tensor_model_parallel_world_size())
|
||||
assert total_num_heads % tensor_model_parallel_world_size == 0
|
||||
self.num_heads = total_num_heads // tensor_model_parallel_world_size
|
||||
self.head_dim = self.hidden_size // total_num_heads
|
||||
self.scale = self.head_dim**-0.5
|
||||
|
||||
self.c_attn = ColumnParallelLinear(self.hidden_size, 3 * self.hidden_size,
|
||||
bias=True, gather_output=False,
|
||||
self.c_attn = ColumnParallelLinear(self.hidden_size,
|
||||
3 * self.hidden_size,
|
||||
bias=True,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.c_proj = RowParallelLinear(self.hidden_size, self.hidden_size,
|
||||
bias=True, input_is_parallel=True,
|
||||
self.c_proj = RowParallelLinear(self.hidden_size,
|
||||
self.hidden_size,
|
||||
bias=True,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
self.attn = PagedAttention(self.num_heads, self.head_dim,
|
||||
self.attn = PagedAttention(self.num_heads,
|
||||
self.head_dim,
|
||||
scale=self.scale)
|
||||
|
||||
def forward(
|
||||
@ -72,8 +80,8 @@ class GPT2Attention(nn.Module):
|
||||
qkv, _ = self.c_attn(hidden_states)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
key_cache, value_cache = kv_cache
|
||||
attn_output = self.attn(
|
||||
q, k, v, key_cache, value_cache, input_metadata, cache_event)
|
||||
attn_output = self.attn(q, k, v, key_cache, value_cache,
|
||||
input_metadata, cache_event)
|
||||
attn_output, _ = self.c_proj(attn_output)
|
||||
return attn_output
|
||||
|
||||
@ -87,11 +95,15 @@ class GPT2MLP(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
self.c_fc = ColumnParallelLinear(hidden_size, intermediate_size,
|
||||
bias=True, gather_output=False,
|
||||
self.c_fc = ColumnParallelLinear(hidden_size,
|
||||
intermediate_size,
|
||||
bias=True,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.c_proj = RowParallelLinear(intermediate_size, hidden_size,
|
||||
bias=True, input_is_parallel=True,
|
||||
self.c_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=True,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
self.act = get_act_fn(config.activation_function)
|
||||
|
||||
@ -107,7 +119,8 @@ class GPT2Block(nn.Module):
|
||||
def __init__(self, config: GPT2Config):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
|
||||
inner_dim = (config.n_inner if config.n_inner is not None else 4 *
|
||||
hidden_size)
|
||||
|
||||
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.attn = GPT2Attention(config)
|
||||
@ -145,9 +158,9 @@ class GPT2Model(nn.Module):
|
||||
def __init__(self, config: GPT2Config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
assert config.add_cross_attention == False
|
||||
assert config.scale_attn_by_inverse_layer_idx == False
|
||||
assert config.reorder_and_upcast_attn == False
|
||||
assert not config.add_cross_attention
|
||||
assert not config.scale_attn_by_inverse_layer_idx
|
||||
assert not config.reorder_and_upcast_attn
|
||||
self.embed_dim = config.hidden_size
|
||||
|
||||
# Optimization: While the vocab size of GPT-2 is 50257, we extend it
|
||||
@ -180,8 +193,8 @@ class GPT2Model(nn.Module):
|
||||
else:
|
||||
cache_event = cache_events[i]
|
||||
layer = self.h[i]
|
||||
hidden_states = layer(
|
||||
hidden_states, kv_caches[i], input_metadata, cache_event)
|
||||
hidden_states = layer(hidden_states, kv_caches[i], input_metadata,
|
||||
cache_event)
|
||||
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
return hidden_states
|
||||
@ -205,35 +218,41 @@ class GPT2LMHeadModel(nn.Module):
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> Dict[int, SequenceOutputs]:
|
||||
hidden_states = self.transformer(
|
||||
input_ids, positions, kv_caches, input_metadata, cache_events)
|
||||
next_tokens = self.sampler(
|
||||
self.lm_head_weight, hidden_states, input_metadata)
|
||||
) -> SamplerOutput:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
input_metadata, cache_events)
|
||||
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
|
||||
input_metadata)
|
||||
return next_tokens
|
||||
|
||||
_column_parallel_weights = ["wte.weight", "c_fc.weight", "c_fc.bias"]
|
||||
_column_parallel_weights = ["c_fc.weight", "c_fc.bias"]
|
||||
_row_parallel_weights = ["c_proj.weight"]
|
||||
|
||||
def load_weights(self, model_name_or_path: str,
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
use_np_cache: bool = False):
|
||||
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
|
||||
load_format: str = "auto"):
|
||||
tensor_model_parallel_world_size = (
|
||||
get_tensor_model_parallel_world_size())
|
||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||
state_dict = self.state_dict()
|
||||
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, use_np_cache):
|
||||
model_name_or_path, cache_dir, load_format):
|
||||
if "lm_head.weight" in name:
|
||||
# GPT-2 ties the weights of the embedding layer and the final
|
||||
# linear layer.
|
||||
continue
|
||||
if ".attn.bias" in name:
|
||||
if ".attn.bias" in name or ".attn.masked_bias" in name:
|
||||
# Skip attention mask.
|
||||
# NOTE: "c_attn.bias" should not be skipped.
|
||||
continue
|
||||
|
||||
if not name.startswith("transformer."):
|
||||
name = "transformer." + name
|
||||
|
||||
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
||||
|
||||
# The HF's GPT-2 implementation uses Conv1D instead of Linear.
|
||||
# Because of this, we need to transpose the weights.
|
||||
for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
|
||||
@ -245,17 +264,16 @@ class GPT2LMHeadModel(nn.Module):
|
||||
param = state_dict[name]
|
||||
|
||||
if name == "transformer.wte.weight":
|
||||
# Consider padding in the vocab size.
|
||||
padded_vocab_size = param.shape[0] * tensor_model_parallel_world_size
|
||||
num_extra_rows = padded_vocab_size - self.config.vocab_size
|
||||
extra_rows = torch.empty(num_extra_rows, loaded_weight.shape[1])
|
||||
extra_rows = extra_rows.to(loaded_weight)
|
||||
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
|
||||
load_padded_tensor_parallel_vocab(param, loaded_weight,
|
||||
tensor_model_parallel_rank)
|
||||
continue
|
||||
|
||||
# For the fused QKV linear layer, manually shard the weights.
|
||||
if "c_attn" in name:
|
||||
# GPT-2's fused QKV has the shape of [3 * num_heads * head_size, hidden_size].
|
||||
# When tensor parallelism is used, we shard the weights along the head dimension.
|
||||
# GPT-2's fused QKV has the shape of
|
||||
# [3 * num_heads * head_size, hidden_size].
|
||||
# When tensor parallelism is used, we shard the weights along
|
||||
# the head dimension.
|
||||
total_num_heads = self.config.num_attention_heads
|
||||
hidden_size = self.config.hidden_size
|
||||
head_size = hidden_size // total_num_heads
|
||||
@ -264,11 +282,13 @@ class GPT2LMHeadModel(nn.Module):
|
||||
head_end = (tensor_model_parallel_rank + 1) * num_heads
|
||||
|
||||
if name.endswith(".weight"):
|
||||
loaded_weight = loaded_weight.view(3, total_num_heads, head_size, hidden_size)
|
||||
loaded_weight = loaded_weight.view(3, total_num_heads,
|
||||
head_size, hidden_size)
|
||||
loaded_weight = loaded_weight[:, head_start:head_end, :, :]
|
||||
loaded_weight = loaded_weight.reshape(-1, hidden_size)
|
||||
elif name.endswith(".bias"):
|
||||
loaded_weight = loaded_weight.view(3, total_num_heads, head_size)
|
||||
loaded_weight = loaded_weight.view(3, total_num_heads,
|
||||
head_size)
|
||||
loaded_weight = loaded_weight[:, head_start:head_end, :]
|
||||
loaded_weight = loaded_weight.reshape(-1)
|
||||
else:
|
||||
|
||||
340
vllm/model_executor/models/gpt_bigcode.py
Normal file
340
vllm/model_executor/models/gpt_bigcode.py
Normal file
@ -0,0 +1,340 @@
|
||||
# coding=utf-8
|
||||
# Adapted from
|
||||
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2023 CTranslate2, and Michael Feil
|
||||
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only GPTBigCode model compatible with HuggingFace weights.
|
||||
|
||||
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
||||
InputMetadata to extract the original 2D shape of the input.
|
||||
"""
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import GPTBigCodeConfig
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.attention import PagedAttention
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.weight_utils import (
|
||||
convert_pyslice_to_tensor, hf_model_weights_iterator,
|
||||
load_padded_tensor_parallel_vocab, load_tensor_parallel_weights)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
class GPTBigCodeAttention(nn.Module):
|
||||
|
||||
def __init__(self, config: GPTBigCodeConfig):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
total_num_heads = config.num_attention_heads
|
||||
self.tensor_model_parallel_world_size = (
|
||||
get_tensor_model_parallel_world_size())
|
||||
assert total_num_heads % self.tensor_model_parallel_world_size == 0
|
||||
self.num_heads = (total_num_heads //
|
||||
self.tensor_model_parallel_world_size)
|
||||
self.head_dim = self.hidden_size // total_num_heads
|
||||
self.scale = self.head_dim**-0.5
|
||||
|
||||
self.multi_query = config.multi_query
|
||||
if self.multi_query:
|
||||
self.num_kv_heads = 1
|
||||
self.kv_dim = self.head_dim
|
||||
self.c_attn_q = ColumnParallelLinear(self.hidden_size,
|
||||
self.hidden_size,
|
||||
bias=True,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.c_attn_kv = nn.Linear(self.hidden_size,
|
||||
2 * self.kv_dim,
|
||||
bias=True)
|
||||
else:
|
||||
self.num_kv_heads = self.num_heads
|
||||
self.kv_dim = self.num_kv_heads * self.head_dim
|
||||
self.c_attn = ColumnParallelLinear(self.hidden_size,
|
||||
self.hidden_size +
|
||||
2 * self.kv_dim,
|
||||
bias=True,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
|
||||
self.c_proj = RowParallelLinear(self.hidden_size,
|
||||
self.hidden_size,
|
||||
bias=True,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
self.attn = PagedAttention(self.num_heads,
|
||||
self.head_dim,
|
||||
scale=self.scale,
|
||||
num_kv_heads=self.num_kv_heads)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
if self.multi_query:
|
||||
q, _ = self.c_attn_q(hidden_states)
|
||||
kv = self.c_attn_kv(hidden_states)
|
||||
k, v = kv.split([self.kv_dim, self.kv_dim], dim=-1)
|
||||
else:
|
||||
qkv, _ = self.c_attn(hidden_states)
|
||||
q, k, v = qkv.split([
|
||||
self.hidden_size // self.tensor_model_parallel_world_size,
|
||||
self.kv_dim, self.kv_dim
|
||||
],
|
||||
dim=-1)
|
||||
key_cache, value_cache = kv_cache
|
||||
attn_output = self.attn(q, k, v, key_cache, value_cache,
|
||||
input_metadata, cache_event)
|
||||
attn_output, _ = self.c_proj(attn_output)
|
||||
return attn_output
|
||||
|
||||
|
||||
class GPTBigMLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
intermediate_size: int,
|
||||
config: GPTBigCodeConfig,
|
||||
):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
self.c_fc = ColumnParallelLinear(hidden_size,
|
||||
intermediate_size,
|
||||
bias=True,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.c_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=True,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
self.act = get_act_fn(config.activation_function)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states, _ = self.c_fc(hidden_states)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states, _ = self.c_proj(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class GPTBigCodeBlock(nn.Module):
|
||||
|
||||
def __init__(self, config: GPTBigCodeConfig):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
inner_dim = (config.n_inner if config.n_inner is not None else 4 *
|
||||
hidden_size)
|
||||
|
||||
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.attn = GPTBigCodeAttention(config)
|
||||
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.mlp = GPTBigMLP(inner_dim, config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
hidden_states = self.ln_1(hidden_states)
|
||||
attn_output = self.attn(
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
input_metadata=input_metadata,
|
||||
cache_event=cache_event,
|
||||
)
|
||||
# residual connection
|
||||
hidden_states = attn_output + residual
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.ln_2(hidden_states)
|
||||
feed_forward_hidden_states = self.mlp(hidden_states)
|
||||
# residual connection
|
||||
hidden_states = residual + feed_forward_hidden_states
|
||||
return hidden_states
|
||||
|
||||
|
||||
class GPTBigCodeModel(nn.Module):
|
||||
|
||||
def __init__(self, config: GPTBigCodeConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
assert not config.add_cross_attention
|
||||
|
||||
self.embed_dim = config.hidden_size
|
||||
|
||||
# Optimization: While the vocab size of GPT-2 is 50257, we extend it
|
||||
# to 50304 in order to make it divisible by 64.
|
||||
# This improves performance since GPUs are faster if the dimension
|
||||
# is divisible by 64. In addition, it allows us to shard the embedding
|
||||
# layer across 2, 4, 8, or more GPUs.
|
||||
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
||||
self.wte = VocabParallelEmbedding(vocab_size, self.embed_dim)
|
||||
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
|
||||
self.h = nn.ModuleList(
|
||||
[GPTBigCodeBlock(config) for _ in range(config.num_hidden_layers)])
|
||||
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.wte(input_ids)
|
||||
position_embeds = self.wpe(position_ids)
|
||||
hidden_states = inputs_embeds + position_embeds
|
||||
|
||||
for i in range(len(self.h)):
|
||||
if cache_events is None:
|
||||
cache_event = None
|
||||
else:
|
||||
cache_event = cache_events[i]
|
||||
layer = self.h[i]
|
||||
hidden_states = layer(hidden_states, kv_caches[i], input_metadata,
|
||||
cache_event)
|
||||
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class GPTBigCodeForCausalLM(nn.Module):
|
||||
|
||||
def __init__(self, config: GPTBigCodeConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.transformer = GPTBigCodeModel(config)
|
||||
# TODO(zhuohan): create a new weight after implementing pipeline
|
||||
# parallelism
|
||||
self.lm_head_weight = self.transformer.wte.weight
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> SamplerOutput:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
input_metadata, cache_events)
|
||||
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
|
||||
input_metadata)
|
||||
return next_tokens
|
||||
|
||||
_column_parallel_weights = ["c_fc.weight", "c_fc.bias"]
|
||||
_row_parallel_weights = ["c_proj.weight"]
|
||||
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto"):
|
||||
tensor_model_parallel_world_size = (
|
||||
get_tensor_model_parallel_world_size())
|
||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||
state_dict = self.state_dict()
|
||||
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format):
|
||||
if "lm_head.weight" in name:
|
||||
# GPT-2 ties the weights of the embedding layer and the final
|
||||
# linear layer.
|
||||
continue
|
||||
if ".attn.bias" in name:
|
||||
# Skip attention mask.
|
||||
# NOTE: "c_attn.bias" should not be skipped.
|
||||
continue
|
||||
|
||||
if not name.startswith("transformer."):
|
||||
name = "transformer." + name
|
||||
|
||||
# For the fused QKV linear layer, manually shard the weights.
|
||||
if "c_attn" in name:
|
||||
# GPT-2's fused QKV has the shape of
|
||||
# [3 * num_heads * head_size, hidden_size].
|
||||
# When tensor parallelism is used, we shard the weights along
|
||||
# the head dimension.
|
||||
total_num_heads = self.config.num_attention_heads
|
||||
total_num_kv_heads = (1 if self.config.multi_query else
|
||||
total_num_heads)
|
||||
hidden_size = self.config.hidden_size
|
||||
head_size = hidden_size // total_num_heads
|
||||
total_kv_size = head_size * total_num_kv_heads
|
||||
num_heads = total_num_heads // tensor_model_parallel_world_size
|
||||
head_start = tensor_model_parallel_rank * num_heads
|
||||
head_end = (tensor_model_parallel_rank + 1) * num_heads
|
||||
|
||||
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
||||
wq, wk, wv = torch.split(
|
||||
loaded_weight, [hidden_size, total_kv_size, total_kv_size],
|
||||
dim=0)
|
||||
|
||||
wq = wq[head_size * head_start:head_size * head_end]
|
||||
if not self.config.multi_query:
|
||||
# Split the heads when using normal multi-head attention
|
||||
wk = wk[head_size * head_start:head_size * head_end]
|
||||
wv = wv[head_size * head_start:head_size * head_end]
|
||||
loaded_weight = torch.cat([wq, wk, wv], dim=0)
|
||||
else:
|
||||
# For multi-query attention, we split the query
|
||||
# but replicate the key and value.
|
||||
loaded_weight_q = wq
|
||||
loaded_weight_kv = torch.cat([wk, wv], dim=0)
|
||||
q_weight_name = name.replace("c_attn", "c_attn_q")
|
||||
kv_weight_name = name.replace("c_attn", "c_attn_kv")
|
||||
load_tensor_parallel_weights(state_dict[q_weight_name],
|
||||
loaded_weight_q,
|
||||
q_weight_name,
|
||||
self._column_parallel_weights,
|
||||
self._row_parallel_weights,
|
||||
tensor_model_parallel_rank)
|
||||
load_tensor_parallel_weights(state_dict[kv_weight_name],
|
||||
loaded_weight_kv,
|
||||
kv_weight_name,
|
||||
self._column_parallel_weights,
|
||||
self._row_parallel_weights,
|
||||
tensor_model_parallel_rank)
|
||||
continue
|
||||
|
||||
param = state_dict[name]
|
||||
|
||||
if name == "transformer.wte.weight":
|
||||
load_padded_tensor_parallel_vocab(param, loaded_weight,
|
||||
tensor_model_parallel_rank)
|
||||
continue
|
||||
|
||||
load_tensor_parallel_weights(param, loaded_weight, name,
|
||||
self._column_parallel_weights,
|
||||
self._row_parallel_weights,
|
||||
tensor_model_parallel_rank)
|
||||
254
vllm/model_executor/models/gpt_j.py
Normal file
254
vllm/model_executor/models/gpt_j.py
Normal file
@ -0,0 +1,254 @@
|
||||
# coding=utf-8
|
||||
# Adapted from
|
||||
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gptj/modeling_gptj.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2021 The EleutherAI and HuggingFace Teams. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only GPT-J model compatible with HuggingFace weights.
|
||||
|
||||
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
||||
InputMetadata to extract the original 2D shape of the input.
|
||||
"""
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import GPTJConfig
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
|
||||
load_tensor_parallel_weights)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
class GPTJAttention(nn.Module):
|
||||
|
||||
def __init__(self, config: GPTJConfig):
|
||||
super().__init__()
|
||||
self.total_num_heads = config.num_attention_heads
|
||||
self.hidden_size = config.hidden_size
|
||||
self.head_size = self.hidden_size // self.total_num_heads
|
||||
|
||||
self.qkv_proj = ColumnParallelLinear(config.hidden_size,
|
||||
3 * config.hidden_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.out_proj = RowParallelLinear(config.hidden_size,
|
||||
config.hidden_size,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
|
||||
tp_world_size = get_tensor_model_parallel_world_size()
|
||||
assert self.total_num_heads % tp_world_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_world_size
|
||||
|
||||
scaling = self.head_size**-0.5
|
||||
assert getattr(config, "rotary", True)
|
||||
assert config.rotary_dim % 2 == 0
|
||||
self.attn = PagedAttentionWithRoPE(self.num_heads,
|
||||
self.head_size,
|
||||
scaling,
|
||||
config.rotary_dim,
|
||||
is_neox_style=False)
|
||||
self.warmup = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
position_ids: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
k_cache, v_cache = kv_cache
|
||||
attn_output = self.attn(position_ids, q, k, v, k_cache, v_cache,
|
||||
input_metadata, cache_event)
|
||||
attn_output, _ = self.out_proj(attn_output)
|
||||
return attn_output
|
||||
|
||||
|
||||
class GPTJMLP(nn.Module):
|
||||
|
||||
def __init__(self, intermediate_size: int, config: GPTJConfig):
|
||||
super().__init__()
|
||||
hidden_size = config.n_embd
|
||||
self.fc_in = ColumnParallelLinear(hidden_size,
|
||||
intermediate_size,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.fc_out = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
self.act = get_act_fn(config.activation_function)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states, _ = self.fc_in(hidden_states)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states, _ = self.fc_out(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class GPTJBlock(nn.Module):
|
||||
|
||||
def __init__(self, config: GPTJConfig):
|
||||
super().__init__()
|
||||
if config.n_inner is None:
|
||||
inner_dim = 4 * config.n_embd
|
||||
else:
|
||||
inner_dim = config.n_inner
|
||||
self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
||||
self.attn = GPTJAttention(config)
|
||||
self.mlp = GPTJMLP(inner_dim, config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
position_ids: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
hidden_states = self.ln_1(hidden_states)
|
||||
attn_output = self.attn(
|
||||
position_ids=position_ids,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
input_metadata=input_metadata,
|
||||
cache_event=cache_event,
|
||||
)
|
||||
mlp_output = self.mlp(hidden_states)
|
||||
hidden_states = attn_output + mlp_output + residual
|
||||
return hidden_states
|
||||
|
||||
|
||||
class GPTJModel(nn.Module):
|
||||
|
||||
def __init__(self, config: GPTJConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.n_embd
|
||||
self.wte = VocabParallelEmbedding(config.vocab_size,
|
||||
self.embed_dim,
|
||||
perform_initialization=False)
|
||||
self.h = nn.ModuleList(
|
||||
[GPTJBlock(config) for _ in range(config.n_layer)])
|
||||
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.wte(input_ids)
|
||||
for i in range(len(self.h)):
|
||||
if cache_events is None:
|
||||
cache_event = None
|
||||
else:
|
||||
cache_event = cache_events[i]
|
||||
layer = self.h[i]
|
||||
hidden_states = layer(
|
||||
position_ids,
|
||||
hidden_states,
|
||||
kv_caches[i],
|
||||
input_metadata,
|
||||
cache_event,
|
||||
)
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class GPTJForCausalLM(nn.Module):
|
||||
|
||||
def __init__(self, config: GPTJConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
assert not config.tie_word_embeddings
|
||||
self.transformer = GPTJModel(config)
|
||||
self.lm_head = ColumnParallelLinear(config.n_embd,
|
||||
config.vocab_size,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> SamplerOutput:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
input_metadata, cache_events)
|
||||
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
||||
input_metadata, self.lm_head.bias)
|
||||
return next_tokens
|
||||
|
||||
_column_parallel_weights = [
|
||||
"wte.weight", "fc_in.weight", "fc_in.bias", "lm_head.weight",
|
||||
"lm_head.bias"
|
||||
]
|
||||
_row_parallel_weights = ["out_proj.weight", "fc_out.weight"]
|
||||
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto"):
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
state_dict = self.state_dict()
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format):
|
||||
if "attn.bias" in name or "attn.masked_bias" in name:
|
||||
continue
|
||||
|
||||
is_attention_weight = False
|
||||
for stride_id, att_weight_name in enumerate(
|
||||
["q_proj", "k_proj", "v_proj"]):
|
||||
if att_weight_name not in name:
|
||||
continue
|
||||
param = state_dict[name.replace(att_weight_name, "qkv_proj")]
|
||||
shard_size = param.shape[1]
|
||||
loaded_weight = loaded_weight[shard_size * tp_rank:shard_size *
|
||||
(tp_rank + 1)]
|
||||
param_slice = param.data[shard_size * stride_id:shard_size *
|
||||
(stride_id + 1)]
|
||||
assert param_slice.shape == loaded_weight.shape
|
||||
param_slice.copy_(loaded_weight)
|
||||
is_attention_weight = True
|
||||
break
|
||||
if is_attention_weight:
|
||||
continue
|
||||
|
||||
param = state_dict[name]
|
||||
load_tensor_parallel_weights(param, loaded_weight, name,
|
||||
self._column_parallel_weights,
|
||||
self._row_parallel_weights, tp_rank)
|
||||
@ -1,5 +1,6 @@
|
||||
# coding=utf-8
|
||||
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt_neox/modeling_gpt_neox.py
|
||||
# Adapted from
|
||||
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt_neox/modeling_gpt_neox.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2022 EleutherAI The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
@ -19,7 +20,7 @@
|
||||
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
||||
InputMetadata to extract the original 2D shape of the input.
|
||||
"""
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -35,7 +36,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
||||
from vllm.sequence import SequenceOutputs
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
@ -48,15 +49,19 @@ class GPTNeoXAttention(nn.Module):
|
||||
self.hidden_size = config.hidden_size
|
||||
self.head_size = self.hidden_size // self.total_num_heads
|
||||
|
||||
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
|
||||
tensor_model_parallel_world_size = (
|
||||
get_tensor_model_parallel_world_size())
|
||||
assert self.total_num_heads % tensor_model_parallel_world_size == 0
|
||||
self.num_heads = self.total_num_heads // tensor_model_parallel_world_size
|
||||
self.num_heads = (self.total_num_heads //
|
||||
tensor_model_parallel_world_size)
|
||||
|
||||
self.query_key_value = ColumnParallelLinear(config.hidden_size,
|
||||
self.query_key_value = ColumnParallelLinear(
|
||||
config.hidden_size,
|
||||
3 * config.hidden_size,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.dense = RowParallelLinear(config.hidden_size, config.hidden_size,
|
||||
self.dense = RowParallelLinear(config.hidden_size,
|
||||
config.hidden_size,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
|
||||
@ -75,11 +80,10 @@ class GPTNeoXAttention(nn.Module):
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.query_key_value(hidden_states)
|
||||
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
k_cache, v_cache = kv_cache
|
||||
attn_output = self.attn(
|
||||
position_ids, q, k, v, k_cache, v_cache, input_metadata, cache_event)
|
||||
attn_output = self.attn(position_ids, q, k, v, k_cache, v_cache,
|
||||
input_metadata, cache_event)
|
||||
output, _ = self.dense(attn_output)
|
||||
return output
|
||||
|
||||
@ -92,7 +96,8 @@ class GPTNeoXMLP(nn.Module):
|
||||
config.intermediate_size,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.dense_4h_to_h = RowParallelLinear(config.intermediate_size, config.hidden_size,
|
||||
self.dense_4h_to_h = RowParallelLinear(config.intermediate_size,
|
||||
config.hidden_size,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
self.act = get_act_fn(config.hidden_act)
|
||||
@ -109,8 +114,10 @@ class GPTNeoXLayer(nn.Module):
|
||||
def __init__(self, config: GPTNeoXConfig):
|
||||
super().__init__()
|
||||
self.use_parallel_residual = config.use_parallel_residual
|
||||
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.input_layernorm = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
self.attention = GPTNeoXAttention(config)
|
||||
self.mlp = GPTNeoXMLP(config)
|
||||
|
||||
@ -154,10 +161,13 @@ class GPTNeoXModel(nn.Module):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
self.embed_in = VocabParallelEmbedding(config.vocab_size, config.hidden_size,
|
||||
self.embed_in = VocabParallelEmbedding(config.vocab_size,
|
||||
config.hidden_size,
|
||||
perform_initialization=False)
|
||||
self.layers = nn.ModuleList([GPTNeoXLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.layers = nn.ModuleList(
|
||||
[GPTNeoXLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.final_layer_norm = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -191,8 +201,10 @@ class GPTNeoXForCausalLM(nn.Module):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.gpt_neox = GPTNeoXModel(config)
|
||||
self.embed_out = ColumnParallelLinear(config.hidden_size, config.vocab_size,
|
||||
bias=False, gather_output=False,
|
||||
self.embed_out = ColumnParallelLinear(config.hidden_size,
|
||||
config.vocab_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
|
||||
@ -203,23 +215,27 @@ class GPTNeoXForCausalLM(nn.Module):
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> Dict[int, SequenceOutputs]:
|
||||
hidden_states = self.gpt_neox(
|
||||
input_ids, positions, kv_caches, input_metadata, cache_events)
|
||||
next_tokens = self.sampler(
|
||||
self.embed_out.weight, hidden_states, input_metadata)
|
||||
) -> SamplerOutput:
|
||||
hidden_states = self.gpt_neox(input_ids, positions, kv_caches,
|
||||
input_metadata, cache_events)
|
||||
next_tokens = self.sampler(self.embed_out.weight, hidden_states,
|
||||
input_metadata)
|
||||
return next_tokens
|
||||
|
||||
_column_parallel_weights = ["embed_in.weight", "embed_out.weight", "dense_h_to_4h.weight", "dense_h_to_4h.bias"]
|
||||
_column_parallel_weights = [
|
||||
"embed_in.weight", "embed_out.weight", "dense_h_to_4h.weight",
|
||||
"dense_h_to_4h.bias"
|
||||
]
|
||||
_row_parallel_weights = ["dense.weight", "dense_4h_to_h.weight"]
|
||||
|
||||
def load_weights(self, model_name_or_path: str,
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
use_np_cache: bool = False):
|
||||
load_format: str = "auto"):
|
||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||
state_dict = self.state_dict()
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, use_np_cache):
|
||||
model_name_or_path, cache_dir, load_format):
|
||||
if ("attention.bias" in name or "attention.masked_bias" in name
|
||||
or "rotary_emb.inv_freq" in name):
|
||||
continue
|
||||
@ -230,17 +246,19 @@ class GPTNeoXForCausalLM(nn.Module):
|
||||
# required shape is [3 * num_heads * head_size, hidden_size].
|
||||
# Thus, we need weight conversion.
|
||||
shard_size = param.shape[0]
|
||||
loaded_weight = loaded_weight[shard_size * tensor_model_parallel_rank
|
||||
:shard_size * (tensor_model_parallel_rank + 1)]
|
||||
loaded_weight = loaded_weight[
|
||||
shard_size * tensor_model_parallel_rank:shard_size *
|
||||
(tensor_model_parallel_rank + 1)]
|
||||
|
||||
num_heads = self.config.num_attention_heads
|
||||
hidden_size = self.config.hidden_size
|
||||
head_size = hidden_size // num_heads
|
||||
if 'query_key_value.weight' in name:
|
||||
loaded_weight = loaded_weight.view(-1, 3, head_size, hidden_size)
|
||||
if "query_key_value.weight" in name:
|
||||
loaded_weight = loaded_weight.view(-1, 3, head_size,
|
||||
hidden_size)
|
||||
loaded_weight = loaded_weight.transpose(0, 1)
|
||||
loaded_weight = loaded_weight.reshape(-1, hidden_size)
|
||||
elif 'query_key_value.bias' in name:
|
||||
elif "query_key_value.bias" in name:
|
||||
loaded_weight = loaded_weight.view(-1, 3, head_size)
|
||||
loaded_weight = loaded_weight.transpose(0, 1)
|
||||
loaded_weight = loaded_weight.reshape(-1)
|
||||
|
||||
292
vllm/model_executor/models/internlm.py
Normal file
292
vllm/model_executor/models/internlm.py
Normal file
@ -0,0 +1,292 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import LlamaConfig
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||
ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding)
|
||||
from vllm.model_executor.weight_utils import (
|
||||
hf_model_weights_iterator, load_padded_tensor_parallel_vocab,
|
||||
load_tensor_parallel_weights)
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
class InternLMMLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
):
|
||||
super().__init__()
|
||||
self.gate_up_proj = ColumnParallelLinear(hidden_size,
|
||||
2 * intermediate_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.down_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
def forward(self, x):
|
||||
gate_up, _ = self.gate_up_proj(x)
|
||||
x = self.act_fn(gate_up)
|
||||
x, _ = self.down_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class InternLMAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
tensor_model_parallel_world_size = (
|
||||
get_tensor_model_parallel_world_size())
|
||||
self.total_num_heads = num_heads
|
||||
assert self.total_num_heads % tensor_model_parallel_world_size == 0
|
||||
self.num_heads = (self.total_num_heads //
|
||||
tensor_model_parallel_world_size)
|
||||
self.head_dim = hidden_size // self.total_num_heads
|
||||
self.scaling = self.head_dim**-0.5
|
||||
|
||||
self.qkv_proj = ColumnParallelLinear(
|
||||
hidden_size,
|
||||
3 * self.total_num_heads * self.head_dim,
|
||||
bias=True,
|
||||
gather_output=False,
|
||||
perform_initialization=False,
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=True,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False,
|
||||
)
|
||||
self.attn = PagedAttentionWithRoPE(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
rotary_dim=self.head_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
k_cache, v_cache = kv_cache
|
||||
attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
|
||||
input_metadata, cache_event)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class InternLMDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(self, config: LlamaConfig):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.self_attn = InternLMAttention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
)
|
||||
self.mlp = InternLMMLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
# Self Attention
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
input_metadata=input_metadata,
|
||||
cache_event=cache_event,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
return hidden_states
|
||||
|
||||
|
||||
class InternLMModel(nn.Module):
|
||||
|
||||
def __init__(self, config: LlamaConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
vocab_size, config.hidden_size, perform_initialization=False)
|
||||
self.layers = nn.ModuleList([
|
||||
InternLMDecoderLayer(config)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
for i in range(len(self.layers)):
|
||||
if cache_events is None:
|
||||
cache_event = None
|
||||
else:
|
||||
cache_event = cache_events[i]
|
||||
layer = self.layers[i]
|
||||
hidden_states = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i],
|
||||
input_metadata,
|
||||
cache_event,
|
||||
)
|
||||
hidden_states = self.norm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class InternLMForCausalLM(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.model = InternLMModel(config)
|
||||
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
||||
self.lm_head = ColumnParallelLinear(config.hidden_size,
|
||||
vocab_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> SamplerOutput:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
input_metadata, cache_events)
|
||||
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
||||
input_metadata)
|
||||
return next_tokens
|
||||
|
||||
_column_parallel_weights = [
|
||||
"qkv_proj.weight", "gate_proj.weight", "up_proj.weight"
|
||||
]
|
||||
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
|
||||
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto"):
|
||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||
state_dict = self.state_dict()
|
||||
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format):
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
if "embed_tokens" in name or "lm_head" in name:
|
||||
param = state_dict[name]
|
||||
load_padded_tensor_parallel_vocab(param, loaded_weight,
|
||||
tensor_model_parallel_rank)
|
||||
continue
|
||||
|
||||
is_attention_weight = False
|
||||
for stride_id, att_weight_name in enumerate(
|
||||
["q_proj", "k_proj", "v_proj"]):
|
||||
if att_weight_name not in name:
|
||||
continue
|
||||
param = state_dict[name.replace(att_weight_name, "qkv_proj")]
|
||||
shard_size = param.shape[0] // 3
|
||||
loaded_weight = loaded_weight[
|
||||
shard_size * tensor_model_parallel_rank:shard_size *
|
||||
(tensor_model_parallel_rank + 1)]
|
||||
param_slice = param.data[shard_size * stride_id:shard_size *
|
||||
(stride_id + 1)]
|
||||
assert param_slice.shape == loaded_weight.shape
|
||||
param_slice.copy_(loaded_weight)
|
||||
is_attention_weight = True
|
||||
break
|
||||
if is_attention_weight:
|
||||
continue
|
||||
|
||||
is_gate_up_weight = False
|
||||
for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
|
||||
if weight_name not in name:
|
||||
continue
|
||||
param = state_dict[name.replace(weight_name, "gate_up_proj")]
|
||||
shard_size = param.shape[0] // 2
|
||||
loaded_weight = loaded_weight[
|
||||
shard_size * tensor_model_parallel_rank:shard_size *
|
||||
(tensor_model_parallel_rank + 1)]
|
||||
param_slice = param.data[shard_size * stride_id:shard_size *
|
||||
(stride_id + 1)]
|
||||
assert param_slice.shape == loaded_weight.shape
|
||||
param_slice.copy_(loaded_weight)
|
||||
is_gate_up_weight = True
|
||||
break
|
||||
if is_gate_up_weight:
|
||||
continue
|
||||
|
||||
param = state_dict[name]
|
||||
load_tensor_parallel_weights(param, loaded_weight, name,
|
||||
self._column_parallel_weights,
|
||||
self._row_parallel_weights,
|
||||
tensor_model_parallel_rank)
|
||||
@ -1,5 +1,6 @@
|
||||
# coding=utf-8
|
||||
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
|
||||
# Adapted from
|
||||
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
@ -24,25 +25,25 @@
|
||||
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
||||
InputMetadata to extract the original 2D shape of the input.
|
||||
"""
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import LlamaConfig
|
||||
|
||||
from vllm.sequence import SequenceOutputs
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
|
||||
load_tensor_parallel_weights)
|
||||
from vllm.model_executor.weight_utils import (
|
||||
load_tensor_parallel_weights, load_padded_tensor_parallel_vocab,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
||||
from vllm.sequence import SequenceOutputs
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
@ -56,15 +57,19 @@ class LlamaMLP(nn.Module):
|
||||
hidden_act: str,
|
||||
):
|
||||
super().__init__()
|
||||
self.gate_up_proj = ColumnParallelLinear(hidden_size, 2 * intermediate_size,
|
||||
bias=False, gather_output=False,
|
||||
self.gate_up_proj = ColumnParallelLinear(hidden_size,
|
||||
2 * intermediate_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.down_proj = RowParallelLinear(intermediate_size, hidden_size,
|
||||
bias=False, input_is_parallel=True,
|
||||
self.down_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
if hidden_act != 'silu':
|
||||
raise ValueError(f'Unsupported activation: {hidden_act}. '
|
||||
'Only silu is supported for now.')
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
def forward(self, x):
|
||||
@ -80,19 +85,28 @@ class LlamaAttention(nn.Module):
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
rope_theta: float = 10000,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.total_num_heads = num_heads
|
||||
assert self.total_num_heads % tensor_model_parallel_world_size == 0
|
||||
self.num_heads = self.total_num_heads // tensor_model_parallel_world_size
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_size
|
||||
self.total_num_kv_heads = num_kv_heads
|
||||
assert self.total_num_kv_heads % tp_size == 0
|
||||
self.num_kv_heads = self.total_num_kv_heads // tp_size
|
||||
self.head_dim = hidden_size // self.total_num_heads
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
|
||||
self.qkv_proj = ColumnParallelLinear(
|
||||
hidden_size,
|
||||
3 * self.total_num_heads * self.head_dim,
|
||||
(self.total_num_heads + 2 * self.total_num_kv_heads) *
|
||||
self.head_dim,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
perform_initialization=False,
|
||||
@ -104,8 +118,12 @@ class LlamaAttention(nn.Module):
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False,
|
||||
)
|
||||
self.attn = PagedAttentionWithRoPE(self.num_heads, self.head_dim,
|
||||
self.scaling, rotary_dim=self.head_dim)
|
||||
self.attn = PagedAttentionWithRoPE(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
base=self.rope_theta,
|
||||
rotary_dim=self.head_dim,
|
||||
num_kv_heads=self.num_kv_heads)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -116,10 +134,10 @@ class LlamaAttention(nn.Module):
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
k_cache, v_cache = kv_cache
|
||||
attn_output = self.attn(
|
||||
positions, q, k, v, k_cache, v_cache, input_metadata, cache_event)
|
||||
attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
|
||||
input_metadata, cache_event)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
@ -129,17 +147,23 @@ class LlamaDecoderLayer(nn.Module):
|
||||
def __init__(self, config: LlamaConfig):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
# Requires transformers > 4.32.0
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
self.self_attn = LlamaAttention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
rope_theta=rope_theta,
|
||||
)
|
||||
self.mlp = LlamaMLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -177,9 +201,12 @@ class LlamaModel(nn.Module):
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size,
|
||||
perform_initialization=False)
|
||||
self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
vocab_size, config.hidden_size, perform_initialization=False)
|
||||
self.layers = nn.ModuleList([
|
||||
LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
@ -209,12 +236,14 @@ class LlamaModel(nn.Module):
|
||||
|
||||
|
||||
class LlamaForCausalLM(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.model = LlamaModel(config)
|
||||
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
||||
self.lm_head = ColumnParallelLinear(config.hidden_size,
|
||||
config.vocab_size,
|
||||
vocab_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
@ -227,41 +256,54 @@ class LlamaForCausalLM(nn.Module):
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> Dict[int, SequenceOutputs]:
|
||||
hidden_states = self.model(
|
||||
input_ids, positions, kv_caches, input_metadata, cache_events)
|
||||
next_tokens = self.sampler(
|
||||
self.lm_head.weight, hidden_states, input_metadata)
|
||||
) -> SamplerOutput:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
input_metadata, cache_events)
|
||||
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
||||
input_metadata)
|
||||
return next_tokens
|
||||
|
||||
_column_parallel_weights = ["embed_tokens.weight", "lm_head.weight",
|
||||
"qkv_proj.weight", "gate_proj.weight",
|
||||
"up_proj.weight"]
|
||||
_column_parallel_weights = [
|
||||
"qkv_proj.weight", "gate_proj.weight", "up_proj.weight"
|
||||
]
|
||||
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
|
||||
|
||||
def load_weights(self, model_name_or_path: str,
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
use_np_cache: bool = False):
|
||||
load_format: str = "auto"):
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||
q_proj_shard_size = (self.config.hidden_size // tp_size)
|
||||
kv_proj_shard_size = (self.config.hidden_size //
|
||||
self.config.num_attention_heads *
|
||||
self.config.num_key_value_heads // tp_size)
|
||||
attention_weight_specs = [
|
||||
# (weight_name, shard_size, offset)
|
||||
("q_proj", q_proj_shard_size, 0),
|
||||
("k_proj", kv_proj_shard_size, q_proj_shard_size),
|
||||
("v_proj", kv_proj_shard_size,
|
||||
q_proj_shard_size + kv_proj_shard_size),
|
||||
]
|
||||
state_dict = self.state_dict()
|
||||
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, use_np_cache):
|
||||
model_name_or_path, cache_dir, load_format):
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
is_attention_weight = False
|
||||
for stride_id, att_weight_name in enumerate(["q_proj", "k_proj", "v_proj"]):
|
||||
if att_weight_name not in name:
|
||||
for weight_name, shard_size, offset in attention_weight_specs:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
param = state_dict[name.replace(att_weight_name, "qkv_proj")]
|
||||
shard_size = param.shape[0] // 3
|
||||
param = state_dict[name.replace(weight_name, "qkv_proj")]
|
||||
|
||||
loaded_weight = loaded_weight[
|
||||
shard_size * tensor_model_parallel_rank
|
||||
:shard_size * (tensor_model_parallel_rank + 1)]
|
||||
param_slice = param.data[shard_size * stride_id
|
||||
:shard_size * (stride_id + 1)]
|
||||
shard_size * tensor_model_parallel_rank:shard_size *
|
||||
(tensor_model_parallel_rank + 1)]
|
||||
param_slice = param.data[offset:offset + shard_size]
|
||||
assert param_slice.shape == loaded_weight.shape
|
||||
|
||||
param_slice.copy_(loaded_weight)
|
||||
is_attention_weight = True
|
||||
break
|
||||
@ -275,10 +317,10 @@ class LlamaForCausalLM(nn.Module):
|
||||
param = state_dict[name.replace(weight_name, "gate_up_proj")]
|
||||
shard_size = param.shape[0] // 2
|
||||
loaded_weight = loaded_weight[
|
||||
shard_size * tensor_model_parallel_rank
|
||||
:shard_size * (tensor_model_parallel_rank + 1)]
|
||||
param_slice = param.data[shard_size * stride_id
|
||||
:shard_size * (stride_id + 1)]
|
||||
shard_size * tensor_model_parallel_rank:shard_size *
|
||||
(tensor_model_parallel_rank + 1)]
|
||||
param_slice = param.data[shard_size * stride_id:shard_size *
|
||||
(stride_id + 1)]
|
||||
assert param_slice.shape == loaded_weight.shape
|
||||
param_slice.copy_(loaded_weight)
|
||||
is_gate_up_weight = True
|
||||
@ -287,6 +329,12 @@ class LlamaForCausalLM(nn.Module):
|
||||
continue
|
||||
|
||||
param = state_dict[name]
|
||||
|
||||
if "embed_tokens" in name or "lm_head" in name:
|
||||
load_padded_tensor_parallel_vocab(param, loaded_weight,
|
||||
tensor_model_parallel_rank)
|
||||
continue
|
||||
|
||||
load_tensor_parallel_weights(param, loaded_weight, name,
|
||||
self._column_parallel_weights,
|
||||
self._row_parallel_weights,
|
||||
|
||||
281
vllm/model_executor/models/mpt.py
Normal file
281
vllm/model_executor/models/mpt.py
Normal file
@ -0,0 +1,281 @@
|
||||
# coding=utf-8
|
||||
# Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main
|
||||
import math
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.attention import PagedAttentionWithALiBi
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.weight_utils import (convert_pyslice_to_tensor,
|
||||
hf_model_weights_iterator,
|
||||
load_tensor_parallel_weights)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.transformers_utils.configs.mpt import MPTConfig
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
def _get_alibi_slopes(
|
||||
total_num_heads: int,
|
||||
alibi_bias_max: int,
|
||||
) -> torch.Tensor:
|
||||
next_power_of_2 = 2**math.ceil(math.log2(total_num_heads))
|
||||
m = torch.arange(1, next_power_of_2 + 1, dtype=torch.float32)
|
||||
m = m.mul(alibi_bias_max / next_power_of_2)
|
||||
slopes = 1.0 / torch.pow(2, m)
|
||||
if next_power_of_2 != total_num_heads:
|
||||
slopes = torch.concat([slopes[1::2], slopes[::2]])[:total_num_heads]
|
||||
return slopes
|
||||
|
||||
|
||||
class MPTAttention(nn.Module):
|
||||
|
||||
def __init__(self, config: MPTConfig):
|
||||
super().__init__()
|
||||
self.d_model = config.d_model
|
||||
self.total_num_heads = config.n_heads
|
||||
self.clip_qkv = config.attn_config["clip_qkv"]
|
||||
self.qk_ln = config.attn_config["qk_ln"]
|
||||
self.alibi_bias_max = config.attn_config["alibi_bias_max"]
|
||||
assert not config.attn_config["prefix_lm"]
|
||||
assert config.attn_config["alibi"]
|
||||
|
||||
self.qkv_proj = ColumnParallelLinear(
|
||||
self.d_model,
|
||||
3 * self.d_model,
|
||||
bias=not config.no_bias,
|
||||
gather_output=False,
|
||||
perform_initialization=False,
|
||||
)
|
||||
if self.qk_ln:
|
||||
self.q_ln = nn.LayerNorm(self.d_model)
|
||||
self.k_ln = nn.LayerNorm(self.d_model)
|
||||
self.out_proj = RowParallelLinear(
|
||||
self.d_model,
|
||||
self.d_model,
|
||||
bias=not config.no_bias,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False,
|
||||
)
|
||||
|
||||
tp_world_size = get_tensor_model_parallel_world_size()
|
||||
assert self.total_num_heads % tp_world_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_world_size
|
||||
|
||||
# Create the alibi slopes and slice them.
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
head_start = tp_rank * self.num_heads
|
||||
head_end = (tp_rank + 1) * self.num_heads
|
||||
alibi_slopes = _get_alibi_slopes(self.total_num_heads,
|
||||
self.alibi_bias_max)
|
||||
alibi_slopes = alibi_slopes[head_start:head_end].tolist()
|
||||
|
||||
self.head_dim = self.d_model // self.total_num_heads
|
||||
scaling = self.head_dim**-0.5
|
||||
self.attn = PagedAttentionWithALiBi(self.num_heads, self.head_dim,
|
||||
scaling, alibi_slopes)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
position_ids: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
del position_ids # unused.
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
if self.clip_qkv is not None:
|
||||
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
if self.qk_ln:
|
||||
q = self.q_ln(q)
|
||||
k = self.k_ln(k)
|
||||
k_cache, v_cache = kv_cache
|
||||
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
|
||||
cache_event)
|
||||
output, _ = self.out_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class MPTMLP(nn.Module):
|
||||
|
||||
def __init__(self, config: MPTConfig):
|
||||
super().__init__()
|
||||
hidden_size = config.d_model
|
||||
expansion_ratio = config.expansion_ratio
|
||||
intermediate_size = expansion_ratio * hidden_size
|
||||
self.up_proj = ColumnParallelLinear(hidden_size,
|
||||
intermediate_size,
|
||||
bias=not config.no_bias,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.act = get_act_fn("gelu")
|
||||
self.down_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=not config.no_bias,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x, _ = self.up_proj(x)
|
||||
x = self.act(x)
|
||||
x, _ = self.down_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class MPTBlock(nn.Module):
|
||||
|
||||
def __init__(self, config: MPTConfig):
|
||||
super().__init__()
|
||||
hidden_size = config.d_model
|
||||
self.norm_1 = nn.LayerNorm(hidden_size)
|
||||
self.attn = MPTAttention(config)
|
||||
self.norm_2 = nn.LayerNorm(hidden_size)
|
||||
self.ffn = MPTMLP(config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
position_ids: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
x = self.norm_1(hidden_states)
|
||||
x = self.attn(
|
||||
position_ids=position_ids,
|
||||
hidden_states=x,
|
||||
kv_cache=kv_cache,
|
||||
input_metadata=input_metadata,
|
||||
cache_event=cache_event,
|
||||
)
|
||||
hidden_states = hidden_states + x
|
||||
x = self.norm_2(hidden_states)
|
||||
x = self.ffn(x)
|
||||
hidden_states = hidden_states + x
|
||||
return hidden_states
|
||||
|
||||
|
||||
class MPTModel(nn.Module):
|
||||
|
||||
def __init__(self, config: MPTConfig):
|
||||
super().__init__()
|
||||
assert config.embedding_fraction == 1.0
|
||||
assert config.norm_type == "low_precision_layernorm"
|
||||
|
||||
self.wte = VocabParallelEmbedding(config.vocab_size,
|
||||
config.d_model,
|
||||
perform_initialization=False)
|
||||
self.blocks = nn.ModuleList(
|
||||
[MPTBlock(config) for _ in range(config.n_layers)])
|
||||
self.norm_f = nn.LayerNorm(config.d_model)
|
||||
if config.no_bias:
|
||||
for module in self.modules():
|
||||
if hasattr(module, "bias"):
|
||||
if isinstance(module.bias, nn.Parameter):
|
||||
# Remove the bias term in Linear and LayerNorm.
|
||||
module.register_parameter("bias", None)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.wte(input_ids)
|
||||
for i in range(len(self.blocks)):
|
||||
if cache_events is None:
|
||||
cache_event = None
|
||||
else:
|
||||
cache_event = cache_events[i]
|
||||
block = self.blocks[i]
|
||||
hidden_states = block(
|
||||
position_ids,
|
||||
hidden_states,
|
||||
kv_caches[i],
|
||||
input_metadata,
|
||||
cache_event,
|
||||
)
|
||||
hidden_states = self.norm_f(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class MPTForCausalLM(nn.Module):
|
||||
|
||||
def __init__(self, config: MPTConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
assert config.tie_word_embeddings
|
||||
|
||||
self.transformer = MPTModel(config)
|
||||
# TODO(zhuohan): create a new weight after implementing pipeline
|
||||
# parallelism
|
||||
self.lm_head_weight = self.transformer.wte.weight
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> SamplerOutput:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
input_metadata, cache_events)
|
||||
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
|
||||
input_metadata)
|
||||
return next_tokens
|
||||
|
||||
_column_parallel_weights = ["wte.weight", "up_proj.weight", "up_proj.bias"]
|
||||
_row_parallel_weights = ["out_proj.weight", "down_proj.weight"]
|
||||
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto"):
|
||||
tp_world_size = get_tensor_model_parallel_world_size()
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
state_dict = self.state_dict()
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format):
|
||||
if "Wqkv" in name:
|
||||
# NOTE(woosuk): MPT's fused QKV has the shape of
|
||||
# [3 * num_heads * head_size, hidden_size].
|
||||
# When tensor model parallelism is used, we need to shard
|
||||
# the weight along the hidden dimension.
|
||||
total_num_heads = self.config.num_attention_heads
|
||||
hidden_size = self.config.hidden_size
|
||||
head_size = hidden_size // total_num_heads
|
||||
num_heads = total_num_heads // tp_world_size
|
||||
head_start = tp_rank * num_heads
|
||||
head_end = (tp_rank + 1) * num_heads
|
||||
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
||||
if name.endswith(".weight"):
|
||||
loaded_weight = loaded_weight.view(3, total_num_heads,
|
||||
head_size, hidden_size)
|
||||
loaded_weight = loaded_weight[:, head_start:head_end, :, :]
|
||||
loaded_weight = loaded_weight.reshape(-1, hidden_size)
|
||||
elif name.endswith(".bias"):
|
||||
loaded_weight = loaded_weight.view(3, total_num_heads,
|
||||
head_size)
|
||||
loaded_weight = loaded_weight[:, head_start:head_end, :]
|
||||
loaded_weight = loaded_weight.reshape(-1)
|
||||
else:
|
||||
raise ValueError(f"Unexpected parameter name {name}")
|
||||
name = name.replace("Wqkv", "qkv_proj")
|
||||
param = state_dict[name]
|
||||
load_tensor_parallel_weights(param, loaded_weight, name,
|
||||
self._column_parallel_weights,
|
||||
self._row_parallel_weights, tp_rank)
|
||||
@ -1,7 +1,9 @@
|
||||
# coding=utf-8
|
||||
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/opt/modeling_opt.py
|
||||
# Adapted from
|
||||
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/opt/modeling_opt.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights
|
||||
# reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -19,7 +21,7 @@
|
||||
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
||||
InputMetadata to extract the original 2D shape of the input.
|
||||
"""
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -35,7 +37,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
||||
from vllm.sequence import SequenceOutputs
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
@ -43,8 +45,9 @@ KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
class OPTLearnedPositionalEmbedding(nn.Embedding):
|
||||
|
||||
def __init__(self, num_embeddings: int, embedding_dim: int):
|
||||
# OPT is set up so that if padding_idx is specified then offset the embedding ids by 2
|
||||
# and adjust num_embeddings appropriately. Other models don't have this hack
|
||||
# OPT is set up so that if padding_idx is specified then offset the
|
||||
# embedding ids by 2 and adjust num_embeddings appropriately. Other
|
||||
# models don't have this hack
|
||||
self.offset = 2
|
||||
super().__init__(num_embeddings + self.offset, embedding_dim)
|
||||
|
||||
@ -62,20 +65,26 @@ class OPTAttention(nn.Module):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
|
||||
tensor_model_parallel_world_size = (
|
||||
get_tensor_model_parallel_world_size())
|
||||
total_num_heads = num_heads
|
||||
assert num_heads % tensor_model_parallel_world_size == 0
|
||||
self.num_heads = total_num_heads // tensor_model_parallel_world_size
|
||||
self.head_dim = embed_dim // total_num_heads
|
||||
self.scaling = self.head_dim**-0.5
|
||||
|
||||
self.qkv_proj = ColumnParallelLinear(embed_dim, 3 * embed_dim, bias=bias,
|
||||
self.qkv_proj = ColumnParallelLinear(embed_dim,
|
||||
3 * embed_dim,
|
||||
bias=bias,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.out_proj = RowParallelLinear(embed_dim, embed_dim, bias=bias,
|
||||
self.out_proj = RowParallelLinear(embed_dim,
|
||||
embed_dim,
|
||||
bias=bias,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
self.attn = PagedAttention(self.num_heads, self.head_dim,
|
||||
self.attn = PagedAttention(self.num_heads,
|
||||
self.head_dim,
|
||||
scale=self.scaling)
|
||||
|
||||
def forward(
|
||||
@ -88,8 +97,8 @@ class OPTAttention(nn.Module):
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
key_cache, value_cache = kv_cache
|
||||
attn_output = self.attn(
|
||||
q, k, v, key_cache, value_cache, input_metadata, cache_event)
|
||||
attn_output = self.attn(q, k, v, key_cache, value_cache,
|
||||
input_metadata, cache_event)
|
||||
output, _ = self.out_proj(attn_output)
|
||||
return output
|
||||
|
||||
@ -109,17 +118,21 @@ class OPTDecoderLayer(nn.Module):
|
||||
self.activation_fn = get_act_fn(config.activation_function)
|
||||
|
||||
self.self_attn_layer_norm = nn.LayerNorm(
|
||||
self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine)
|
||||
self.fc1 = ColumnParallelLinear(self.embed_dim, config.ffn_dim,
|
||||
self.embed_dim,
|
||||
elementwise_affine=config.layer_norm_elementwise_affine)
|
||||
self.fc1 = ColumnParallelLinear(self.embed_dim,
|
||||
config.ffn_dim,
|
||||
bias=config.enable_bias,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.fc2 = RowParallelLinear(config.ffn_dim, self.embed_dim,
|
||||
self.fc2 = RowParallelLinear(config.ffn_dim,
|
||||
self.embed_dim,
|
||||
bias=config.enable_bias,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
self.final_layer_norm = nn.LayerNorm(
|
||||
self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine)
|
||||
self.embed_dim,
|
||||
elementwise_affine=config.layer_norm_elementwise_affine)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -133,8 +146,7 @@ class OPTDecoderLayer(nn.Module):
|
||||
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
|
||||
if self.do_layer_norm_before:
|
||||
hidden_states = self.self_attn_layer_norm(hidden_states)
|
||||
hidden_states = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
hidden_states = self.self_attn(hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
input_metadata=input_metadata,
|
||||
cache_event=cache_event)
|
||||
@ -167,7 +179,8 @@ class OPTDecoder(nn.Module):
|
||||
self.max_target_positions = config.max_position_embeddings
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.word_embed_proj_dim,
|
||||
perform_initialization=False)
|
||||
# Positional embeddings are replicated (not sharded).
|
||||
@ -176,26 +189,32 @@ class OPTDecoder(nn.Module):
|
||||
|
||||
# Project out & in will be replicated if they exist.
|
||||
if config.word_embed_proj_dim != config.hidden_size:
|
||||
self.project_out = nn.Linear(config.hidden_size, config.word_embed_proj_dim, bias=False)
|
||||
self.project_out = nn.Linear(config.hidden_size,
|
||||
config.word_embed_proj_dim,
|
||||
bias=False)
|
||||
else:
|
||||
self.project_out = None
|
||||
|
||||
if config.word_embed_proj_dim != config.hidden_size:
|
||||
self.project_in = nn.Linear(config.word_embed_proj_dim, config.hidden_size, bias=False)
|
||||
self.project_in = nn.Linear(config.word_embed_proj_dim,
|
||||
config.hidden_size,
|
||||
bias=False)
|
||||
else:
|
||||
self.project_in = None
|
||||
|
||||
# Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility
|
||||
# with checkpoints that have been fine-tuned before transformers v4.20.1
|
||||
# Note that the only purpose of `config._remove_final_layer_norm` is to
|
||||
# keep backward compatibility with checkpoints that have been fine-tuned
|
||||
# before transformers v4.20.1
|
||||
# see https://github.com/facebookresearch/metaseq/pull/164
|
||||
if config.do_layer_norm_before and not config._remove_final_layer_norm:
|
||||
self.final_layer_norm = nn.LayerNorm(
|
||||
config.hidden_size, elementwise_affine=config.layer_norm_elementwise_affine
|
||||
)
|
||||
config.hidden_size,
|
||||
elementwise_affine=config.layer_norm_elementwise_affine)
|
||||
else:
|
||||
self.final_layer_norm = None
|
||||
|
||||
self.layers = nn.ModuleList([OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.layers = nn.ModuleList(
|
||||
[OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -217,8 +236,8 @@ class OPTDecoder(nn.Module):
|
||||
else:
|
||||
cache_event = cache_events[i]
|
||||
layer = self.layers[i]
|
||||
hidden_states = layer(
|
||||
hidden_states, kv_caches[i], input_metadata, cache_event)
|
||||
hidden_states = layer(hidden_states, kv_caches[i], input_metadata,
|
||||
cache_event)
|
||||
|
||||
if self.final_layer_norm is not None:
|
||||
hidden_states = self.final_layer_norm(hidden_states)
|
||||
@ -241,8 +260,8 @@ class OPTModel(nn.Module):
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> torch.Tensor:
|
||||
return self.decoder(
|
||||
input_ids, positions, kv_caches, input_metadata, cache_events)
|
||||
return self.decoder(input_ids, positions, kv_caches, input_metadata,
|
||||
cache_events)
|
||||
|
||||
|
||||
class OPTForCausalLM(nn.Module):
|
||||
@ -263,24 +282,27 @@ class OPTForCausalLM(nn.Module):
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> Dict[int, SequenceOutputs]:
|
||||
hidden_states = self.model(
|
||||
input_ids, positions, kv_caches, input_metadata, cache_events)
|
||||
next_tokens = self.sampler(
|
||||
self.lm_head_weight, hidden_states, input_metadata)
|
||||
) -> SamplerOutput:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
input_metadata, cache_events)
|
||||
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
|
||||
input_metadata)
|
||||
return next_tokens
|
||||
|
||||
_column_parallel_weights = ["embed_tokens.weight", "fc1.weight", "fc1.bias"]
|
||||
_column_parallel_weights = [
|
||||
"embed_tokens.weight", "fc1.weight", "fc1.bias"
|
||||
]
|
||||
_row_parallel_weights = ["out_proj.weight", "fc2.weight"]
|
||||
|
||||
def load_weights(self, model_name_or_path: str,
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
use_np_cache: bool = False):
|
||||
load_format: str = "auto"):
|
||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||
state_dict = self.state_dict()
|
||||
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, use_np_cache):
|
||||
model_name_or_path, cache_dir, load_format):
|
||||
if "lm_head.weight" in name:
|
||||
continue
|
||||
|
||||
@ -288,16 +310,17 @@ class OPTForCausalLM(nn.Module):
|
||||
name = "model." + name
|
||||
|
||||
is_attention_weight = False
|
||||
for stride_id, att_weight_name in enumerate(["q_proj", "k_proj", "v_proj"]):
|
||||
for stride_id, att_weight_name in enumerate(
|
||||
["q_proj", "k_proj", "v_proj"]):
|
||||
if att_weight_name not in name:
|
||||
continue
|
||||
param = state_dict[name.replace(att_weight_name, "qkv_proj")]
|
||||
shard_size = param.shape[0] // 3
|
||||
loaded_weight = loaded_weight[
|
||||
shard_size * tensor_model_parallel_rank
|
||||
:shard_size * (tensor_model_parallel_rank + 1)]
|
||||
param_slice = param.data[shard_size * stride_id
|
||||
:shard_size * (stride_id + 1)]
|
||||
shard_size * tensor_model_parallel_rank:shard_size *
|
||||
(tensor_model_parallel_rank + 1)]
|
||||
param_slice = param.data[shard_size * stride_id:shard_size *
|
||||
(stride_id + 1)]
|
||||
assert param_slice.shape == loaded_weight.shape
|
||||
param_slice.copy_(loaded_weight)
|
||||
is_attention_weight = True
|
||||
|
||||
316
vllm/model_executor/models/qwen.py
Normal file
316
vllm/model_executor/models/qwen.py
Normal file
@ -0,0 +1,316 @@
|
||||
# coding=utf-8
|
||||
# Adapted from
|
||||
# https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py
|
||||
# Copyright (c) Alibaba Cloud.
|
||||
# LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE
|
||||
"""Inference-only QWen model compatible with HuggingFace weights.
|
||||
|
||||
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
||||
InputMetadata to extract the original 2D shape of the input.
|
||||
"""
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.weight_utils import (
|
||||
convert_pyslice_to_tensor,
|
||||
hf_model_weights_iterator,
|
||||
load_padded_tensor_parallel_vocab,
|
||||
load_tensor_parallel_weights,
|
||||
)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||
VocabParallelEmbedding,
|
||||
ColumnParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.transformers_utils.configs.qwen import QWenConfig
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
class QWenMLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str = "silu",
|
||||
):
|
||||
super().__init__()
|
||||
self.gate_up_proj = ColumnParallelLinear(
|
||||
hidden_size,
|
||||
2 * intermediate_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
perform_initialization=False,
|
||||
)
|
||||
self.c_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False,
|
||||
)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
def forward(self, x):
|
||||
gate_up, _ = self.gate_up_proj(x)
|
||||
x = self.act_fn(gate_up)
|
||||
x, _ = self.c_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class QWenAttention(nn.Module):
|
||||
|
||||
def __init__(self, hidden_size: int, num_heads: int,
|
||||
max_position_embeddings: int):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
|
||||
)
|
||||
self.total_num_heads = num_heads
|
||||
assert self.total_num_heads % tensor_model_parallel_world_size == 0
|
||||
self.num_heads = (self.total_num_heads //
|
||||
tensor_model_parallel_world_size)
|
||||
self.head_dim = hidden_size // self.total_num_heads
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
self.c_attn = ColumnParallelLinear(
|
||||
hidden_size,
|
||||
3 * hidden_size,
|
||||
bias=True,
|
||||
gather_output=False,
|
||||
perform_initialization=False,
|
||||
)
|
||||
self.c_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False,
|
||||
)
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.attn = PagedAttentionWithRoPE(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.c_attn(hidden_states)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
|
||||
k_cache, v_cache = kv_cache
|
||||
attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
|
||||
input_metadata, cache_event)
|
||||
|
||||
output, _ = self.c_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class QWenBlock(nn.Module):
|
||||
|
||||
def __init__(self, config: QWenConfig):
|
||||
super().__init__()
|
||||
self.ln_1 = RMSNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
||||
|
||||
self.attn = QWenAttention(config.n_embd, config.num_attention_heads,
|
||||
config.max_position_embeddings)
|
||||
|
||||
self.ln_2 = RMSNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
||||
|
||||
self.mlp = QWenMLP(config.n_embd, config.ffn_hidden_size // 2)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
# Self Attention
|
||||
residual = hidden_states
|
||||
hidden_states = self.ln_1(hidden_states)
|
||||
hidden_states = self.attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
input_metadata=input_metadata,
|
||||
cache_event=cache_event,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.ln_2(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
return hidden_states
|
||||
|
||||
|
||||
class QWenModel(nn.Module):
|
||||
|
||||
def __init__(self, config: QWenConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
||||
self.wte = VocabParallelEmbedding(vocab_size,
|
||||
config.n_embd,
|
||||
perform_initialization=False)
|
||||
self.h = nn.ModuleList(
|
||||
[QWenBlock(config) for _ in range(config.num_hidden_layers)])
|
||||
self.ln_f = RMSNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.wte(input_ids)
|
||||
for i in range(len(self.h)):
|
||||
if cache_events is None:
|
||||
cache_event = None
|
||||
else:
|
||||
cache_event = cache_events[i]
|
||||
layer = self.h[i]
|
||||
hidden_states = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i],
|
||||
input_metadata,
|
||||
cache_event,
|
||||
)
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class QWenLMHeadModel(nn.Module):
|
||||
|
||||
def __init__(self, config: QWenConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.transformer = QWenModel(config)
|
||||
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
||||
self.lm_head = ColumnParallelLinear(
|
||||
config.n_embd,
|
||||
vocab_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
perform_initialization=False,
|
||||
)
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> SamplerOutput:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
input_metadata, cache_events)
|
||||
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
||||
input_metadata)
|
||||
return next_tokens
|
||||
|
||||
_column_parallel_weights = []
|
||||
_row_parallel_weights = ["c_proj.weight"]
|
||||
|
||||
def load_weights(
|
||||
self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
):
|
||||
tp_world_size = get_tensor_model_parallel_world_size()
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
state_dict = self.state_dict()
|
||||
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format):
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
||||
|
||||
if "c_attn" in name:
|
||||
total_num_heads = self.config.num_attention_heads
|
||||
hidden_size = self.config.hidden_size
|
||||
head_size = hidden_size // total_num_heads
|
||||
num_heads = total_num_heads // tp_world_size
|
||||
head_start = tp_rank * num_heads
|
||||
head_end = (tp_rank + 1) * num_heads
|
||||
|
||||
if "weight" in name:
|
||||
loaded_weight = loaded_weight.view(3, total_num_heads,
|
||||
head_size, hidden_size)
|
||||
loaded_weight = loaded_weight[:, head_start:head_end, :, :]
|
||||
loaded_weight = loaded_weight.reshape(-1, hidden_size)
|
||||
elif "bias" in name:
|
||||
loaded_weight = loaded_weight.view(3, total_num_heads,
|
||||
head_size)
|
||||
loaded_weight = loaded_weight[:, head_start:head_end, :]
|
||||
loaded_weight = loaded_weight.reshape(-1)
|
||||
|
||||
is_gate_up_weight = False
|
||||
for stride_id, weight_name in enumerate(["w2", "w1"]):
|
||||
if weight_name not in name:
|
||||
continue
|
||||
param = state_dict[name.replace(weight_name, "gate_up_proj")]
|
||||
shard_size = param.shape[0] // 2
|
||||
loaded_weight = loaded_weight[shard_size * tp_rank:shard_size *
|
||||
(tp_rank + 1)]
|
||||
param_slice = param.data[shard_size * stride_id:shard_size *
|
||||
(stride_id + 1)]
|
||||
assert param_slice.shape == loaded_weight.shape
|
||||
param_slice.copy_(loaded_weight)
|
||||
is_gate_up_weight = True
|
||||
break
|
||||
if is_gate_up_weight:
|
||||
continue
|
||||
|
||||
param = state_dict[name]
|
||||
|
||||
if "wte" in name or "lm_head" in name:
|
||||
load_padded_tensor_parallel_vocab(param, loaded_weight,
|
||||
tp_rank)
|
||||
continue
|
||||
|
||||
load_tensor_parallel_weights(
|
||||
param,
|
||||
loaded_weight,
|
||||
name,
|
||||
self._column_parallel_weights,
|
||||
self._row_parallel_weights,
|
||||
tp_rank,
|
||||
)
|
||||
@ -1,9 +1,6 @@
|
||||
import vllm.model_executor.parallel_utils.parallel_state
|
||||
import vllm.model_executor.parallel_utils.tensor_parallel
|
||||
|
||||
# Alias parallel_state as mpu, its legacy name
|
||||
mpu = parallel_state
|
||||
|
||||
__all__ = [
|
||||
"parallel_state",
|
||||
"tensor_parallel",
|
||||
|
||||
@ -44,7 +44,6 @@ _PIPELINE_GLOBAL_RANKS = None
|
||||
# rank when broadcasting weights from src to all other data parallel ranks
|
||||
_DATA_PARALLEL_GLOBAL_RANKS = None
|
||||
|
||||
_ALL_REDUCE_LAUNCHER: Optional['GraphAllReduce'] = None
|
||||
|
||||
def initialize_model_parallel(
|
||||
tensor_model_parallel_size: int = 1,
|
||||
@ -196,20 +195,6 @@ def initialize_model_parallel(
|
||||
if rank in ranks:
|
||||
_POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks
|
||||
|
||||
def initialize_all_reduce_launcher(
|
||||
max_num_tokens: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
disable_graph: bool = False,
|
||||
) -> None:
|
||||
global _ALL_REDUCE_LAUNCHER
|
||||
_ALL_REDUCE_LAUNCHER = GraphAllReduce(
|
||||
max_num_tokens=max_num_tokens,
|
||||
hidden_size=hidden_size,
|
||||
dtype=dtype,
|
||||
disable_graph=disable_graph,
|
||||
)
|
||||
|
||||
def model_parallel_is_initialized():
|
||||
"""Check if model and data parallel groups are initialized."""
|
||||
if _TENSOR_MODEL_PARALLEL_GROUP is None or \
|
||||
@ -458,6 +443,7 @@ def get_pipeline_model_parallel_last_rank():
|
||||
last_rank_local = get_pipeline_model_parallel_world_size() - 1
|
||||
return _PIPELINE_GLOBAL_RANKS[last_rank_local]
|
||||
|
||||
|
||||
def get_pipeline_model_parallel_next_rank():
|
||||
"""Return the global rank that follows the caller in the pipeline"""
|
||||
assert _PIPELINE_GLOBAL_RANKS is not None, \
|
||||
@ -485,10 +471,6 @@ def get_data_parallel_rank():
|
||||
"""Return my rank for the data parallel group."""
|
||||
return torch.distributed.get_rank(group=get_data_parallel_group())
|
||||
|
||||
def get_all_reduce_launcher() -> 'GraphAllReduce':
|
||||
assert _ALL_REDUCE_LAUNCHER is not None, 'all reduce launcher is not initialized'
|
||||
return _ALL_REDUCE_LAUNCHER
|
||||
|
||||
def destroy_model_parallel():
|
||||
"""Set the groups to none."""
|
||||
global _MODEL_PARALLEL_GROUP
|
||||
@ -515,56 +497,3 @@ def destroy_model_parallel():
|
||||
_MPU_TENSOR_MODEL_PARALLEL_RANK = None
|
||||
global _MPU_PIPELINE_MODEL_PARALLEL_RANK
|
||||
_MPU_PIPELINE_MODEL_PARALLEL_RANK = None
|
||||
|
||||
|
||||
class GraphAllReduce:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_num_tokens: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
disable_graph: bool = False,
|
||||
) -> None:
|
||||
self.max_num_tokens = max_num_tokens
|
||||
self.hidden_size = hidden_size
|
||||
self.disable_graph = disable_graph
|
||||
|
||||
tp_world_size = get_tensor_model_parallel_world_size()
|
||||
if tp_world_size == 1:
|
||||
return
|
||||
|
||||
self.group = get_tensor_model_parallel_group()
|
||||
self.buffer = torch.empty(
|
||||
size=(max_num_tokens, hidden_size),
|
||||
dtype=dtype,
|
||||
device='cuda',
|
||||
)
|
||||
|
||||
# Build graphs for different number of tokens.
|
||||
if not self.disable_graph:
|
||||
self.graphs = {}
|
||||
for num_tokens in range(8, max_num_tokens + 1, 8):
|
||||
self.graphs[num_tokens] = self._build_graph(num_tokens)
|
||||
|
||||
def _build_graph(self, num_tokens: int) -> torch.cuda.CUDAGraph:
|
||||
# Warm up.
|
||||
torch.distributed.all_reduce(self.buffer[:num_tokens], group=self.group)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Build graph.
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph):
|
||||
torch.distributed.all_reduce(
|
||||
self.buffer[:num_tokens], group=self.group)
|
||||
torch.cuda.synchronize()
|
||||
return graph
|
||||
|
||||
def launch(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# NOTE: x must be a slice of self.buffer.
|
||||
num_tokens = x.shape[0]
|
||||
if self.disable_graph:
|
||||
torch.distributed.all_reduce(x, group=self.group)
|
||||
else:
|
||||
self.graphs[num_tokens].replay()
|
||||
return x
|
||||
|
||||
@ -12,6 +12,7 @@ from .mappings import (
|
||||
copy_to_tensor_model_parallel_region,
|
||||
gather_from_tensor_model_parallel_region,
|
||||
gather_from_sequence_parallel_region,
|
||||
reduce_from_tensor_model_parallel_region,
|
||||
scatter_to_tensor_model_parallel_region,
|
||||
scatter_to_sequence_parallel_region,
|
||||
)
|
||||
@ -38,7 +39,7 @@ __all__ = [
|
||||
"copy_to_tensor_model_parallel_region",
|
||||
"gather_from_tensor_model_parallel_region",
|
||||
"gather_from_sequence_parallel_region",
|
||||
# "reduce_from_tensor_model_parallel_region",
|
||||
"reduce_from_tensor_model_parallel_region",
|
||||
"scatter_to_tensor_model_parallel_region",
|
||||
"scatter_to_sequence_parallel_region",
|
||||
# random.py
|
||||
|
||||
@ -14,7 +14,6 @@ from torch.nn.parameter import Parameter
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_all_reduce_launcher,
|
||||
)
|
||||
from .mappings import (
|
||||
copy_to_tensor_model_parallel_region,
|
||||
@ -248,8 +247,8 @@ class ColumnParallelLinear(torch.nn.Module):
|
||||
self.output_size = output_size
|
||||
self.gather_output = gather_output
|
||||
# Divide the weight matrix along the last dimension.
|
||||
world_size = get_tensor_model_parallel_world_size()
|
||||
self.output_size_per_partition = divide(output_size, world_size)
|
||||
self.world_size = get_tensor_model_parallel_world_size()
|
||||
self.output_size_per_partition = divide(output_size, self.world_size)
|
||||
self.skip_bias_add = skip_bias_add
|
||||
|
||||
if params_dtype is None:
|
||||
@ -350,6 +349,7 @@ class RowParallelLinear(torch.nn.Module):
|
||||
params_dtype:
|
||||
use_cpu_initialization:
|
||||
perform_initialization:
|
||||
reduce_results:
|
||||
"""
|
||||
|
||||
def __init__(self, input_size, output_size, *,
|
||||
@ -360,6 +360,7 @@ class RowParallelLinear(torch.nn.Module):
|
||||
params_dtype=None,
|
||||
use_cpu_initialization=False,
|
||||
perform_initialization=True,
|
||||
reduce_results=True,
|
||||
):
|
||||
super(RowParallelLinear, self).__init__()
|
||||
|
||||
@ -367,14 +368,19 @@ class RowParallelLinear(torch.nn.Module):
|
||||
self.input_size = input_size
|
||||
self.output_size = output_size
|
||||
self.input_is_parallel = input_is_parallel
|
||||
self.reduce_results = reduce_results
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
|
||||
# Divide the weight matrix along the last dimension.
|
||||
world_size = get_tensor_model_parallel_world_size()
|
||||
self.input_size_per_partition = divide(input_size, world_size)
|
||||
self.world_size = get_tensor_model_parallel_world_size()
|
||||
self.input_size_per_partition = divide(input_size, self.world_size)
|
||||
self.skip_bias_add = skip_bias_add
|
||||
|
||||
if not reduce_results and (bias and not skip_bias_add):
|
||||
raise ValueError("When not reduce the results, adding bias to the "
|
||||
"results can lead to incorrect results")
|
||||
|
||||
# Parameters.
|
||||
# Note: torch.nn.functional.linear performs XA^T + b and as a result
|
||||
# we allocate the transpose.
|
||||
@ -427,17 +433,12 @@ class RowParallelLinear(torch.nn.Module):
|
||||
input_parallel = input_
|
||||
else:
|
||||
input_parallel = scatter_to_tensor_model_parallel_region(input_)
|
||||
if get_tensor_model_parallel_world_size() == 1:
|
||||
# Matrix multiply.
|
||||
output_ = F.linear(input_parallel, self.weight)
|
||||
output_parallel = F.linear(input_parallel, self.weight)
|
||||
if self.reduce_results and self.world_size > 1:
|
||||
output_ = reduce_from_tensor_model_parallel_region(output_parallel)
|
||||
else:
|
||||
# Matrix multiply.
|
||||
all_reduce_launcher = get_all_reduce_launcher()
|
||||
num_tokens = input_parallel.shape[0]
|
||||
output_buffer = all_reduce_launcher.buffer[:num_tokens]
|
||||
torch.matmul(input_parallel, self.weight_t, out=output_buffer)
|
||||
# All-reduce across all the partitions.
|
||||
output_ = all_reduce_launcher.launch(output_buffer)
|
||||
output_ = output_parallel
|
||||
|
||||
if not self.skip_bias_add:
|
||||
output = output_ + self.bias if self.bias is not None else output_
|
||||
|
||||
@ -3,13 +3,19 @@ import filelock
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
from typing import Iterator, List, Optional, Tuple
|
||||
from collections import defaultdict
|
||||
from typing import Iterator, List, Optional, Tuple, Any
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from safetensors.torch import load_file, save_file, safe_open
|
||||
import numpy as np
|
||||
import torch
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class Disabledtqdm(tqdm):
|
||||
|
||||
@ -17,50 +23,150 @@ class Disabledtqdm(tqdm):
|
||||
super().__init__(*args, **kwargs, disable=True)
|
||||
|
||||
|
||||
def hf_model_weights_iterator(
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
use_np_cache: bool = False,
|
||||
) -> Iterator[Tuple[str, torch.Tensor]]:
|
||||
# Prepare file lock directory to prevent multiple processes from
|
||||
# downloading the same model weights at the same time.
|
||||
def get_lock(model_name_or_path: str, cache_dir: Optional[str] = None):
|
||||
lock_dir = cache_dir if cache_dir is not None else "/tmp"
|
||||
lock_file_name = model_name_or_path.replace("/", "-") + ".lock"
|
||||
lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name))
|
||||
return lock
|
||||
|
||||
|
||||
def _shared_pointers(tensors):
|
||||
ptrs = defaultdict(list)
|
||||
for k, v in tensors.items():
|
||||
ptrs[v.data_ptr()].append(k)
|
||||
failing = []
|
||||
for _, names in ptrs.items():
|
||||
if len(names) > 1:
|
||||
failing.append(names)
|
||||
return failing
|
||||
|
||||
|
||||
def convert_bin_to_safetensor_file(
|
||||
pt_filename: str,
|
||||
sf_filename: str,
|
||||
):
|
||||
loaded = torch.load(pt_filename, map_location="cpu")
|
||||
if "state_dict" in loaded:
|
||||
loaded = loaded["state_dict"]
|
||||
shared = _shared_pointers(loaded)
|
||||
for shared_weights in shared:
|
||||
for name in shared_weights[1:]:
|
||||
loaded.pop(name)
|
||||
|
||||
# For tensors to be contiguous
|
||||
loaded = {k: v.contiguous() for k, v in loaded.items()}
|
||||
|
||||
dirname = os.path.dirname(sf_filename)
|
||||
os.makedirs(dirname, exist_ok=True)
|
||||
save_file(loaded, sf_filename, metadata={"format": "pt"})
|
||||
|
||||
# check file size
|
||||
sf_size = os.stat(sf_filename).st_size
|
||||
pt_size = os.stat(pt_filename).st_size
|
||||
if (sf_size - pt_size) / pt_size > 0.01:
|
||||
raise RuntimeError(f"""The file size different is more than 1%:
|
||||
- {sf_filename}: {sf_size}
|
||||
- {pt_filename}: {pt_size}
|
||||
""")
|
||||
|
||||
# check if the tensors are the same
|
||||
reloaded = load_file(sf_filename)
|
||||
for k in loaded:
|
||||
pt_tensor = loaded[k]
|
||||
sf_tensor = reloaded[k]
|
||||
if not torch.equal(pt_tensor, sf_tensor):
|
||||
raise RuntimeError(f"The output tensors do not match for key {k}")
|
||||
|
||||
|
||||
def prepare_hf_model_weights(
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
use_safetensors: bool = False,
|
||||
fall_back_to_pt: bool = True,
|
||||
):
|
||||
# Download model weights from huggingface.
|
||||
is_local = os.path.isdir(model_name_or_path)
|
||||
allow_patterns = "*.safetensors" if use_safetensors else "*.bin"
|
||||
if not is_local:
|
||||
with lock:
|
||||
# Use file lock to prevent multiple processes from
|
||||
# downloading the same model weights at the same time.
|
||||
with get_lock(model_name_or_path, cache_dir):
|
||||
hf_folder = snapshot_download(model_name_or_path,
|
||||
allow_patterns="*.bin",
|
||||
allow_patterns=allow_patterns,
|
||||
cache_dir=cache_dir,
|
||||
tqdm_class=Disabledtqdm)
|
||||
else:
|
||||
hf_folder = model_name_or_path
|
||||
hf_weights_files = glob.glob(os.path.join(hf_folder, allow_patterns))
|
||||
if not use_safetensors:
|
||||
hf_weights_files = [
|
||||
x for x in hf_weights_files if not x.endswith("training_args.bin")
|
||||
]
|
||||
|
||||
hf_bin_files = glob.glob(os.path.join(hf_folder, "*.bin"))
|
||||
if len(hf_weights_files) == 0 and use_safetensors and fall_back_to_pt:
|
||||
return prepare_hf_model_weights(model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
use_safetensors=False,
|
||||
fall_back_to_pt=False)
|
||||
|
||||
if len(hf_weights_files) == 0:
|
||||
raise RuntimeError(
|
||||
f"Cannot find any model weights with `{model_name_or_path}`")
|
||||
|
||||
return hf_folder, hf_weights_files, use_safetensors
|
||||
|
||||
|
||||
def hf_model_weights_iterator(
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
) -> Iterator[Tuple[str, torch.Tensor]]:
|
||||
use_safetensors = False
|
||||
use_np_cache = False
|
||||
fall_back_to_pt = False
|
||||
if load_format == "auto":
|
||||
use_safetensors = True
|
||||
fall_back_to_pt = True
|
||||
elif load_format == "safetensors":
|
||||
use_safetensors = True
|
||||
elif load_format == "pt":
|
||||
pass
|
||||
elif load_format == "npcache":
|
||||
use_np_cache = True
|
||||
else:
|
||||
raise ValueError(f"Unknown load_format: {load_format}")
|
||||
|
||||
hf_folder, hf_weights_files, use_safetensors = prepare_hf_model_weights(
|
||||
model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
use_safetensors=use_safetensors,
|
||||
fall_back_to_pt=fall_back_to_pt)
|
||||
|
||||
if use_np_cache:
|
||||
# Currently np_cache only support *.bin checkpoints
|
||||
assert use_safetensors is False
|
||||
|
||||
# Convert the model weights from torch tensors to numpy arrays for
|
||||
# faster loading.
|
||||
np_folder = os.path.join(hf_folder, 'np')
|
||||
np_folder = os.path.join(hf_folder, "np")
|
||||
os.makedirs(np_folder, exist_ok=True)
|
||||
weight_names_file = os.path.join(np_folder, 'weight_names.json')
|
||||
with lock:
|
||||
weight_names_file = os.path.join(np_folder, "weight_names.json")
|
||||
# Use file lock to prevent multiple processes from
|
||||
# dumping the same model weights to numpy at the same time.
|
||||
with get_lock(model_name_or_path, cache_dir):
|
||||
if not os.path.exists(weight_names_file):
|
||||
weight_names = []
|
||||
for bin_file in hf_bin_files:
|
||||
for bin_file in hf_weights_files:
|
||||
state = torch.load(bin_file, map_location="cpu")
|
||||
for name, param in state.items():
|
||||
param_path = os.path.join(np_folder, name)
|
||||
with open(param_path, "wb") as f:
|
||||
np.save(f, param.cpu().detach().numpy())
|
||||
weight_names.append(name)
|
||||
with open(weight_names_file, 'w') as f:
|
||||
with open(weight_names_file, "w") as f:
|
||||
json.dump(weight_names, f)
|
||||
|
||||
with open(weight_names_file, 'r') as f:
|
||||
with open(weight_names_file, "r") as f:
|
||||
weight_names = json.load(f)
|
||||
|
||||
for name in weight_names:
|
||||
@ -68,16 +174,52 @@ def hf_model_weights_iterator(
|
||||
with open(param_path, "rb") as f:
|
||||
param = np.load(f)
|
||||
yield name, torch.from_numpy(param)
|
||||
elif use_safetensors:
|
||||
for st_file in hf_weights_files:
|
||||
with safe_open(st_file, framework="pt") as f:
|
||||
for name in f.keys():
|
||||
param = f.get_slice(name)
|
||||
yield name, param
|
||||
else:
|
||||
for bin_file in hf_bin_files:
|
||||
for bin_file in hf_weights_files:
|
||||
state = torch.load(bin_file, map_location="cpu")
|
||||
for name, param in state.items():
|
||||
yield name, param
|
||||
del state
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
|
||||
"""convert PySafeSlice object from safetensors to torch.Tensor
|
||||
|
||||
PySafeSlice object supports indexing, which is done before loading the
|
||||
actual tensor and can reduce the amount of memory being read into the
|
||||
memory. However, it does not support more advanced functionalities
|
||||
like `.view()` or `.t()`. Therefore, if we need to modify the loaded
|
||||
tensor with these more complicated operators, we need to convert to
|
||||
tensor first.
|
||||
"""
|
||||
if not isinstance(x, torch.Tensor):
|
||||
x = x[:]
|
||||
return x
|
||||
|
||||
|
||||
def load_padded_tensor_parallel_vocab(
|
||||
param: torch.Tensor,
|
||||
loaded_weight: Any, # `torch.Tensor` or `PySafeSlice`
|
||||
tensor_model_parallel_rank: int,
|
||||
) -> None:
|
||||
shard_size = param.shape[0]
|
||||
start_idx = tensor_model_parallel_rank * shard_size
|
||||
end_idx = (tensor_model_parallel_rank + 1) * shard_size
|
||||
loaded_weight = loaded_weight[start_idx:end_idx]
|
||||
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
||||
param[:loaded_weight.shape[0]].copy_(loaded_weight)
|
||||
|
||||
|
||||
def load_tensor_parallel_weights(
|
||||
param: torch.Tensor,
|
||||
loaded_weight: torch.Tensor,
|
||||
loaded_weight: Any, # `torch.Tensor` or `PySafeSlice`
|
||||
param_name: str,
|
||||
column_parallel_weight_names: List[str],
|
||||
row_parallel_weight_names: List[str],
|
||||
@ -86,19 +228,22 @@ def load_tensor_parallel_weights(
|
||||
for p in column_parallel_weight_names:
|
||||
if p in param_name:
|
||||
shard_size = param.shape[0]
|
||||
loaded_weight = loaded_weight[
|
||||
shard_size * tensor_model_parallel_rank
|
||||
:shard_size * (tensor_model_parallel_rank + 1)]
|
||||
start_idx = tensor_model_parallel_rank * shard_size
|
||||
end_idx = (tensor_model_parallel_rank + 1) * shard_size
|
||||
loaded_weight = loaded_weight[start_idx:end_idx]
|
||||
break
|
||||
for p in row_parallel_weight_names:
|
||||
if p in param_name:
|
||||
shard_size = param.shape[1]
|
||||
loaded_weight = loaded_weight[
|
||||
:,
|
||||
shard_size * tensor_model_parallel_rank
|
||||
:shard_size * (tensor_model_parallel_rank + 1)]
|
||||
start_idx = tensor_model_parallel_rank * shard_size
|
||||
end_idx = (tensor_model_parallel_rank + 1) * shard_size
|
||||
loaded_weight = loaded_weight[:, start_idx:end_idx]
|
||||
break
|
||||
assert param.shape == loaded_weight.shape
|
||||
|
||||
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
||||
assert param.shape == loaded_weight.shape, (
|
||||
f"{param_name} shape mismatch between model and checkpoint: "
|
||||
f"{param.shape} != {loaded_weight.shape}")
|
||||
param.data.copy_(loaded_weight)
|
||||
|
||||
|
||||
|
||||
@ -55,6 +55,7 @@ class RequestOutput:
|
||||
outputs: The output sequences of the request.
|
||||
finished: Whether the whole request is finished.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
request_id: str,
|
||||
@ -74,9 +75,12 @@ class RequestOutput:
|
||||
# Get the top-n sequences.
|
||||
n = seq_group.sampling_params.n
|
||||
seqs = seq_group.get_seqs()
|
||||
assert n <= len(seqs)
|
||||
sorted_seqs = sorted(
|
||||
seqs, key=lambda seq: seq.get_cumulative_logprob(), reverse=True)
|
||||
if seq_group.sampling_params.use_beam_search:
|
||||
sorting_key = lambda seq: seq.get_beam_search_score(
|
||||
seq_group.sampling_params.length_penalty)
|
||||
else:
|
||||
sorting_key = lambda seq: seq.get_cumulative_logprob()
|
||||
sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
|
||||
top_n_seqs = sorted_seqs[:n]
|
||||
|
||||
# Create the outputs.
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
"""Sampling parameters for text generation."""
|
||||
from typing import List, Optional, Union
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
|
||||
|
||||
class SamplingParams:
|
||||
"""Sampling parameters for text generation.
|
||||
@ -32,6 +34,15 @@ class SamplingParams:
|
||||
top_k: Integer that controls the number of top tokens to consider. Set
|
||||
to -1 to consider all tokens.
|
||||
use_beam_search: Whether to use beam search instead of sampling.
|
||||
length_penalty: Float that penalizes sequences based on their length.
|
||||
Used in beam search.
|
||||
early_stopping: Controls the stopping condition for beam search. It
|
||||
accepts the following values: `True`, where the generation stops as
|
||||
soon as there are `best_of` complete candidates; `False`, where an
|
||||
heuristic is applied and the generation stops when is it very
|
||||
unlikely to find better candidates; `"never"`, where the beam search
|
||||
procedure only stops when there cannot be better candidates
|
||||
(canonical beam search algorithm).
|
||||
stop: List of strings that stop the generation when they are generated.
|
||||
The returned output will not contain the stop strings.
|
||||
ignore_eos: Whether to ignore the EOS token and continue generating
|
||||
@ -50,7 +61,9 @@ class SamplingParams:
|
||||
top_p: float = 1.0,
|
||||
top_k: int = -1,
|
||||
use_beam_search: bool = False,
|
||||
stop: Union[str, List[str]] = [],
|
||||
length_penalty: float = 1.0,
|
||||
early_stopping: Union[bool, str] = False,
|
||||
stop: Union[None, str, List[str]] = None,
|
||||
ignore_eos: bool = False,
|
||||
max_tokens: int = 16,
|
||||
logprobs: Optional[int] = None,
|
||||
@ -63,15 +76,24 @@ class SamplingParams:
|
||||
self.top_p = top_p
|
||||
self.top_k = top_k
|
||||
self.use_beam_search = use_beam_search
|
||||
self.stop = [stop] if isinstance(stop, str) else list(stop)
|
||||
self.length_penalty = length_penalty
|
||||
self.early_stopping = early_stopping
|
||||
if stop is None:
|
||||
self.stop = []
|
||||
elif isinstance(stop, str):
|
||||
self.stop = [stop]
|
||||
else:
|
||||
self.stop = list(stop)
|
||||
self.ignore_eos = ignore_eos
|
||||
self.max_tokens = max_tokens
|
||||
self.logprobs = logprobs
|
||||
|
||||
self._verify_args()
|
||||
if self.use_beam_search:
|
||||
self._verity_beam_search()
|
||||
elif self.temperature == 0.0:
|
||||
self._verify_beam_search()
|
||||
else:
|
||||
self._verify_non_beam_search()
|
||||
if self.temperature < _SAMPLING_EPS:
|
||||
# Zero temperature means greedy sampling.
|
||||
self._verify_greedy_sampling()
|
||||
|
||||
@ -102,22 +124,36 @@ class SamplingParams:
|
||||
raise ValueError(
|
||||
f"logprobs must be non-negative, got {self.logprobs}.")
|
||||
|
||||
def _verity_beam_search(self) -> None:
|
||||
def _verify_beam_search(self) -> None:
|
||||
if self.best_of == 1:
|
||||
raise ValueError("best_of must be greater than 1 when using beam "
|
||||
f"search. Got {self.best_of}.")
|
||||
if self.temperature > 0.0:
|
||||
if self.temperature > _SAMPLING_EPS:
|
||||
raise ValueError("temperature must be 0 when using beam search.")
|
||||
if self.top_p < 1.0:
|
||||
if self.top_p < 1.0 - _SAMPLING_EPS:
|
||||
raise ValueError("top_p must be 1 when using beam search.")
|
||||
if self.top_k != -1:
|
||||
raise ValueError("top_k must be -1 when using beam search.")
|
||||
if self.early_stopping not in [True, False, "never"]:
|
||||
raise ValueError(
|
||||
f"early_stopping must be True, False, or 'never', "
|
||||
f"got {self.early_stopping}.")
|
||||
|
||||
def _verify_non_beam_search(self) -> None:
|
||||
if self.early_stopping is not False:
|
||||
raise ValueError("early_stopping is not effective and must be "
|
||||
"False when not using beam search.")
|
||||
if (self.length_penalty < 1.0 - _SAMPLING_EPS
|
||||
or self.length_penalty > 1.0 + _SAMPLING_EPS):
|
||||
raise ValueError(
|
||||
"length_penalty is not effective and must be the "
|
||||
"default value of 1.0 when not using beam search.")
|
||||
|
||||
def _verify_greedy_sampling(self) -> None:
|
||||
if self.best_of > 1:
|
||||
raise ValueError("best_of must be 1 when using greedy sampling."
|
||||
f"Got {self.best_of}.")
|
||||
if self.top_p < 1.0:
|
||||
if self.top_p < 1.0 - _SAMPLING_EPS:
|
||||
raise ValueError("top_p must be 1 when using greedy sampling.")
|
||||
if self.top_k != -1:
|
||||
raise ValueError("top_k must be -1 when using greedy sampling.")
|
||||
@ -131,6 +167,8 @@ class SamplingParams:
|
||||
f"top_p={self.top_p}, "
|
||||
f"top_k={self.top_k}, "
|
||||
f"use_beam_search={self.use_beam_search}, "
|
||||
f"length_penalty={self.length_penalty}, "
|
||||
f"early_stopping={self.early_stopping}, "
|
||||
f"stop={self.stop}, "
|
||||
f"ignore_eos={self.ignore_eos}, "
|
||||
f"max_tokens={self.max_tokens}, "
|
||||
|
||||
180
vllm/sequence.py
180
vllm/sequence.py
@ -1,3 +1,4 @@
|
||||
"""Sequence and its related classes."""
|
||||
import copy
|
||||
import enum
|
||||
from typing import Dict, List, Optional, Union
|
||||
@ -7,12 +8,14 @@ from vllm.sampling_params import SamplingParams
|
||||
|
||||
|
||||
class SequenceStatus(enum.Enum):
|
||||
"""Status of a sequence."""
|
||||
WAITING = enum.auto()
|
||||
RUNNING = enum.auto()
|
||||
SWAPPED = enum.auto()
|
||||
FINISHED_STOPPED = enum.auto()
|
||||
FINISHED_LENGTH_CAPPED = enum.auto()
|
||||
FINISHED_ABORTED = enum.auto()
|
||||
FINISHED_IGNORED = enum.auto()
|
||||
|
||||
@staticmethod
|
||||
def is_finished(status: "SequenceStatus") -> bool:
|
||||
@ -20,6 +23,7 @@ class SequenceStatus(enum.Enum):
|
||||
SequenceStatus.FINISHED_STOPPED,
|
||||
SequenceStatus.FINISHED_LENGTH_CAPPED,
|
||||
SequenceStatus.FINISHED_ABORTED,
|
||||
SequenceStatus.FINISHED_IGNORED,
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
@ -30,12 +34,25 @@ class SequenceStatus(enum.Enum):
|
||||
finish_reason = "length"
|
||||
elif status == SequenceStatus.FINISHED_ABORTED:
|
||||
finish_reason = "abort"
|
||||
elif status == SequenceStatus.FINISHED_IGNORED:
|
||||
finish_reason = "length"
|
||||
else:
|
||||
finish_reason = None
|
||||
return finish_reason
|
||||
|
||||
|
||||
class SequenceData:
|
||||
"""Data associated with a sequence.
|
||||
|
||||
|
||||
Args:
|
||||
prompt_token_ids: The token IDs of the prompt.
|
||||
|
||||
Attributes:
|
||||
prompt_token_ids: The token IDs of the prompt.
|
||||
output_token_ids: The token IDs of the output.
|
||||
cumulative_logprob: The cumulative log probability of the output.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -52,6 +69,9 @@ class SequenceData:
|
||||
def get_len(self) -> int:
|
||||
return len(self.output_token_ids) + len(self.prompt_token_ids)
|
||||
|
||||
def get_prompt_len(self) -> int:
|
||||
return len(self.prompt_token_ids)
|
||||
|
||||
def get_output_len(self) -> int:
|
||||
return len(self.output_token_ids)
|
||||
|
||||
@ -71,6 +91,15 @@ class SequenceData:
|
||||
|
||||
|
||||
class Sequence:
|
||||
"""Stores the data, status, and block information of a sequence.
|
||||
|
||||
Args:
|
||||
seq_id: The ID of the sequence.
|
||||
prompt: The prompt of the sequence.
|
||||
prompt_token_ids: The token IDs of the prompt.
|
||||
block_size: The block size of the sequence. Should be the same as the
|
||||
block size used by the block manager and cache engine.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -101,7 +130,8 @@ class Sequence:
|
||||
self.logical_token_blocks.append(block)
|
||||
|
||||
def _append_tokens_to_blocks(self, token_ids: List[int]) -> None:
|
||||
while token_ids:
|
||||
cursor = 0
|
||||
while cursor < len(token_ids):
|
||||
if not self.logical_token_blocks:
|
||||
self._append_logical_block()
|
||||
|
||||
@ -111,8 +141,9 @@ class Sequence:
|
||||
last_block = self.logical_token_blocks[-1]
|
||||
|
||||
num_empty_slots = last_block.get_num_empty_slots()
|
||||
last_block.append_tokens(token_ids[:num_empty_slots])
|
||||
token_ids = token_ids[num_empty_slots:]
|
||||
last_block.append_tokens(token_ids[cursor:cursor +
|
||||
num_empty_slots])
|
||||
cursor += num_empty_slots
|
||||
|
||||
def append_token_id(
|
||||
self,
|
||||
@ -127,6 +158,9 @@ class Sequence:
|
||||
def get_len(self) -> int:
|
||||
return self.data.get_len()
|
||||
|
||||
def get_prompt_len(self) -> int:
|
||||
return self.data.get_prompt_len()
|
||||
|
||||
def get_output_len(self) -> int:
|
||||
return self.data.get_output_len()
|
||||
|
||||
@ -142,22 +176,48 @@ class Sequence:
|
||||
def get_cumulative_logprob(self) -> float:
|
||||
return self.data.cumulative_logprob
|
||||
|
||||
def get_beam_search_score(self,
|
||||
length_penalty: float = 0.0,
|
||||
seq_len: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None) -> float:
|
||||
"""Calculate the beam search score with length penalty.
|
||||
|
||||
Adapted from
|
||||
|
||||
https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938
|
||||
"""
|
||||
if seq_len is None:
|
||||
seq_len = self.get_len()
|
||||
# Note: HF implementation does not count the EOS token
|
||||
# towards the length, we align with that here for testing.
|
||||
if (eos_token_id is not None
|
||||
and self.get_last_token_id() == eos_token_id):
|
||||
seq_len -= 1
|
||||
return self.get_cumulative_logprob() / (seq_len**length_penalty)
|
||||
|
||||
def is_finished(self) -> bool:
|
||||
return SequenceStatus.is_finished(self.status)
|
||||
|
||||
def fork(self, child_seq: 'Sequence') -> None:
|
||||
child_seq.logical_token_blocks = copy.deepcopy(self.logical_token_blocks)
|
||||
child_seq.output_logprobs = copy.deepcopy(self.output_logprobs)
|
||||
child_seq.data = copy.deepcopy(self.data)
|
||||
return None
|
||||
def fork(self, new_seq_id: int) -> "Sequence":
|
||||
new_seq = copy.deepcopy(self)
|
||||
new_seq.seq_id = new_seq_id
|
||||
return new_seq
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f'Sequence(seq_id={self.seq_id}, '
|
||||
f'status={self.status.name}, '
|
||||
f'num_blocks={len(self.logical_token_blocks)})')
|
||||
return (f"Sequence(seq_id={self.seq_id}, "
|
||||
f"status={self.status.name}, "
|
||||
f"num_blocks={len(self.logical_token_blocks)})")
|
||||
|
||||
|
||||
class SequenceGroup:
|
||||
"""A group of sequences that are generated from the same prompt.
|
||||
|
||||
Args:
|
||||
request_id: The ID of the request.
|
||||
seqs: The list of sequences.
|
||||
sampling_params: The sampling parameters used to generate the outputs.
|
||||
arrival_time: The arrival time of the request.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -167,46 +227,88 @@ class SequenceGroup:
|
||||
arrival_time: float,
|
||||
) -> None:
|
||||
self.request_id = request_id
|
||||
self.seqs = seqs
|
||||
self.seqs_dict = {seq.seq_id: seq for seq in seqs}
|
||||
self.sampling_params = sampling_params
|
||||
self.arrival_time = arrival_time
|
||||
|
||||
def get_max_num_running_seqs(self) -> int:
|
||||
"""The maximum number of sequences running in parallel in the remaining
|
||||
lifetime of the request."""
|
||||
if self.sampling_params.use_beam_search:
|
||||
# For beam search, maximally there will always be `best_of` beam
|
||||
# candidates running in the future.
|
||||
return self.sampling_params.best_of
|
||||
else:
|
||||
if self.sampling_params.best_of > self.num_seqs():
|
||||
# At prompt stage, the sequence group is not yet filled up
|
||||
# and only have one sequence running. However, in the
|
||||
# generation stage, we will have `best_of` sequences running.
|
||||
return self.sampling_params.best_of
|
||||
# At sampling stages, return the number of actual sequences
|
||||
# running.
|
||||
return self.num_seqs(status=SequenceStatus.RUNNING)
|
||||
|
||||
def get_seqs(
|
||||
self,
|
||||
status: Optional[SequenceStatus] = None,
|
||||
) -> List[Sequence]:
|
||||
if status is None:
|
||||
return self.seqs
|
||||
return list(self.seqs_dict.values())
|
||||
else:
|
||||
return [seq for seq in self.seqs if seq.status == status]
|
||||
return [
|
||||
seq for seq in self.seqs_dict.values() if seq.status == status
|
||||
]
|
||||
|
||||
def get_finished_seqs(self) -> List[Sequence]:
|
||||
return [seq for seq in self.seqs_dict.values() if seq.is_finished()]
|
||||
|
||||
def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
|
||||
return len(self.get_seqs(status))
|
||||
|
||||
def find(self, seq_id: int) -> Sequence:
|
||||
for seq in self.seqs:
|
||||
if seq.seq_id == seq_id:
|
||||
return seq
|
||||
raise ValueError(f'Sequence {seq_id} not found.')
|
||||
if seq_id not in self.seqs_dict:
|
||||
raise ValueError(f"Sequence {seq_id} not found.")
|
||||
return self.seqs_dict[seq_id]
|
||||
|
||||
def add(self, seq: Sequence) -> None:
|
||||
if seq.seq_id in self.seqs_dict:
|
||||
raise ValueError(f"Sequence {seq.seq_id} already exists.")
|
||||
self.seqs_dict[seq.seq_id] = seq
|
||||
|
||||
def remove(self, seq_id: int) -> None:
|
||||
if seq_id not in self.seqs_dict:
|
||||
raise ValueError(f"Sequence {seq_id} not found.")
|
||||
del self.seqs_dict[seq_id]
|
||||
|
||||
def is_finished(self) -> bool:
|
||||
return all(seq.is_finished() for seq in self.seqs)
|
||||
return all(seq.is_finished() for seq in self.get_seqs())
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"SequenceGroup(request_id={self.request_id}, "
|
||||
f"sampling_params={self.sampling_params}, "
|
||||
f"num_seqs={len(self.seqs)})")
|
||||
f"num_seqs={len(self.seqs_dict)})")
|
||||
|
||||
|
||||
class SequenceGroupMetadata:
|
||||
"""Metadata for a sequence group. Used to create `InputMetadata`.
|
||||
|
||||
|
||||
Args:
|
||||
request_id: The ID of the request.
|
||||
is_prompt: Whether the request is at prompt stage.
|
||||
seq_data: The sequence data. (Seq id -> sequence data)
|
||||
sampling_params: The sampling parameters used to generate the outputs.
|
||||
block_tables: The block tables. (Seq id -> list of physical block
|
||||
numbers)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
request_id: str,
|
||||
is_prompt: bool,
|
||||
seq_data: Dict[int, SequenceData], # Seq id -> sequence data.
|
||||
seq_data: Dict[int, SequenceData],
|
||||
sampling_params: SamplingParams,
|
||||
block_tables: Dict[int, List[int]], # Seq id -> list of physical block numbers.
|
||||
block_tables: Dict[int, List[int]],
|
||||
) -> None:
|
||||
self.request_id = request_id
|
||||
self.is_prompt = is_prompt
|
||||
@ -216,29 +318,39 @@ class SequenceGroupMetadata:
|
||||
|
||||
|
||||
class SequenceOutputs:
|
||||
"""The model output associated with a sequence.
|
||||
|
||||
Args:
|
||||
parent_seq_id: The ID of the parent sequence (for forking in beam
|
||||
search).
|
||||
output_token: The output token ID.
|
||||
logprobs: The logprobs of the output token.
|
||||
(Token id -> logP(x_i+1 | x_0, ..., x_i))
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
seq_id: int,
|
||||
parent_seq_id: int,
|
||||
output_token: int,
|
||||
logprobs: Dict[int, float], # Token id -> logP(x_i+1 | x_0, ..., x_i).
|
||||
logprobs: Dict[int, float],
|
||||
) -> None:
|
||||
self.seq_id = seq_id
|
||||
self.parent_seq_id = parent_seq_id
|
||||
self.output_token = output_token
|
||||
self.logprobs = logprobs
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f'SequenceOutputs(seq_id={self.seq_id}, '
|
||||
f'parent_seq_id={self.parent_seq_id}, '
|
||||
f'output_token={self.output_token}), '
|
||||
f'logprobs={self.logprobs}')
|
||||
return (f"SequenceOutputs(parent_seq_id={self.parent_seq_id}, "
|
||||
f"output_token={self.output_token}), "
|
||||
f"logprobs={self.logprobs}")
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, SequenceOutputs):
|
||||
return NotImplemented
|
||||
return (self.seq_id == other.seq_id and
|
||||
self.parent_seq_id == other.parent_seq_id and
|
||||
self.output_token == other.output_token and
|
||||
self.logprobs == other.logprobs)
|
||||
return NotImplementedError()
|
||||
return (self.parent_seq_id == other.parent_seq_id
|
||||
and self.output_token == other.output_token
|
||||
and self.logprobs == other.logprobs)
|
||||
|
||||
|
||||
# For each sequence group, we generate a list of SequenceOutputs object,
|
||||
# each of which contains one possible candidate for the next token.
|
||||
SamplerOutput = List[List[SequenceOutputs]]
|
||||
|
||||
0
vllm/transformers_utils/__init__.py
Normal file
0
vllm/transformers_utils/__init__.py
Normal file
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user