mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-24 17:04:35 +08:00
Compare commits
306 Commits
submission
...
v0.2.1
| Author | SHA1 | Date | |
|---|---|---|---|
| 651c614aa4 | |||
| d3a5bd9fb7 | |||
| e8ef4c0820 | |||
| 348897af31 | |||
| 9d9072a069 | |||
| 928de46888 | |||
| 29678cd213 | |||
| d0740dff1b | |||
| de89472897 | |||
| e7c8555d06 | |||
| ec3b5ce9cc | |||
| 6368e777a8 | |||
| 875afe38ab | |||
| ee8217e5be | |||
| 980dd4a2c4 | |||
| 8285736840 | |||
| 91fce82c6f | |||
| ac5cf86aa6 | |||
| 6a6119554c | |||
| b95ee898fe | |||
| 9eed4d1f3e | |||
| 6b5296aa3a | |||
| ee92b58b3a | |||
| 09ff7f106a | |||
| acbed3ef40 | |||
| 66d18a7fb0 | |||
| ba0bfd40e2 | |||
| 84e4e37d14 | |||
| a60b353005 | |||
| ebe4d1db3a | |||
| b5a10eb0ef | |||
| 0967102c6d | |||
| e2fb71ec9f | |||
| f936657eb6 | |||
| 6f88f762bf | |||
| 202351d5bf | |||
| 2e8e49fce3 | |||
| a8e98aee0c | |||
| bb1ba58f06 | |||
| 7bedab5748 | |||
| 20f7cc4cde | |||
| 649aa730c5 | |||
| a19bc5c628 | |||
| 28e616c4e3 | |||
| 30e775281d | |||
| 21877b0d75 | |||
| cf5cb1e33e | |||
| 03ffd0a022 | |||
| a425bd9a9a | |||
| bbbf86565f | |||
| 9f6be8692e | |||
| f187877945 | |||
| 947b794146 | |||
| 8d926e91f1 | |||
| 4ee52bb169 | |||
| 7d7e3b78a3 | |||
| f98b745a81 | |||
| 2d1e86f1b1 | |||
| 1ac4ccf73c | |||
| 2ac4d5e2bf | |||
| 3302f0aef3 | |||
| 6f2dd6c37e | |||
| bc0644574c | |||
| 400b8289f7 | |||
| c1026311b5 | |||
| 2b1c116b5a | |||
| cc796b1358 | |||
| f029ef94d7 | |||
| 95592fa00a | |||
| fbe66e1d0b | |||
| 90979c38f8 | |||
| e21d7687a9 | |||
| ff36139ffc | |||
| e3e79e9e8a | |||
| b9fe4616f9 | |||
| 64ca424e75 | |||
| b5f93d0631 | |||
| a58936966f | |||
| dd54a4b026 | |||
| eda1a7cad3 | |||
| f04908cae7 | |||
| ab019eea75 | |||
| 9841d48a10 | |||
| 3272d7a0b7 | |||
| 0bb1e885a0 | |||
| d6545ad22e | |||
| 90eb3f43ca | |||
| e67b4f2c2a | |||
| d6770d1f23 | |||
| b9cecc2635 | |||
| 898285c9bf | |||
| a62de9ecfd | |||
| 4042d192f5 | |||
| 1117aa1411 | |||
| 080438477f | |||
| 4b5bcf8906 | |||
| 852ef5b4f5 | |||
| db09d4ad83 | |||
| c957c741d9 | |||
| c07ece5ca4 | |||
| 7a9c20c715 | |||
| 005ba458b5 | |||
| 320a622ec4 | |||
| c9927c1a6a | |||
| fbd80ad409 | |||
| 22379d5513 | |||
| 1696725879 | |||
| 002800f081 | |||
| e15932bb60 | |||
| ce741ba3e4 | |||
| bf87484efa | |||
| 8ce9c50d40 | |||
| 32b6816e55 | |||
| c128d69856 | |||
| 55b28b1eee | |||
| e11222333f | |||
| 28873a2799 | |||
| 0080d8329d | |||
| 0d93f15694 | |||
| becd7a56f1 | |||
| 75471386de | |||
| d2b2eed67c | |||
| 4b6f069b6f | |||
| 791d79de32 | |||
| 94d2f59895 | |||
| 75c0ca9d43 | |||
| 2a4ec90854 | |||
| 85ebcda94d | |||
| d64bf1646c | |||
| a41c20435e | |||
| eedac9dba0 | |||
| 14f9c72bfd | |||
| ad5f2fe34c | |||
| 4f8584756d | |||
| 65fc1c3127 | |||
| c393af6cd7 | |||
| 0c04ce3234 | |||
| 73b3de79ea | |||
| d1744376ae | |||
| 805de738f6 | |||
| 1b151ed181 | |||
| e06f504a76 | |||
| 462ae5220a | |||
| 66c54aa9c3 | |||
| 735ecfff61 | |||
| a57d13cc96 | |||
| 79af7e96a0 | |||
| 621980bdc0 | |||
| aa84c92ef6 | |||
| f7389f4763 | |||
| 55fe8a81ec | |||
| e8ddc08ec8 | |||
| 1b0bd0fe8a | |||
| 20044cab7a | |||
| 64f23c2900 | |||
| d4c7755ca8 | |||
| aa39e42c5a | |||
| 953f28cf9a | |||
| c0d00f5be6 | |||
| 58a072be15 | |||
| 82ad323dee | |||
| df5dd3c68e | |||
| 2d867b55fa | |||
| d7a1c6d614 | |||
| 7d5a155e4a | |||
| 1dde34e0f8 | |||
| 6fc2a38b11 | |||
| c487a221ee | |||
| 9925c17940 | |||
| 8c4b2592fb | |||
| cf21a9bd5c | |||
| 16c3e295a8 | |||
| bda41c70dd | |||
| 453bafb96f | |||
| 328d231c17 | |||
| b4b195b360 | |||
| 20b0d88d16 | |||
| 2bdea7ac11 | |||
| 58df2883cb | |||
| 6d7d95a70a | |||
| 96853af5a8 | |||
| dbed69058c | |||
| 7b6ae94059 | |||
| c6dfc3cdbe | |||
| 51be365143 | |||
| c894836108 | |||
| 75beba29b5 | |||
| ddfdf470ae | |||
| b6fbb9a565 | |||
| 2179e4f4c5 | |||
| a945fcc2ae | |||
| be54f8e5c4 | |||
| b396cb4998 | |||
| 1c395b4eaa | |||
| 3d64cf019e | |||
| 98fe8cb542 | |||
| ffa6d2f9f9 | |||
| 404422f42e | |||
| 7717d0838b | |||
| 42e0c1df78 | |||
| e41f06702c | |||
| d6fa1be3a8 | |||
| 0ffded812a | |||
| 0bd2a573a5 | |||
| 49b26e2cec | |||
| dafd924c1f | |||
| 598dc4b79a | |||
| 85de093472 | |||
| f72297562f | |||
| 9d27b09d12 | |||
| 998d9d1509 | |||
| 425040d4c1 | |||
| 4338cc4750 | |||
| bdd6b4c8bc | |||
| 2b7d3aca2e | |||
| 4026a049d3 | |||
| 43710e8d09 | |||
| 526df28fb2 | |||
| 2cf1a333b6 | |||
| 0b7db411b5 | |||
| 471a7a4566 | |||
| 6214dd6ce9 | |||
| 0603379863 | |||
| 665c48963b | |||
| 298695b766 | |||
| 83658c8ace | |||
| 1d24ccb96c | |||
| 14f0b39cda | |||
| 2e0d314384 | |||
| 67d96c29fb | |||
| 033f5c78f5 | |||
| 794e578de0 | |||
| caddfc14c1 | |||
| fc72e39de3 | |||
| b7e62d3454 | |||
| 364536acd1 | |||
| 0b32a987dd | |||
| 570fb2e9cc | |||
| a255885f83 | |||
| 5822ede66e | |||
| 0370afa2e5 | |||
| 7e2a913c64 | |||
| 3f92038b99 | |||
| dcda03b4cb | |||
| bf5f121c02 | |||
| bec7b2dc26 | |||
| 0b98ba15c7 | |||
| e5464ee484 | |||
| bab8f3dd0d | |||
| eedb46bf03 | |||
| 311490a720 | |||
| da5ddcd544 | |||
| 5020e1e80c | |||
| 4298374265 | |||
| e38074b1e6 | |||
| 376725ce74 | |||
| 456941cfe4 | |||
| 1a956e136b | |||
| 8274ca23ac | |||
| 62ec38ea41 | |||
| 0eda2e0953 | |||
| 211318d44a | |||
| 337871c6fd | |||
| 56b7f0efa4 | |||
| d721168449 | |||
| 4a151dd453 | |||
| 057daef778 | |||
| e86717833d | |||
| aedba6d5ec | |||
| a283ec2eec | |||
| 3f942acfe1 | |||
| 19d2899439 | |||
| 655a5e48df | |||
| f746ced08d | |||
| c3442c1f6f | |||
| 7297fa6f7c | |||
| b7955ef17b | |||
| f756799b84 | |||
| 825d8892b5 | |||
| b322fd1607 | |||
| 667ba3995c | |||
| 707ec647bb | |||
| 89988ec8c2 | |||
| 6208d622ca | |||
| 42f1042e1c | |||
| 55f8b0a5de | |||
| 9f88db35da | |||
| ae356774ab | |||
| e331957784 | |||
| 8d66a7b6d7 | |||
| ce26e57fd3 | |||
| 85eb631839 | |||
| add055e151 | |||
| 7c041ab578 | |||
| 8917782af6 | |||
| 7addca5935 | |||
| c84e924287 | |||
| c9d5b6d4a8 | |||
| 189ae23133 | |||
| e548c1488a | |||
| 130d5fd8c7 | |||
| e070829ae8 | |||
| 436e523bf1 | |||
| 27f1410d06 | |||
| 4858f3bb45 | |||
| a96d63c21d |
102
.github/workflows/publish.yml
vendored
Normal file
102
.github/workflows/publish.yml
vendored
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
# This workflow will upload a Python Package to Release asset
|
||||||
|
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions
|
||||||
|
|
||||||
|
name: Create Release
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
tags:
|
||||||
|
- v*
|
||||||
|
|
||||||
|
# Needed to create release and upload assets
|
||||||
|
permissions:
|
||||||
|
contents: write
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
release:
|
||||||
|
# Retrieve tag and create release
|
||||||
|
name: Create Release
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
outputs:
|
||||||
|
upload_url: ${{ steps.create_release.outputs.upload_url }}
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v3
|
||||||
|
|
||||||
|
- name: Extract branch info
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
echo "release_tag=${GITHUB_REF#refs/*/}" >> $GITHUB_ENV
|
||||||
|
|
||||||
|
- name: Create Release
|
||||||
|
id: create_release
|
||||||
|
uses: "actions/github-script@v6"
|
||||||
|
env:
|
||||||
|
RELEASE_TAG: ${{ env.release_tag }}
|
||||||
|
with:
|
||||||
|
github-token: "${{ secrets.GITHUB_TOKEN }}"
|
||||||
|
script: |
|
||||||
|
const script = require('.github/workflows/scripts/create_release.js')
|
||||||
|
await script(github, context, core)
|
||||||
|
|
||||||
|
wheel:
|
||||||
|
name: Build Wheel
|
||||||
|
runs-on: ${{ matrix.os }}
|
||||||
|
needs: release
|
||||||
|
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
os: ['ubuntu-20.04']
|
||||||
|
python-version: ['3.8', '3.9', '3.10', '3.11']
|
||||||
|
pytorch-version: ['2.0.1']
|
||||||
|
cuda-version: ['11.8'] # Github runner can't build anything older than 11.8
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v3
|
||||||
|
|
||||||
|
- name: Set up Linux Env
|
||||||
|
if: ${{ runner.os == 'Linux' }}
|
||||||
|
run: |
|
||||||
|
bash -x .github/workflows/scripts/env.sh
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v4
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
|
||||||
|
- name: Install CUDA ${{ matrix.cuda-version }}
|
||||||
|
run: |
|
||||||
|
bash -x .github/workflows/scripts/cuda-install.sh ${{ matrix.cuda-version }} ${{ matrix.os }}
|
||||||
|
|
||||||
|
- name: Install PyTorch ${{ matrix.pytorch-version }} with CUDA ${{ matrix.cuda-version }}
|
||||||
|
run: |
|
||||||
|
bash -x .github/workflows/scripts/pytorch-install.sh ${{ matrix.python-version }} ${{ matrix.pytorch-version }} ${{ matrix.cuda-version }}
|
||||||
|
|
||||||
|
- name: Build wheel
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
bash -x .github/workflows/scripts/build.sh ${{ matrix.python-version }} ${{ matrix.cuda-version }}
|
||||||
|
wheel_name=$(ls dist/*whl | xargs -n 1 basename)
|
||||||
|
asset_name=${wheel_name//"linux"/"manylinux1"}
|
||||||
|
echo "wheel_name=${wheel_name}" >> $GITHUB_ENV
|
||||||
|
echo "asset_name=${asset_name}" >> $GITHUB_ENV
|
||||||
|
|
||||||
|
- name: Upload Release Asset
|
||||||
|
uses: actions/upload-release-asset@v1
|
||||||
|
env:
|
||||||
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
with:
|
||||||
|
upload_url: ${{ needs.release.outputs.upload_url }}
|
||||||
|
asset_path: ./dist/${{ env.wheel_name }}
|
||||||
|
asset_name: ${{ env.asset_name }}
|
||||||
|
asset_content_type: application/*
|
||||||
|
|
||||||
|
# (Danielkinz): This last step will publish the .whl to pypi. Warning: untested
|
||||||
|
# - name: Publish package
|
||||||
|
# uses: pypa/gh-action-pypi-publish@release/v1.8
|
||||||
|
# with:
|
||||||
|
# repository-url: https://test.pypi.org/legacy/
|
||||||
|
# password: ${{ secrets.PYPI_API_TOKEN }}
|
||||||
|
# skip-existing: true
|
||||||
31
.github/workflows/pylint.yml
vendored
Normal file
31
.github/workflows/pylint.yml
vendored
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
name: pylint
|
||||||
|
|
||||||
|
on:
|
||||||
|
# Trigger the workflow on push or pull request,
|
||||||
|
# but only for the main branch
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
pull_request:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
pylint:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
python-version: ["3.10"]
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v2
|
||||||
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
|
uses: actions/setup-python@v2
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install pylint==2.8.2
|
||||||
|
- name: Analysing the code with pylint
|
||||||
|
run: |
|
||||||
|
pylint vllm tests
|
||||||
15
.github/workflows/scripts/build.sh
vendored
Normal file
15
.github/workflows/scripts/build.sh
vendored
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
python_executable=python$1
|
||||||
|
cuda_home=/usr/local/cuda-$2
|
||||||
|
|
||||||
|
# Update paths
|
||||||
|
PATH=${cuda_home}/bin:$PATH
|
||||||
|
LD_LIBRARY_PATH=${cuda_home}/lib64:$LD_LIBRARY_PATH
|
||||||
|
|
||||||
|
# Install requirements
|
||||||
|
$python_executable -m pip install wheel packaging
|
||||||
|
$python_executable -m pip install -r requirements.txt
|
||||||
|
|
||||||
|
# Build
|
||||||
|
$python_executable setup.py bdist_wheel --dist-dir=dist
|
||||||
20
.github/workflows/scripts/create_release.js
vendored
Normal file
20
.github/workflows/scripts/create_release.js
vendored
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
// Uses Github's API to create the release and wait for result.
|
||||||
|
// We use a JS script since github CLI doesn't provide a way to wait for the release's creation and returns immediately.
|
||||||
|
|
||||||
|
module.exports = async (github, context, core) => {
|
||||||
|
try {
|
||||||
|
const response = await github.rest.repos.createRelease({
|
||||||
|
draft: false,
|
||||||
|
generate_release_notes: true,
|
||||||
|
name: process.env.RELEASE_TAG,
|
||||||
|
owner: context.repo.owner,
|
||||||
|
prerelease: false,
|
||||||
|
repo: context.repo.repo,
|
||||||
|
tag_name: process.env.RELEASE_TAG,
|
||||||
|
});
|
||||||
|
|
||||||
|
core.setOutput('upload_url', response.data.upload_url);
|
||||||
|
} catch (error) {
|
||||||
|
core.setFailed(error.message);
|
||||||
|
}
|
||||||
|
}
|
||||||
18
.github/workflows/scripts/cuda-install.sh
vendored
Normal file
18
.github/workflows/scripts/cuda-install.sh
vendored
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Replace '.' with '-' ex: 11.8 -> 11-8
|
||||||
|
cuda_version=$(echo $1 | tr "." "-")
|
||||||
|
# Removes '-' and '.' ex: ubuntu-20.04 -> ubuntu2004
|
||||||
|
OS=$(echo $2 | tr -d ".\-")
|
||||||
|
|
||||||
|
# Installs CUDA
|
||||||
|
wget -nv https://developer.download.nvidia.com/compute/cuda/repos/${OS}/x86_64/cuda-keyring_1.1-1_all.deb
|
||||||
|
sudo dpkg -i cuda-keyring_1.1-1_all.deb
|
||||||
|
rm cuda-keyring_1.1-1_all.deb
|
||||||
|
sudo apt -qq update
|
||||||
|
sudo apt -y install cuda-${cuda_version} cuda-nvcc-${cuda_version} cuda-libraries-dev-${cuda_version}
|
||||||
|
sudo apt clean
|
||||||
|
|
||||||
|
# Test nvcc
|
||||||
|
PATH=/usr/local/cuda-$1/bin:${PATH}
|
||||||
|
nvcc --version
|
||||||
56
.github/workflows/scripts/env.sh
vendored
Normal file
56
.github/workflows/scripts/env.sh
vendored
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# This file installs common linux environment tools
|
||||||
|
|
||||||
|
export LANG C.UTF-8
|
||||||
|
|
||||||
|
# python_version=$1
|
||||||
|
|
||||||
|
sudo apt-get update && \
|
||||||
|
sudo apt-get install -y --no-install-recommends \
|
||||||
|
software-properties-common \
|
||||||
|
|
||||||
|
sudo apt-get install -y --no-install-recommends \
|
||||||
|
build-essential \
|
||||||
|
apt-utils \
|
||||||
|
ca-certificates \
|
||||||
|
wget \
|
||||||
|
git \
|
||||||
|
vim \
|
||||||
|
libssl-dev \
|
||||||
|
curl \
|
||||||
|
unzip \
|
||||||
|
unrar \
|
||||||
|
cmake \
|
||||||
|
net-tools \
|
||||||
|
sudo \
|
||||||
|
autotools-dev \
|
||||||
|
rsync \
|
||||||
|
jq \
|
||||||
|
openssh-server \
|
||||||
|
tmux \
|
||||||
|
screen \
|
||||||
|
htop \
|
||||||
|
pdsh \
|
||||||
|
openssh-client \
|
||||||
|
lshw \
|
||||||
|
dmidecode \
|
||||||
|
util-linux \
|
||||||
|
automake \
|
||||||
|
autoconf \
|
||||||
|
libtool \
|
||||||
|
net-tools \
|
||||||
|
pciutils \
|
||||||
|
libpci-dev \
|
||||||
|
libaio-dev \
|
||||||
|
libcap2 \
|
||||||
|
libtinfo5 \
|
||||||
|
fakeroot \
|
||||||
|
devscripts \
|
||||||
|
debhelper \
|
||||||
|
nfs-common
|
||||||
|
|
||||||
|
# Remove github bloat files to free up disk space
|
||||||
|
sudo rm -rf "/usr/local/share/boost"
|
||||||
|
sudo rm -rf "$AGENT_TOOLSDIRECTORY"
|
||||||
|
sudo rm -rf "/usr/share/dotnet"
|
||||||
15
.github/workflows/scripts/pytorch-install.sh
vendored
Normal file
15
.github/workflows/scripts/pytorch-install.sh
vendored
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
python_executable=python$1
|
||||||
|
pytorch_version=$2
|
||||||
|
cuda_version=$3
|
||||||
|
|
||||||
|
# Install torch
|
||||||
|
$python_executable -m pip install numpy pyyaml scipy ipython mkl mkl-include ninja cython typing pandas typing-extensions dataclasses setuptools && conda clean -ya
|
||||||
|
$python_executable -m pip install torch==${pytorch_version}+cu${cuda_version//./} --extra-index-url https://download.pytorch.org/whl/cu${cuda_version//./}
|
||||||
|
|
||||||
|
# Print version information
|
||||||
|
$python_executable --version
|
||||||
|
$python_executable -c "import torch; print('PyTorch:', torch.__version__)"
|
||||||
|
$python_executable -c "import torch; print('CUDA:', torch.version.cuda)"
|
||||||
|
$python_executable -c "from torch.utils import cpp_extension; print (cpp_extension.CUDA_HOME)"
|
||||||
31
.github/workflows/yapf.yml
vendored
Normal file
31
.github/workflows/yapf.yml
vendored
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
name: yapf
|
||||||
|
|
||||||
|
on:
|
||||||
|
# Trigger the workflow on push or pull request,
|
||||||
|
# but only for the main branch
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
pull_request:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
jobs:
|
||||||
|
yapf:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
python-version: ["3.10"]
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v2
|
||||||
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
|
uses: actions/setup-python@v2
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install yapf==0.32.0
|
||||||
|
pip install toml==0.10.2
|
||||||
|
- name: Running yapf
|
||||||
|
run: |
|
||||||
|
yapf --diff --recursive vllm tests
|
||||||
185
.gitignore
vendored
185
.gitignore
vendored
@ -1,10 +1,179 @@
|
|||||||
**/*.pyc
|
# Byte-compiled / optimized / DLL files
|
||||||
**/__pycache__/
|
__pycache__/
|
||||||
*.egg-info/
|
*.py[cod]
|
||||||
*.eggs/
|
*$py.class
|
||||||
*.so
|
|
||||||
build/
|
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
cover/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
.pybuilder/
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
# For a library or package, you might want to ignore these files since the code is
|
||||||
|
# intended to run in multiple environments; otherwise, check them in:
|
||||||
|
# .python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
#Pipfile.lock
|
||||||
|
|
||||||
|
# poetry
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||||
|
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||||
|
# commonly ignored for libraries.
|
||||||
|
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||||
|
#poetry.lock
|
||||||
|
|
||||||
|
# pdm
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||||
|
#pdm.lock
|
||||||
|
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||||
|
# in version control.
|
||||||
|
# https://pdm.fming.dev/#use-with-ide
|
||||||
|
.pdm.toml
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
||||||
|
|
||||||
|
# pytype static type analyzer
|
||||||
|
.pytype/
|
||||||
|
|
||||||
|
# Cython debug symbols
|
||||||
|
cython_debug/
|
||||||
|
|
||||||
|
# PyCharm
|
||||||
|
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||||
|
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||||
|
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||||
|
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||||
|
.idea/
|
||||||
|
|
||||||
|
# VSCode
|
||||||
|
.vscode/
|
||||||
|
|
||||||
|
# DS Store
|
||||||
|
.DS_Store
|
||||||
|
|
||||||
|
# Results
|
||||||
|
*.csv
|
||||||
|
|
||||||
|
# Python pickle files
|
||||||
*.pkl
|
*.pkl
|
||||||
*.png
|
|
||||||
**/log.txt
|
# Sphinx documentation
|
||||||
|
_build/
|
||||||
|
|
||||||
|
# vim swap files
|
||||||
|
*.swo
|
||||||
|
*.swp
|
||||||
|
|||||||
434
.pylintrc
Normal file
434
.pylintrc
Normal file
@ -0,0 +1,434 @@
|
|||||||
|
# This Pylint rcfile contains a best-effort configuration to uphold the
|
||||||
|
# best-practices and style described in the Google Python style guide:
|
||||||
|
# https://google.github.io/styleguide/pyguide.html
|
||||||
|
#
|
||||||
|
# Its canonical open-source location is:
|
||||||
|
# https://google.github.io/styleguide/pylintrc
|
||||||
|
|
||||||
|
[MASTER]
|
||||||
|
|
||||||
|
# Files or directories to be skipped. They should be base names, not paths.
|
||||||
|
ignore=docs
|
||||||
|
|
||||||
|
# Files or directories matching the regex patterns are skipped. The regex
|
||||||
|
# matches against base names, not paths.
|
||||||
|
ignore-patterns=
|
||||||
|
|
||||||
|
# Pickle collected data for later comparisons.
|
||||||
|
persistent=no
|
||||||
|
|
||||||
|
# List of plugins (as comma separated values of python modules names) to load,
|
||||||
|
# usually to register additional checkers.
|
||||||
|
load-plugins=
|
||||||
|
|
||||||
|
# Use multiple processes to speed up Pylint.
|
||||||
|
jobs=4
|
||||||
|
|
||||||
|
# Allow loading of arbitrary C extensions. Extensions are imported into the
|
||||||
|
# active Python interpreter and may run arbitrary code.
|
||||||
|
unsafe-load-any-extension=no
|
||||||
|
|
||||||
|
|
||||||
|
[MESSAGES CONTROL]
|
||||||
|
|
||||||
|
# Only show warnings with the listed confidence levels. Leave empty to show
|
||||||
|
# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED
|
||||||
|
confidence=
|
||||||
|
|
||||||
|
# Enable the message, report, category or checker with the given id(s). You can
|
||||||
|
# either give multiple identifier separated by comma (,) or put this option
|
||||||
|
# multiple time (only on the command line, not in the configuration file where
|
||||||
|
# it should appear only once). See also the "--disable" option for examples.
|
||||||
|
#enable=
|
||||||
|
|
||||||
|
# Disable the message, report, category or checker with the given id(s). You
|
||||||
|
# can either give multiple identifiers separated by comma (,) or put this
|
||||||
|
# option multiple times (only on the command line, not in the configuration
|
||||||
|
# file where it should appear only once).You can also use "--disable=all" to
|
||||||
|
# disable everything first and then reenable specific checks. For example, if
|
||||||
|
# you want to run only the similarities checker, you can use "--disable=all
|
||||||
|
# --enable=similarities". If you want to run only the classes checker, but have
|
||||||
|
# no Warning level messages displayed, use"--disable=all --enable=classes
|
||||||
|
# --disable=W"
|
||||||
|
disable=abstract-method,
|
||||||
|
apply-builtin,
|
||||||
|
arguments-differ,
|
||||||
|
attribute-defined-outside-init,
|
||||||
|
backtick,
|
||||||
|
bad-option-value,
|
||||||
|
basestring-builtin,
|
||||||
|
buffer-builtin,
|
||||||
|
c-extension-no-member,
|
||||||
|
consider-using-enumerate,
|
||||||
|
cmp-builtin,
|
||||||
|
cmp-method,
|
||||||
|
coerce-builtin,
|
||||||
|
coerce-method,
|
||||||
|
delslice-method,
|
||||||
|
div-method,
|
||||||
|
duplicate-code,
|
||||||
|
eq-without-hash,
|
||||||
|
execfile-builtin,
|
||||||
|
file-builtin,
|
||||||
|
filter-builtin-not-iterating,
|
||||||
|
fixme,
|
||||||
|
getslice-method,
|
||||||
|
global-statement,
|
||||||
|
hex-method,
|
||||||
|
idiv-method,
|
||||||
|
implicit-str-concat-in-sequence,
|
||||||
|
import-error,
|
||||||
|
import-self,
|
||||||
|
import-star-module-level,
|
||||||
|
inconsistent-return-statements,
|
||||||
|
input-builtin,
|
||||||
|
intern-builtin,
|
||||||
|
invalid-str-codec,
|
||||||
|
locally-disabled,
|
||||||
|
logging-fstring-interpolation, # added by vLLM
|
||||||
|
logging-not-lazy, # added by vLLM
|
||||||
|
long-builtin,
|
||||||
|
long-suffix,
|
||||||
|
map-builtin-not-iterating,
|
||||||
|
misplaced-comparison-constant,
|
||||||
|
missing-class-docstring, # TODO (vLLM): enable
|
||||||
|
missing-function-docstring,
|
||||||
|
missing-module-docstring, # TODO (vLLM): enable
|
||||||
|
metaclass-assignment,
|
||||||
|
next-method-called,
|
||||||
|
next-method-defined,
|
||||||
|
no-absolute-import,
|
||||||
|
no-else-break,
|
||||||
|
no-else-continue,
|
||||||
|
no-else-raise,
|
||||||
|
no-else-return,
|
||||||
|
no-init, # added
|
||||||
|
no-member,
|
||||||
|
no-name-in-module,
|
||||||
|
no-self-use,
|
||||||
|
nonzero-method,
|
||||||
|
oct-method,
|
||||||
|
old-division,
|
||||||
|
old-ne-operator,
|
||||||
|
old-octal-literal,
|
||||||
|
old-raise-syntax,
|
||||||
|
parameter-unpacking,
|
||||||
|
print-statement,
|
||||||
|
raising-string,
|
||||||
|
range-builtin-not-iterating,
|
||||||
|
raw_input-builtin,
|
||||||
|
rdiv-method,
|
||||||
|
reduce-builtin,
|
||||||
|
relative-import,
|
||||||
|
reload-builtin,
|
||||||
|
round-builtin,
|
||||||
|
setslice-method,
|
||||||
|
signature-differs,
|
||||||
|
standarderror-builtin,
|
||||||
|
suppressed-message,
|
||||||
|
sys-max-int,
|
||||||
|
too-few-public-methods,
|
||||||
|
too-many-ancestors,
|
||||||
|
too-many-arguments,
|
||||||
|
too-many-boolean-expressions,
|
||||||
|
too-many-branches,
|
||||||
|
too-many-instance-attributes,
|
||||||
|
too-many-locals,
|
||||||
|
too-many-nested-blocks,
|
||||||
|
too-many-public-methods,
|
||||||
|
too-many-return-statements,
|
||||||
|
too-many-statements,
|
||||||
|
trailing-newlines,
|
||||||
|
unichr-builtin,
|
||||||
|
unicode-builtin,
|
||||||
|
unnecessary-pass,
|
||||||
|
unpacking-in-except,
|
||||||
|
unspecified-encoding,
|
||||||
|
useless-else-on-loop,
|
||||||
|
useless-object-inheritance,
|
||||||
|
useless-suppression,
|
||||||
|
using-cmp-argument,
|
||||||
|
wrong-import-order,
|
||||||
|
xrange-builtin,
|
||||||
|
zip-builtin-not-iterating,
|
||||||
|
|
||||||
|
|
||||||
|
[REPORTS]
|
||||||
|
|
||||||
|
# Set the output format. Available formats are text, parseable, colorized, msvs
|
||||||
|
# (visual studio) and html. You can also give a reporter class, eg
|
||||||
|
# mypackage.mymodule.MyReporterClass.
|
||||||
|
output-format=text
|
||||||
|
|
||||||
|
# Tells whether to display a full report or only the messages
|
||||||
|
reports=no
|
||||||
|
|
||||||
|
# Python expression which should return a note less than 10 (10 is the highest
|
||||||
|
# note). You have access to the variables errors warning, statement which
|
||||||
|
# respectively contain the number of errors / warnings messages and the total
|
||||||
|
# number of statements analyzed. This is used by the global evaluation report
|
||||||
|
# (RP0004).
|
||||||
|
evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
|
||||||
|
|
||||||
|
# Template used to display messages. This is a python new-style format string
|
||||||
|
# used to format the message information. See doc for all details
|
||||||
|
#msg-template=
|
||||||
|
|
||||||
|
|
||||||
|
[BASIC]
|
||||||
|
|
||||||
|
# Good variable names which should always be accepted, separated by a comma
|
||||||
|
good-names=main,_
|
||||||
|
|
||||||
|
# Bad variable names which should always be refused, separated by a comma
|
||||||
|
bad-names=
|
||||||
|
|
||||||
|
# Colon-delimited sets of names that determine each other's naming style when
|
||||||
|
# the name regexes allow several styles.
|
||||||
|
name-group=
|
||||||
|
|
||||||
|
# Include a hint for the correct naming format with invalid-name
|
||||||
|
include-naming-hint=no
|
||||||
|
|
||||||
|
# List of decorators that produce properties, such as abc.abstractproperty. Add
|
||||||
|
# to this list to register other decorators that produce valid properties.
|
||||||
|
property-classes=abc.abstractproperty,cached_property.cached_property,cached_property.threaded_cached_property,cached_property.cached_property_with_ttl,cached_property.threaded_cached_property_with_ttl
|
||||||
|
|
||||||
|
# Regular expression matching correct function names
|
||||||
|
function-rgx=^(?:(?P<exempt>setUp|tearDown|setUpModule|tearDownModule)|(?P<camel_case>_?[A-Z][a-zA-Z0-9]*)|(?P<snake_case>_?[a-z][a-z0-9_]*))$
|
||||||
|
|
||||||
|
# Regular expression matching correct variable names
|
||||||
|
variable-rgx=^[a-z][a-z0-9_]*$
|
||||||
|
|
||||||
|
# Regular expression matching correct constant names
|
||||||
|
const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$
|
||||||
|
|
||||||
|
# Regular expression matching correct attribute names
|
||||||
|
attr-rgx=^_{0,2}[a-z][a-z0-9_]*$
|
||||||
|
|
||||||
|
# Regular expression matching correct argument names
|
||||||
|
argument-rgx=^[a-z][a-z0-9_]*$
|
||||||
|
|
||||||
|
# Regular expression matching correct class attribute names
|
||||||
|
class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$
|
||||||
|
|
||||||
|
# Regular expression matching correct inline iteration names
|
||||||
|
inlinevar-rgx=^[a-z][a-z0-9_]*$
|
||||||
|
|
||||||
|
# Regular expression matching correct class names
|
||||||
|
class-rgx=^_?[A-Z][a-zA-Z0-9]*$
|
||||||
|
|
||||||
|
# Regular expression matching correct module names
|
||||||
|
module-rgx=^(_?[a-z][a-z0-9_]*|__init__)$
|
||||||
|
|
||||||
|
# Regular expression matching correct method names
|
||||||
|
method-rgx=(?x)^(?:(?P<exempt>_[a-z0-9_]+__|runTest|setUp|tearDown|setUpTestCase|tearDownTestCase|setupSelf|tearDownClass|setUpClass|(test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|next)|(?P<camel_case>_{0,2}[A-Z][a-zA-Z0-9_]*)|(?P<snake_case>_{0,2}[a-z][a-z0-9_]*))$
|
||||||
|
|
||||||
|
# Regular expression which should only match function or class names that do
|
||||||
|
# not require a docstring.
|
||||||
|
no-docstring-rgx=(__.*__|main|test.*|.*test|.*Test)$
|
||||||
|
|
||||||
|
# Minimum line length for functions/classes that require docstrings, shorter
|
||||||
|
# ones are exempt.
|
||||||
|
docstring-min-length=10
|
||||||
|
|
||||||
|
|
||||||
|
[TYPECHECK]
|
||||||
|
|
||||||
|
# List of decorators that produce context managers, such as
|
||||||
|
# contextlib.contextmanager. Add to this list to register other decorators that
|
||||||
|
# produce valid context managers.
|
||||||
|
contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager
|
||||||
|
|
||||||
|
# Tells whether missing members accessed in mixin class should be ignored. A
|
||||||
|
# mixin class is detected if its name ends with "mixin" (case insensitive).
|
||||||
|
ignore-mixin-members=yes
|
||||||
|
|
||||||
|
# List of module names for which member attributes should not be checked
|
||||||
|
# (useful for modules/projects where namespaces are manipulated during runtime
|
||||||
|
# and thus existing member attributes cannot be deduced by static analysis. It
|
||||||
|
# supports qualified module names, as well as Unix pattern matching.
|
||||||
|
ignored-modules=
|
||||||
|
|
||||||
|
# List of class names for which member attributes should not be checked (useful
|
||||||
|
# for classes with dynamically set attributes). This supports the use of
|
||||||
|
# qualified names.
|
||||||
|
ignored-classes=optparse.Values,thread._local,_thread._local
|
||||||
|
|
||||||
|
# List of members which are set dynamically and missed by pylint inference
|
||||||
|
# system, and so shouldn't trigger E1101 when accessed. Python regular
|
||||||
|
# expressions are accepted.
|
||||||
|
generated-members=
|
||||||
|
|
||||||
|
|
||||||
|
[FORMAT]
|
||||||
|
|
||||||
|
# Maximum number of characters on a single line.
|
||||||
|
max-line-length=80
|
||||||
|
|
||||||
|
# TODO(https://github.com/PyCQA/pylint/issues/3352): Direct pylint to exempt
|
||||||
|
# lines made too long by directives to pytype.
|
||||||
|
|
||||||
|
# Regexp for a line that is allowed to be longer than the limit.
|
||||||
|
ignore-long-lines=(?x)(
|
||||||
|
^\s*(\#\ )?<?https?://\S+>?$|
|
||||||
|
^\s*(from\s+\S+\s+)?import\s+.+$)
|
||||||
|
|
||||||
|
# Allow the body of an if to be on the same line as the test if there is no
|
||||||
|
# else.
|
||||||
|
single-line-if-stmt=yes
|
||||||
|
|
||||||
|
# Maximum number of lines in a module
|
||||||
|
max-module-lines=99999
|
||||||
|
|
||||||
|
# String used as indentation unit. The internal Google style guide mandates 2
|
||||||
|
# spaces. Google's externaly-published style guide says 4, consistent with
|
||||||
|
# PEP 8. Here, we use 2 spaces, for conformity with many open-sourced Google
|
||||||
|
# projects (like TensorFlow).
|
||||||
|
indent-string=' '
|
||||||
|
|
||||||
|
# Number of spaces of indent required inside a hanging or continued line.
|
||||||
|
indent-after-paren=4
|
||||||
|
|
||||||
|
# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
|
||||||
|
expected-line-ending-format=
|
||||||
|
|
||||||
|
|
||||||
|
[MISCELLANEOUS]
|
||||||
|
|
||||||
|
# List of note tags to take in consideration, separated by a comma.
|
||||||
|
notes=TODO
|
||||||
|
|
||||||
|
|
||||||
|
[STRING]
|
||||||
|
|
||||||
|
# This flag controls whether inconsistent-quotes generates a warning when the
|
||||||
|
# character used as a quote delimiter is used inconsistently within a module.
|
||||||
|
check-quote-consistency=yes
|
||||||
|
|
||||||
|
|
||||||
|
[VARIABLES]
|
||||||
|
|
||||||
|
# Tells whether we should check for unused import in __init__ files.
|
||||||
|
init-import=no
|
||||||
|
|
||||||
|
# A regular expression matching the name of dummy variables (i.e. expectedly
|
||||||
|
# not used).
|
||||||
|
dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_)
|
||||||
|
|
||||||
|
# List of additional names supposed to be defined in builtins. Remember that
|
||||||
|
# you should avoid to define new builtins when possible.
|
||||||
|
additional-builtins=
|
||||||
|
|
||||||
|
# List of strings which can identify a callback function by name. A callback
|
||||||
|
# name must start or end with one of those strings.
|
||||||
|
callbacks=cb_,_cb
|
||||||
|
|
||||||
|
# List of qualified module names which can have objects that can redefine
|
||||||
|
# builtins.
|
||||||
|
redefining-builtins-modules=six,six.moves,past.builtins,future.builtins,functools
|
||||||
|
|
||||||
|
|
||||||
|
[LOGGING]
|
||||||
|
|
||||||
|
# Logging modules to check that the string format arguments are in logging
|
||||||
|
# function parameter format
|
||||||
|
logging-modules=logging,absl.logging,tensorflow.io.logging
|
||||||
|
|
||||||
|
|
||||||
|
[SIMILARITIES]
|
||||||
|
|
||||||
|
# Minimum lines number of a similarity.
|
||||||
|
min-similarity-lines=4
|
||||||
|
|
||||||
|
# Ignore comments when computing similarities.
|
||||||
|
ignore-comments=yes
|
||||||
|
|
||||||
|
# Ignore docstrings when computing similarities.
|
||||||
|
ignore-docstrings=yes
|
||||||
|
|
||||||
|
# Ignore imports when computing similarities.
|
||||||
|
ignore-imports=no
|
||||||
|
|
||||||
|
|
||||||
|
[SPELLING]
|
||||||
|
|
||||||
|
# Spelling dictionary name. Available dictionaries: none. To make it working
|
||||||
|
# install python-enchant package.
|
||||||
|
spelling-dict=
|
||||||
|
|
||||||
|
# List of comma separated words that should not be checked.
|
||||||
|
spelling-ignore-words=
|
||||||
|
|
||||||
|
# A path to a file that contains private dictionary; one word per line.
|
||||||
|
spelling-private-dict-file=
|
||||||
|
|
||||||
|
# Tells whether to store unknown words to indicated private dictionary in
|
||||||
|
# --spelling-private-dict-file option instead of raising a message.
|
||||||
|
spelling-store-unknown-words=no
|
||||||
|
|
||||||
|
|
||||||
|
[IMPORTS]
|
||||||
|
|
||||||
|
# Deprecated modules which should not be used, separated by a comma
|
||||||
|
deprecated-modules=regsub,
|
||||||
|
TERMIOS,
|
||||||
|
Bastion,
|
||||||
|
rexec,
|
||||||
|
sets
|
||||||
|
|
||||||
|
# Create a graph of every (i.e. internal and external) dependencies in the
|
||||||
|
# given file (report RP0402 must not be disabled)
|
||||||
|
import-graph=
|
||||||
|
|
||||||
|
# Create a graph of external dependencies in the given file (report RP0402 must
|
||||||
|
# not be disabled)
|
||||||
|
ext-import-graph=
|
||||||
|
|
||||||
|
# Create a graph of internal dependencies in the given file (report RP0402 must
|
||||||
|
# not be disabled)
|
||||||
|
int-import-graph=
|
||||||
|
|
||||||
|
# Force import order to recognize a module as part of the standard
|
||||||
|
# compatibility libraries.
|
||||||
|
known-standard-library=
|
||||||
|
|
||||||
|
# Force import order to recognize a module as part of a third party library.
|
||||||
|
known-third-party=enchant, absl
|
||||||
|
|
||||||
|
# Analyse import fallback blocks. This can be used to support both Python 2 and
|
||||||
|
# 3 compatible code, which means that the block might have code that exists
|
||||||
|
# only in one or another interpreter, leading to false positives when analysed.
|
||||||
|
analyse-fallback-blocks=no
|
||||||
|
|
||||||
|
|
||||||
|
[CLASSES]
|
||||||
|
|
||||||
|
# List of method names used to declare (i.e. assign) instance attributes.
|
||||||
|
defining-attr-methods=__init__,
|
||||||
|
__new__,
|
||||||
|
setUp
|
||||||
|
|
||||||
|
# List of member names, which should be excluded from the protected access
|
||||||
|
# warning.
|
||||||
|
exclude-protected=_asdict,
|
||||||
|
_fields,
|
||||||
|
_replace,
|
||||||
|
_source,
|
||||||
|
_make
|
||||||
|
|
||||||
|
# List of valid names for the first argument in a class method.
|
||||||
|
valid-classmethod-first-arg=cls,
|
||||||
|
class_
|
||||||
|
|
||||||
|
# List of valid names for the first argument in a metaclass class method.
|
||||||
|
valid-metaclass-classmethod-first-arg=mcs
|
||||||
|
|
||||||
|
|
||||||
|
[EXCEPTIONS]
|
||||||
|
|
||||||
|
# Exceptions that will emit a warning when being caught. Defaults to
|
||||||
|
# "Exception"
|
||||||
|
overgeneral-exceptions=StandardError,
|
||||||
|
Exception,
|
||||||
|
BaseException
|
||||||
21
.readthedocs.yaml
Normal file
21
.readthedocs.yaml
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
# Read the Docs configuration file
|
||||||
|
# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
|
||||||
|
|
||||||
|
version: 2
|
||||||
|
|
||||||
|
build:
|
||||||
|
os: ubuntu-22.04
|
||||||
|
tools:
|
||||||
|
python: "3.8"
|
||||||
|
|
||||||
|
sphinx:
|
||||||
|
configuration: docs/source/conf.py
|
||||||
|
|
||||||
|
# If using Sphinx, optionally build your docs in additional formats such as PDF
|
||||||
|
formats:
|
||||||
|
- pdf
|
||||||
|
|
||||||
|
# Optionally declare the Python requirements required to build your docs
|
||||||
|
python:
|
||||||
|
install:
|
||||||
|
- requirements: docs/requirements-docs.txt
|
||||||
77
CONTRIBUTING.md
Normal file
77
CONTRIBUTING.md
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
# Contributing to vLLM
|
||||||
|
|
||||||
|
Thank you for your interest in contributing to vLLM!
|
||||||
|
Our community is open to everyone and welcomes all kinds of contributions, no matter how small or large.
|
||||||
|
There are several ways you can contribute to the project:
|
||||||
|
|
||||||
|
- Identify and report any issues or bugs.
|
||||||
|
- Request or add a new model.
|
||||||
|
- Suggest or implement new features.
|
||||||
|
|
||||||
|
However, remember that contributions aren't just about code.
|
||||||
|
We believe in the power of community support; thus, answering queries, assisting others, and enhancing the documentation are highly regarded and beneficial contributions.
|
||||||
|
|
||||||
|
Finally, one of the most impactful ways to support us is by raising awareness about vLLM.
|
||||||
|
Talk about it in your blog posts, highlighting how it's driving your incredible projects.
|
||||||
|
Express your support on Twitter if vLLM aids you, or simply offer your appreciation by starring our repository.
|
||||||
|
|
||||||
|
|
||||||
|
## Setup for development
|
||||||
|
|
||||||
|
### Build from source
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install -r requirements.txt
|
||||||
|
pip install -e . # This may take several minutes.
|
||||||
|
```
|
||||||
|
|
||||||
|
### Testing
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install -r requirements-dev.txt
|
||||||
|
|
||||||
|
# Static type checking
|
||||||
|
mypy
|
||||||
|
# Unit tests
|
||||||
|
pytest tests/
|
||||||
|
```
|
||||||
|
**Note:** Currently, the repository does not pass the mypy tests.
|
||||||
|
|
||||||
|
|
||||||
|
## Contributing Guidelines
|
||||||
|
|
||||||
|
### Issue Reporting
|
||||||
|
|
||||||
|
If you encounter a bug or have a feature request, please check our issues page first to see if someone else has already reported it.
|
||||||
|
If not, please file a new issue, providing as much relevant information as possible.
|
||||||
|
|
||||||
|
### Coding Style Guide
|
||||||
|
|
||||||
|
In general, we adhere to [Google Python style guide](https://google.github.io/styleguide/pyguide.html) and [Google C++ style guide](https://google.github.io/styleguide/cppguide.html).
|
||||||
|
|
||||||
|
We include a formatting script [`format.sh`](./format.sh) to format the code.
|
||||||
|
|
||||||
|
### Pull Requests
|
||||||
|
|
||||||
|
When submitting a pull request:
|
||||||
|
|
||||||
|
1. Make sure your code has been rebased on top of the latest commit on the main branch.
|
||||||
|
2. Ensure code is properly formatted by running [`format.sh`](./format.sh).
|
||||||
|
3. Include a detailed description of the changes in the pull request.
|
||||||
|
Explain why you made the changes you did.
|
||||||
|
If your pull request fixes an open issue, please include a reference to it in the description.
|
||||||
|
|
||||||
|
### Code Reviews
|
||||||
|
|
||||||
|
All submissions, including submissions by project members, require a code review.
|
||||||
|
To make the review process as smooth as possible, please:
|
||||||
|
|
||||||
|
1. Keep your changes as concise as possible.
|
||||||
|
If your pull request involves multiple unrelated changes, consider splitting it into separate pull requests.
|
||||||
|
2. Respond to all comments within a reasonable time frame.
|
||||||
|
If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.
|
||||||
|
|
||||||
|
### Thank You
|
||||||
|
|
||||||
|
Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM.
|
||||||
|
Your contributions make vLLM a great tool for everyone!
|
||||||
201
LICENSE
Normal file
201
LICENSE
Normal file
@ -0,0 +1,201 @@
|
|||||||
|
Apache License
|
||||||
|
Version 2.0, January 2004
|
||||||
|
http://www.apache.org/licenses/
|
||||||
|
|
||||||
|
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||||
|
|
||||||
|
1. Definitions.
|
||||||
|
|
||||||
|
"License" shall mean the terms and conditions for use, reproduction,
|
||||||
|
and distribution as defined by Sections 1 through 9 of this document.
|
||||||
|
|
||||||
|
"Licensor" shall mean the copyright owner or entity authorized by
|
||||||
|
the copyright owner that is granting the License.
|
||||||
|
|
||||||
|
"Legal Entity" shall mean the union of the acting entity and all
|
||||||
|
other entities that control, are controlled by, or are under common
|
||||||
|
control with that entity. For the purposes of this definition,
|
||||||
|
"control" means (i) the power, direct or indirect, to cause the
|
||||||
|
direction or management of such entity, whether by contract or
|
||||||
|
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||||
|
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||||
|
|
||||||
|
"You" (or "Your") shall mean an individual or Legal Entity
|
||||||
|
exercising permissions granted by this License.
|
||||||
|
|
||||||
|
"Source" form shall mean the preferred form for making modifications,
|
||||||
|
including but not limited to software source code, documentation
|
||||||
|
source, and configuration files.
|
||||||
|
|
||||||
|
"Object" form shall mean any form resulting from mechanical
|
||||||
|
transformation or translation of a Source form, including but
|
||||||
|
not limited to compiled object code, generated documentation,
|
||||||
|
and conversions to other media types.
|
||||||
|
|
||||||
|
"Work" shall mean the work of authorship, whether in Source or
|
||||||
|
Object form, made available under the License, as indicated by a
|
||||||
|
copyright notice that is included in or attached to the work
|
||||||
|
(an example is provided in the Appendix below).
|
||||||
|
|
||||||
|
"Derivative Works" shall mean any work, whether in Source or Object
|
||||||
|
form, that is based on (or derived from) the Work and for which the
|
||||||
|
editorial revisions, annotations, elaborations, or other modifications
|
||||||
|
represent, as a whole, an original work of authorship. For the purposes
|
||||||
|
of this License, Derivative Works shall not include works that remain
|
||||||
|
separable from, or merely link (or bind by name) to the interfaces of,
|
||||||
|
the Work and Derivative Works thereof.
|
||||||
|
|
||||||
|
"Contribution" shall mean any work of authorship, including
|
||||||
|
the original version of the Work and any modifications or additions
|
||||||
|
to that Work or Derivative Works thereof, that is intentionally
|
||||||
|
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||||
|
or by an individual or Legal Entity authorized to submit on behalf of
|
||||||
|
the copyright owner. For the purposes of this definition, "submitted"
|
||||||
|
means any form of electronic, verbal, or written communication sent
|
||||||
|
to the Licensor or its representatives, including but not limited to
|
||||||
|
communication on electronic mailing lists, source code control systems,
|
||||||
|
and issue tracking systems that are managed by, or on behalf of, the
|
||||||
|
Licensor for the purpose of discussing and improving the Work, but
|
||||||
|
excluding communication that is conspicuously marked or otherwise
|
||||||
|
designated in writing by the copyright owner as "Not a Contribution."
|
||||||
|
|
||||||
|
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||||
|
on behalf of whom a Contribution has been received by Licensor and
|
||||||
|
subsequently incorporated within the Work.
|
||||||
|
|
||||||
|
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
copyright license to reproduce, prepare Derivative Works of,
|
||||||
|
publicly display, publicly perform, sublicense, and distribute the
|
||||||
|
Work and such Derivative Works in Source or Object form.
|
||||||
|
|
||||||
|
3. Grant of Patent License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
(except as stated in this section) patent license to make, have made,
|
||||||
|
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||||
|
where such license applies only to those patent claims licensable
|
||||||
|
by such Contributor that are necessarily infringed by their
|
||||||
|
Contribution(s) alone or by combination of their Contribution(s)
|
||||||
|
with the Work to which such Contribution(s) was submitted. If You
|
||||||
|
institute patent litigation against any entity (including a
|
||||||
|
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||||
|
or a Contribution incorporated within the Work constitutes direct
|
||||||
|
or contributory patent infringement, then any patent licenses
|
||||||
|
granted to You under this License for that Work shall terminate
|
||||||
|
as of the date such litigation is filed.
|
||||||
|
|
||||||
|
4. Redistribution. You may reproduce and distribute copies of the
|
||||||
|
Work or Derivative Works thereof in any medium, with or without
|
||||||
|
modifications, and in Source or Object form, provided that You
|
||||||
|
meet the following conditions:
|
||||||
|
|
||||||
|
(a) You must give any other recipients of the Work or
|
||||||
|
Derivative Works a copy of this License; and
|
||||||
|
|
||||||
|
(b) You must cause any modified files to carry prominent notices
|
||||||
|
stating that You changed the files; and
|
||||||
|
|
||||||
|
(c) You must retain, in the Source form of any Derivative Works
|
||||||
|
that You distribute, all copyright, patent, trademark, and
|
||||||
|
attribution notices from the Source form of the Work,
|
||||||
|
excluding those notices that do not pertain to any part of
|
||||||
|
the Derivative Works; and
|
||||||
|
|
||||||
|
(d) If the Work includes a "NOTICE" text file as part of its
|
||||||
|
distribution, then any Derivative Works that You distribute must
|
||||||
|
include a readable copy of the attribution notices contained
|
||||||
|
within such NOTICE file, excluding those notices that do not
|
||||||
|
pertain to any part of the Derivative Works, in at least one
|
||||||
|
of the following places: within a NOTICE text file distributed
|
||||||
|
as part of the Derivative Works; within the Source form or
|
||||||
|
documentation, if provided along with the Derivative Works; or,
|
||||||
|
within a display generated by the Derivative Works, if and
|
||||||
|
wherever such third-party notices normally appear. The contents
|
||||||
|
of the NOTICE file are for informational purposes only and
|
||||||
|
do not modify the License. You may add Your own attribution
|
||||||
|
notices within Derivative Works that You distribute, alongside
|
||||||
|
or as an addendum to the NOTICE text from the Work, provided
|
||||||
|
that such additional attribution notices cannot be construed
|
||||||
|
as modifying the License.
|
||||||
|
|
||||||
|
You may add Your own copyright statement to Your modifications and
|
||||||
|
may provide additional or different license terms and conditions
|
||||||
|
for use, reproduction, or distribution of Your modifications, or
|
||||||
|
for any such Derivative Works as a whole, provided Your use,
|
||||||
|
reproduction, and distribution of the Work otherwise complies with
|
||||||
|
the conditions stated in this License.
|
||||||
|
|
||||||
|
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||||
|
any Contribution intentionally submitted for inclusion in the Work
|
||||||
|
by You to the Licensor shall be under the terms and conditions of
|
||||||
|
this License, without any additional terms or conditions.
|
||||||
|
Notwithstanding the above, nothing herein shall supersede or modify
|
||||||
|
the terms of any separate license agreement you may have executed
|
||||||
|
with Licensor regarding such Contributions.
|
||||||
|
|
||||||
|
6. Trademarks. This License does not grant permission to use the trade
|
||||||
|
names, trademarks, service marks, or product names of the Licensor,
|
||||||
|
except as required for reasonable and customary use in describing the
|
||||||
|
origin of the Work and reproducing the content of the NOTICE file.
|
||||||
|
|
||||||
|
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||||
|
agreed to in writing, Licensor provides the Work (and each
|
||||||
|
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
implied, including, without limitation, any warranties or conditions
|
||||||
|
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||||
|
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||||
|
appropriateness of using or redistributing the Work and assume any
|
||||||
|
risks associated with Your exercise of permissions under this License.
|
||||||
|
|
||||||
|
8. Limitation of Liability. In no event and under no legal theory,
|
||||||
|
whether in tort (including negligence), contract, or otherwise,
|
||||||
|
unless required by applicable law (such as deliberate and grossly
|
||||||
|
negligent acts) or agreed to in writing, shall any Contributor be
|
||||||
|
liable to You for damages, including any direct, indirect, special,
|
||||||
|
incidental, or consequential damages of any character arising as a
|
||||||
|
result of this License or out of the use or inability to use the
|
||||||
|
Work (including but not limited to damages for loss of goodwill,
|
||||||
|
work stoppage, computer failure or malfunction, or any and all
|
||||||
|
other commercial damages or losses), even if such Contributor
|
||||||
|
has been advised of the possibility of such damages.
|
||||||
|
|
||||||
|
9. Accepting Warranty or Additional Liability. While redistributing
|
||||||
|
the Work or Derivative Works thereof, You may choose to offer,
|
||||||
|
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||||
|
or other liability obligations and/or rights consistent with this
|
||||||
|
License. However, in accepting such obligations, You may act only
|
||||||
|
on Your own behalf and on Your sole responsibility, not on behalf
|
||||||
|
of any other Contributor, and only if You agree to indemnify,
|
||||||
|
defend, and hold each Contributor harmless for any liability
|
||||||
|
incurred by, or claims asserted against, such Contributor by reason
|
||||||
|
of your accepting any such warranty or additional liability.
|
||||||
|
|
||||||
|
END OF TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
APPENDIX: How to apply the Apache License to your work.
|
||||||
|
|
||||||
|
To apply the Apache License to your work, attach the following
|
||||||
|
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||||
|
replaced with your own identifying information. (Don't include
|
||||||
|
the brackets!) The text should be enclosed in the appropriate
|
||||||
|
comment syntax for the file format. We also recommend that a
|
||||||
|
file or class name and description of purpose be included on the
|
||||||
|
same "printed page" as the copyright notice for easier
|
||||||
|
identification within third-party archives.
|
||||||
|
|
||||||
|
Copyright [yyyy] [name of copyright owner]
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
4
MANIFEST.in
Normal file
4
MANIFEST.in
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
include LICENSE
|
||||||
|
include requirements.txt
|
||||||
|
|
||||||
|
recursive-include csrc *
|
||||||
150
README.md
150
README.md
@ -1,72 +1,92 @@
|
|||||||
# CacheFlow
|
<p align="center">
|
||||||
|
<picture>
|
||||||
|
<source media="(prefers-color-scheme: dark)" srcset="https://raw.githubusercontent.com/vllm-project/vllm/main/docs/source/assets/logos/vllm-logo-text-dark.png">
|
||||||
|
<img alt="vLLM" src="https://raw.githubusercontent.com/vllm-project/vllm/main/docs/source/assets/logos/vllm-logo-text-light.png" width=55%>
|
||||||
|
</picture>
|
||||||
|
</p>
|
||||||
|
|
||||||
## Installation
|
<h3 align="center">
|
||||||
|
Easy, fast, and cheap LLM serving for everyone
|
||||||
|
</h3>
|
||||||
|
|
||||||
|
<p align="center">
|
||||||
|
| <a href="https://vllm.readthedocs.io/en/latest/"><b>Documentation</b></a> | <a href="https://vllm.ai"><b>Blog</b></a> | <a href="https://arxiv.org/abs/2309.06180"><b>Paper</b></a> | <a href="https://discord.gg/jz7wjKhh6g"><b>Discord</b></a> |
|
||||||
|
|
||||||
|
</p>
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
*Latest News* 🔥
|
||||||
|
- [2023/10] We hosted [the first vLLM meetup](https://lu.ma/first-vllm-meetup) in SF! Please find the meetup slides [here](https://docs.google.com/presentation/d/1QL-XPFXiFpDBh86DbEegFXBXFXjix4v032GhShbKf3s/edit?usp=sharing).
|
||||||
|
- [2023/09] We created our [Discord server](https://discord.gg/jz7wjKhh6g)! Join us to discuss vLLM and LLM serving! We will also post the latest announcements and updates there.
|
||||||
|
- [2023/09] We released our [PagedAttention paper](https://arxiv.org/abs/2309.06180) on arXiv!
|
||||||
|
- [2023/08] We would like to express our sincere gratitude to [Andreessen Horowitz](https://a16z.com/2023/08/30/supporting-the-open-source-ai-community/) (a16z) for providing a generous grant to support the open-source development and research of vLLM.
|
||||||
|
- [2023/07] Added support for LLaMA-2! You can run and serve 7B/13B/70B LLaMA-2s on vLLM with a single command!
|
||||||
|
- [2023/06] Serving vLLM On any Cloud with SkyPilot. Check out a 1-click [example](https://github.com/skypilot-org/skypilot/blob/master/llm/vllm) to start the vLLM demo, and the [blog post](https://blog.skypilot.co/serving-llm-24x-faster-on-the-cloud-with-vllm-and-skypilot/) for the story behind vLLM development on the clouds.
|
||||||
|
- [2023/06] We officially released vLLM! FastChat-vLLM integration has powered [LMSYS Vicuna and Chatbot Arena](https://chat.lmsys.org) since mid-April. Check out our [blog post](https://vllm.ai).
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
vLLM is a fast and easy-to-use library for LLM inference and serving.
|
||||||
|
|
||||||
|
vLLM is fast with:
|
||||||
|
|
||||||
|
- State-of-the-art serving throughput
|
||||||
|
- Efficient management of attention key and value memory with **PagedAttention**
|
||||||
|
- Continuous batching of incoming requests
|
||||||
|
- Optimized CUDA kernels
|
||||||
|
|
||||||
|
vLLM is flexible and easy to use with:
|
||||||
|
|
||||||
|
- Seamless integration with popular Hugging Face models
|
||||||
|
- High-throughput serving with various decoding algorithms, including *parallel sampling*, *beam search*, and more
|
||||||
|
- Tensor parallelism support for distributed inference
|
||||||
|
- Streaming outputs
|
||||||
|
- OpenAI-compatible API server
|
||||||
|
|
||||||
|
vLLM seamlessly supports many Hugging Face models, including the following architectures:
|
||||||
|
|
||||||
|
- Aquila & Aquila2 (`BAAI/AquilaChat2-7B`, `BAAI/AquilaChat2-34B`, `BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc.)
|
||||||
|
- Baichuan (`baichuan-inc/Baichuan-7B`, `baichuan-inc/Baichuan-13B-Chat`, etc.)
|
||||||
|
- BLOOM (`bigscience/bloom`, `bigscience/bloomz`, etc.)
|
||||||
|
- Falcon (`tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc.)
|
||||||
|
- GPT-2 (`gpt2`, `gpt2-xl`, etc.)
|
||||||
|
- GPT BigCode (`bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, etc.)
|
||||||
|
- GPT-J (`EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc.)
|
||||||
|
- GPT-NeoX (`EleutherAI/gpt-neox-20b`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc.)
|
||||||
|
- InternLM (`internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc.)
|
||||||
|
- LLaMA & LLaMA-2 (`meta-llama/Llama-2-70b-hf`, `lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.)
|
||||||
|
- Mistral (`mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc.)
|
||||||
|
- MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.)
|
||||||
|
- OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.)
|
||||||
|
- Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.)
|
||||||
|
|
||||||
|
Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source):
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install psutil numpy ray torch
|
pip install vllm
|
||||||
pip install git+https://github.com/huggingface/transformers # Required for LLaMA.
|
|
||||||
pip install sentencepiece # Required for LlamaTokenizer.
|
|
||||||
pip install ninja # To parallelize the compilation of flash-attn.
|
|
||||||
pip install flash-attn # This may take up to 10 mins.
|
|
||||||
pip install -e .
|
|
||||||
```
|
```
|
||||||
|
|
||||||
## Test simple server
|
## Getting Started
|
||||||
|
|
||||||
```bash
|
Visit our [documentation](https://vllm.readthedocs.io/en/latest/) to get started.
|
||||||
ray start --head
|
- [Installation](https://vllm.readthedocs.io/en/latest/getting_started/installation.html)
|
||||||
python simple_server.py
|
- [Quickstart](https://vllm.readthedocs.io/en/latest/getting_started/quickstart.html)
|
||||||
|
- [Supported Models](https://vllm.readthedocs.io/en/latest/models/supported_models.html)
|
||||||
|
|
||||||
|
## Contributing
|
||||||
|
|
||||||
|
We welcome and value any contributions and collaborations.
|
||||||
|
Please check out [CONTRIBUTING.md](./CONTRIBUTING.md) for how to get involved.
|
||||||
|
|
||||||
|
## Citation
|
||||||
|
|
||||||
|
If you use vLLM for your research, please cite our [paper](https://arxiv.org/abs/2309.06180):
|
||||||
|
```bibtex
|
||||||
|
@inproceedings{kwon2023efficient,
|
||||||
|
title={Efficient Memory Management for Large Language Model Serving with PagedAttention},
|
||||||
|
author={Woosuk Kwon and Zhuohan Li and Siyuan Zhuang and Ying Sheng and Lianmin Zheng and Cody Hao Yu and Joseph E. Gonzalez and Hao Zhang and Ion Stoica},
|
||||||
|
booktitle={Proceedings of the ACM SIGOPS 29th Symposium on Operating Systems Principles},
|
||||||
|
year={2023}
|
||||||
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
The detailed arguments for `simple_server.py` can be found by:
|
|
||||||
```bash
|
|
||||||
python simple_server.py --help
|
|
||||||
```
|
|
||||||
|
|
||||||
## FastAPI server
|
|
||||||
|
|
||||||
Install the following additional dependencies:
|
|
||||||
```bash
|
|
||||||
pip install fastapi uvicorn
|
|
||||||
```
|
|
||||||
|
|
||||||
To start the server:
|
|
||||||
```bash
|
|
||||||
ray start --head
|
|
||||||
python -m cacheflow.http_frontend.fastapi_frontend
|
|
||||||
```
|
|
||||||
|
|
||||||
To test the server:
|
|
||||||
```bash
|
|
||||||
python -m cacheflow.http_frontend.test_cli_client
|
|
||||||
```
|
|
||||||
|
|
||||||
## Gradio web server
|
|
||||||
|
|
||||||
Install the following additional dependencies:
|
|
||||||
```bash
|
|
||||||
pip install gradio
|
|
||||||
```
|
|
||||||
|
|
||||||
Start the server:
|
|
||||||
```bash
|
|
||||||
python -m cacheflow.http_frontend.fastapi_frontend
|
|
||||||
# At another terminal
|
|
||||||
python -m cacheflow.http_frontend.gradio_webserver
|
|
||||||
```
|
|
||||||
|
|
||||||
## Load LLaMA weights
|
|
||||||
|
|
||||||
Since LLaMA weight is not fully public, we cannot directly download the LLaMA weights from huggingface. Therefore, you need to follow the following process to load the LLaMA weights.
|
|
||||||
|
|
||||||
1. Converting LLaMA weights to huggingface format with [this script](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py).
|
|
||||||
```bash
|
|
||||||
python src/transformers/models/llama/convert_llama_weights_to_hf.py \
|
|
||||||
--input_dir /path/to/downloaded/llama/weights --model_size 7B --output_dir /output/path/llama-7b
|
|
||||||
```
|
|
||||||
Please make sure that `llama` is included in the output directory name.
|
|
||||||
2. For all the commands above, specify the model with `--model /output/path/llama-7b` to load the model. For example:
|
|
||||||
```bash
|
|
||||||
python simple_server.py --model /output/path/llama-7b
|
|
||||||
python -m cacheflow.http_frontend.fastapi_frontend --model /output/path/llama-7b
|
|
||||||
```
|
|
||||||
|
|||||||
@ -1,165 +0,0 @@
|
|||||||
import functools
|
|
||||||
import random
|
|
||||||
import time
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from flash_attn.flash_attn_interface import _flash_attn_forward
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from cacheflow import attention_ops
|
|
||||||
|
|
||||||
|
|
||||||
def benchmark(name, f, num_warmup = 10, num_iters = 100):
|
|
||||||
for _ in range(num_warmup):
|
|
||||||
f()
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
start = time.time()
|
|
||||||
for _ in range(num_iters):
|
|
||||||
f()
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
end = time.time()
|
|
||||||
print(f'{name}: {(end - start) / num_iters * 1000:.3f} ms')
|
|
||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
|
||||||
def benchmark_multi_query_cached_kv_attention(
|
|
||||||
query_lens: List[int],
|
|
||||||
context_lens: List[int],
|
|
||||||
num_heads: int,
|
|
||||||
head_size: int,
|
|
||||||
block_size: int,
|
|
||||||
num_blocks: int,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
) -> None:
|
|
||||||
print(f'query_lens: {query_lens}, context_lens: {context_lens}, '
|
|
||||||
f'num_heads: {num_heads}, head_size: {head_size}, block_size: '
|
|
||||||
f'{block_size}, num_blocks: {num_blocks}, dtype: {dtype}')
|
|
||||||
# Create query tensor.
|
|
||||||
num_queries = len(query_lens)
|
|
||||||
cu_query_lens = [0]
|
|
||||||
for query_len in query_lens:
|
|
||||||
cu_query_lens.append(cu_query_lens[-1] + query_len)
|
|
||||||
num_total_tokens = cu_query_lens[-1]
|
|
||||||
qkv = torch.randn(
|
|
||||||
num_total_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
|
|
||||||
query, _, _ = qkv.unbind(dim=1)
|
|
||||||
|
|
||||||
# Create key and value cache.
|
|
||||||
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
|
||||||
key_block_shape = (num_heads, head_size // x, block_size, x)
|
|
||||||
key_cache = torch.randn(
|
|
||||||
size=(num_blocks, *key_block_shape), dtype=dtype, device='cuda')
|
|
||||||
value_block_shape = (num_heads, head_size, block_size)
|
|
||||||
value_cache = torch.randn(
|
|
||||||
size=(num_blocks, *value_block_shape), dtype=dtype, device='cuda')
|
|
||||||
|
|
||||||
# Create block tables.
|
|
||||||
max_context_len = max(context_lens)
|
|
||||||
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
|
|
||||||
block_tables = []
|
|
||||||
for _ in range(num_queries):
|
|
||||||
block_table = [
|
|
||||||
random.randint(0, num_blocks - 1)
|
|
||||||
for _ in range(max_num_blocks_per_seq)
|
|
||||||
]
|
|
||||||
block_tables.append(block_table)
|
|
||||||
block_tables = torch.tensor(block_tables, dtype=torch.int, device='cuda')
|
|
||||||
|
|
||||||
# Create input and output data structures.
|
|
||||||
cu_query_lens = torch.tensor(cu_query_lens, dtype=torch.int, device='cuda')
|
|
||||||
context_len_tensor = torch.tensor(context_lens, dtype=torch.int, device='cuda')
|
|
||||||
scale = float(1.0 / (head_size ** 0.5))
|
|
||||||
output = torch.empty(
|
|
||||||
num_total_tokens, num_heads, head_size, dtype=dtype, device='cuda')
|
|
||||||
|
|
||||||
# Run our implementation.
|
|
||||||
def run_ours():
|
|
||||||
attention_ops.multi_query_cached_kv_attention(
|
|
||||||
cu_query_lens,
|
|
||||||
output,
|
|
||||||
query,
|
|
||||||
key_cache,
|
|
||||||
value_cache,
|
|
||||||
scale,
|
|
||||||
block_tables,
|
|
||||||
context_len_tensor,
|
|
||||||
block_size,
|
|
||||||
max_context_len,
|
|
||||||
)
|
|
||||||
benchmark('Ours', run_ours)
|
|
||||||
|
|
||||||
# Upper bound: Flash attention.
|
|
||||||
# Becuase Flash attention cannot read our own cache,
|
|
||||||
# we make key and value tensors contiguous.
|
|
||||||
num_kv_tokens = sum(context_lens)
|
|
||||||
cu_context_lens = [0]
|
|
||||||
for context_len in context_lens:
|
|
||||||
cu_context_lens.append(cu_context_lens[-1] + context_len)
|
|
||||||
cu_context_lens = torch.tensor(cu_context_lens, dtype=torch.int, device='cuda')
|
|
||||||
qkv = torch.randn(
|
|
||||||
num_kv_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
|
|
||||||
_, key, value = qkv.unbind(dim=1)
|
|
||||||
ref_output = torch.empty_like(output)
|
|
||||||
|
|
||||||
# Run Flash attention.
|
|
||||||
def run_flash_attn():
|
|
||||||
_flash_attn_forward(
|
|
||||||
query,
|
|
||||||
key,
|
|
||||||
value,
|
|
||||||
ref_output,
|
|
||||||
cu_query_lens,
|
|
||||||
cu_context_lens,
|
|
||||||
max(query_lens),
|
|
||||||
max_context_len,
|
|
||||||
dropout_p=0.0,
|
|
||||||
softmax_scale=scale,
|
|
||||||
causal=True,
|
|
||||||
return_softmax=False,
|
|
||||||
)
|
|
||||||
benchmark('Flash attention', run_flash_attn)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
BLOCK_SIZE = 8
|
|
||||||
NUM_BLOCKS = 1024
|
|
||||||
DTYPE = torch.half
|
|
||||||
|
|
||||||
# LLaMA-13B and OPT-13B
|
|
||||||
NUM_HEADS = 40
|
|
||||||
HEAD_SIZE = 128
|
|
||||||
|
|
||||||
run_benchmark = functools.partial(
|
|
||||||
benchmark_multi_query_cached_kv_attention,
|
|
||||||
num_heads=NUM_HEADS,
|
|
||||||
head_size=HEAD_SIZE,
|
|
||||||
block_size=BLOCK_SIZE,
|
|
||||||
num_blocks=NUM_BLOCKS,
|
|
||||||
dtype=DTYPE,
|
|
||||||
)
|
|
||||||
|
|
||||||
run_benchmark(
|
|
||||||
query_lens=[64] * 1,
|
|
||||||
context_lens=[64] * 1,
|
|
||||||
)
|
|
||||||
run_benchmark(
|
|
||||||
query_lens=[128] * 1,
|
|
||||||
context_lens=[128] * 1,
|
|
||||||
)
|
|
||||||
run_benchmark(
|
|
||||||
query_lens=[64] * 8,
|
|
||||||
context_lens=[64] * 8,
|
|
||||||
)
|
|
||||||
run_benchmark(
|
|
||||||
query_lens=[128] * 8,
|
|
||||||
context_lens=[128] * 8,
|
|
||||||
)
|
|
||||||
run_benchmark(
|
|
||||||
query_lens=[64, 32, 16],
|
|
||||||
context_lens=[128, 256, 64],
|
|
||||||
)
|
|
||||||
run_benchmark(
|
|
||||||
query_lens=[1024],
|
|
||||||
context_lens=[1024],
|
|
||||||
)
|
|
||||||
@ -1,81 +0,0 @@
|
|||||||
import functools
|
|
||||||
import random
|
|
||||||
import time
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from cacheflow import cache_ops
|
|
||||||
|
|
||||||
|
|
||||||
def benchmark(name, f, size: int, num_warmup = 10, num_iters = 100):
|
|
||||||
for _ in range(num_warmup):
|
|
||||||
f()
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
start = time.time()
|
|
||||||
for _ in range(num_iters):
|
|
||||||
f()
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
end = time.time()
|
|
||||||
avg_time = (end - start) / num_iters
|
|
||||||
print(f'[Latency] {name}: {avg_time * 1000:.3f} ms')
|
|
||||||
print(f'[Throughput] {name}: {size / avg_time / 2 ** 30:.3f} GB/s')
|
|
||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
|
||||||
def test_gather_cached_kv(
|
|
||||||
num_tokens: int,
|
|
||||||
num_heads: int,
|
|
||||||
head_size: int,
|
|
||||||
block_size: int,
|
|
||||||
num_blocks: int,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
) -> None:
|
|
||||||
print(f'num_tokens: {num_tokens}, num_heads: {num_heads}, '
|
|
||||||
f'head_size: {head_size}, block_size: {block_size}, '
|
|
||||||
f'num_blocks: {num_blocks}, dtype: {dtype}')
|
|
||||||
|
|
||||||
num_slots = block_size * num_blocks
|
|
||||||
slot_mapping = random.sample(range(num_slots), num_tokens)
|
|
||||||
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device='cuda')
|
|
||||||
|
|
||||||
qkv = torch.randn(
|
|
||||||
num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
|
|
||||||
_, key, value = qkv.unbind(dim=1)
|
|
||||||
|
|
||||||
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
|
||||||
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
|
|
||||||
key_cache = torch.randn(size=key_cache_shape, dtype=dtype, device='cuda')
|
|
||||||
|
|
||||||
value_cache_shape = (num_blocks, num_heads, head_size, block_size)
|
|
||||||
value_cache = torch.randn(
|
|
||||||
size=value_cache_shape, dtype=dtype, device='cuda')
|
|
||||||
|
|
||||||
# Run Flash attention.
|
|
||||||
def run():
|
|
||||||
cache_ops.gather_cached_kv(key, value, key_cache, value_cache, slot_mapping)
|
|
||||||
|
|
||||||
benchmark('gather_cached_kv', run,
|
|
||||||
size=num_tokens * num_heads * head_size * 2 * qkv.element_size())
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
BLOCK_SIZE = 8
|
|
||||||
NUM_BLOCKS = 1024
|
|
||||||
DTYPE = torch.half
|
|
||||||
|
|
||||||
# LLaMA-13B and OPT-13B
|
|
||||||
NUM_HEADS = 40
|
|
||||||
HEAD_SIZE = 128
|
|
||||||
|
|
||||||
run_benchmark = functools.partial(
|
|
||||||
test_gather_cached_kv,
|
|
||||||
num_heads=NUM_HEADS,
|
|
||||||
head_size=HEAD_SIZE,
|
|
||||||
block_size=BLOCK_SIZE,
|
|
||||||
num_blocks=NUM_BLOCKS,
|
|
||||||
dtype=DTYPE,
|
|
||||||
)
|
|
||||||
|
|
||||||
for i in range(6, 12):
|
|
||||||
run_benchmark(num_tokens=2 ** i)
|
|
||||||
@ -1,105 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import time
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from tqdm import tqdm
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from cacheflow.master.simple_frontend import SimpleFrontend
|
|
||||||
from cacheflow.master.server import (Server, add_server_arguments,
|
|
||||||
initialize_ray_cluster)
|
|
||||||
from cacheflow.sampling_params import SamplingParams
|
|
||||||
from cacheflow.utils import get_gpu_memory, get_cpu_memory
|
|
||||||
|
|
||||||
|
|
||||||
def main(args: argparse.Namespace):
|
|
||||||
# TODO(zhuohan): Support pipeline parallelism.
|
|
||||||
assert args.pipeline_parallel_size == 1, (
|
|
||||||
'Pipeline parallelism is not supported yet.')
|
|
||||||
|
|
||||||
(num_nodes, num_devices_per_node, distributed_init_method,
|
|
||||||
all_stage_devices) = (
|
|
||||||
initialize_ray_cluster(
|
|
||||||
address='local',
|
|
||||||
pipeline_parallel_size=args.pipeline_parallel_size,
|
|
||||||
tensor_parallel_size=args.tensor_parallel_size))
|
|
||||||
|
|
||||||
# Create a server.
|
|
||||||
server = Server(
|
|
||||||
model=args.model,
|
|
||||||
model_path=args.model_path,
|
|
||||||
use_dummy_weights=args.use_dummy_weights,
|
|
||||||
pipeline_parallel_size=args.pipeline_parallel_size,
|
|
||||||
tensor_parallel_size=args.tensor_parallel_size,
|
|
||||||
block_size=args.block_size,
|
|
||||||
dtype=args.dtype,
|
|
||||||
seed=args.seed,
|
|
||||||
swap_space=args.swap_space,
|
|
||||||
max_num_batched_tokens=args.max_num_batched_tokens,
|
|
||||||
max_num_sequences=args.max_num_sequences,
|
|
||||||
num_nodes=num_nodes,
|
|
||||||
num_devices_per_node=num_devices_per_node,
|
|
||||||
distributed_init_method=distributed_init_method,
|
|
||||||
all_stage_devices=all_stage_devices,
|
|
||||||
gpu_memory=get_gpu_memory(),
|
|
||||||
cpu_memory=get_cpu_memory(),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create a frontend.
|
|
||||||
frontend = SimpleFrontend(
|
|
||||||
model_name=args.model,
|
|
||||||
block_size=args.block_size,
|
|
||||||
)
|
|
||||||
sampling_params_dict = {
|
|
||||||
'n': args.n,
|
|
||||||
'temperature': 0.0 if args.use_beam_search else 1.0,
|
|
||||||
'top_p': 1.0,
|
|
||||||
'use_beam_search': args.use_beam_search,
|
|
||||||
'stop_token_ids': set(),
|
|
||||||
'max_num_steps': args.output_len,
|
|
||||||
}
|
|
||||||
sampling_params = SamplingParams.from_dict(sampling_params_dict)
|
|
||||||
print(sampling_params)
|
|
||||||
input_token_ids = [0] * args.input_len
|
|
||||||
|
|
||||||
def profile_step(profile=False):
|
|
||||||
if profile:
|
|
||||||
torch.cuda.cudart().cudaProfilerStart()
|
|
||||||
for _ in range(args.batch_size):
|
|
||||||
frontend._add_query(input_token_ids, sampling_params)
|
|
||||||
server.add_sequence_groups(frontend.get_inputs())
|
|
||||||
start_time = time.time()
|
|
||||||
while True:
|
|
||||||
server.step()
|
|
||||||
if not server.has_unfinished_requests():
|
|
||||||
break
|
|
||||||
end_time = time.time()
|
|
||||||
latency = end_time - start_time
|
|
||||||
if profile:
|
|
||||||
torch.cuda.cudart().cudaProfilerStop()
|
|
||||||
return latency
|
|
||||||
|
|
||||||
print("Warm up step")
|
|
||||||
profile_step()
|
|
||||||
|
|
||||||
# Benchmark.
|
|
||||||
latencies = []
|
|
||||||
for _ in tqdm(range(3), desc="Profile step"):
|
|
||||||
latencies.append(profile_step())
|
|
||||||
print(f'Avg latency: {np.mean(latencies)} seconds')
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
parser = argparse.ArgumentParser(description='CacheFlow simple server.')
|
|
||||||
parser = add_server_arguments(parser)
|
|
||||||
parser.add_argument('--input-len', type=int, default=32)
|
|
||||||
parser.add_argument('--output-len', type=int, default=128)
|
|
||||||
parser.add_argument('--batch-size', type=int, default=8)
|
|
||||||
parser.add_argument('--n', type=int, default=1)
|
|
||||||
parser.add_argument('--use-beam-search', action='store_true')
|
|
||||||
args = parser.parse_args()
|
|
||||||
args.max_num_batched_tokens = max(
|
|
||||||
args.max_num_batched_tokens, args.batch_size * args.input_len)
|
|
||||||
print(args)
|
|
||||||
main(args)
|
|
||||||
@ -1,290 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import pickle
|
|
||||||
import time
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from tqdm import tqdm
|
|
||||||
from transformers import AutoConfig
|
|
||||||
|
|
||||||
from benchmark.trace import generate_text_completion_requests
|
|
||||||
from cacheflow.master.simple_frontend import SimpleFrontend
|
|
||||||
from cacheflow.master.server import (Server, add_server_arguments,
|
|
||||||
initialize_ray_cluster)
|
|
||||||
from cacheflow.sampling_params import SamplingParams
|
|
||||||
from cacheflow.utils import get_gpu_memory, get_cpu_memory
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def main(args: argparse.Namespace):
|
|
||||||
assert args.pipeline_parallel_size == 1, (
|
|
||||||
'Pipeline parallelism is not supported yet.')
|
|
||||||
|
|
||||||
(num_nodes, num_devices_per_node, distributed_init_method,
|
|
||||||
all_stage_devices) = (
|
|
||||||
initialize_ray_cluster(
|
|
||||||
address='local',
|
|
||||||
pipeline_parallel_size=args.pipeline_parallel_size,
|
|
||||||
tensor_parallel_size=args.tensor_parallel_size))
|
|
||||||
|
|
||||||
# Create a server.
|
|
||||||
server = Server(
|
|
||||||
model=args.model,
|
|
||||||
model_path=args.model_path,
|
|
||||||
use_dummy_weights=args.use_dummy_weights,
|
|
||||||
pipeline_parallel_size=args.pipeline_parallel_size,
|
|
||||||
tensor_parallel_size=args.tensor_parallel_size,
|
|
||||||
block_size=args.block_size,
|
|
||||||
dtype=args.dtype,
|
|
||||||
seed=args.seed,
|
|
||||||
swap_space=args.swap_space,
|
|
||||||
max_num_batched_tokens=args.max_num_batched_tokens,
|
|
||||||
max_num_sequences=args.max_num_sequences,
|
|
||||||
num_nodes=num_nodes,
|
|
||||||
num_devices_per_node=num_devices_per_node,
|
|
||||||
distributed_init_method=distributed_init_method,
|
|
||||||
all_stage_devices=all_stage_devices,
|
|
||||||
gpu_memory=get_gpu_memory(),
|
|
||||||
cpu_memory=get_cpu_memory(),
|
|
||||||
collect_stats=True,
|
|
||||||
do_memory_analysis=args.do_memory_analysis,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create a frontend.
|
|
||||||
frontend = SimpleFrontend(
|
|
||||||
model_name=args.model,
|
|
||||||
block_size=args.block_size,
|
|
||||||
)
|
|
||||||
# Generate requests.
|
|
||||||
requests = generate_text_completion_requests(
|
|
||||||
args.dataset,
|
|
||||||
args.request_rate,
|
|
||||||
args.duration,
|
|
||||||
args.seed,
|
|
||||||
args.n1,
|
|
||||||
args.n2,
|
|
||||||
args.n3,
|
|
||||||
args.n4,
|
|
||||||
args.n6,
|
|
||||||
args.n2_beam,
|
|
||||||
args.n4_beam,
|
|
||||||
args.n6_beam,
|
|
||||||
args.n8_beam,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Warm up.
|
|
||||||
logger.info('Warming up.')
|
|
||||||
num_warmup_requests = 8
|
|
||||||
warmup_input_len = 8
|
|
||||||
warmup_output_len = 32
|
|
||||||
warmup_sampling_params = SamplingParams(
|
|
||||||
n=1,
|
|
||||||
temperature=1.0,
|
|
||||||
top_p=0.99,
|
|
||||||
max_num_steps=warmup_output_len,
|
|
||||||
use_beam_search=False,
|
|
||||||
stop_token_ids=set(),
|
|
||||||
num_logprobs=0,
|
|
||||||
context_window_size=None,
|
|
||||||
)
|
|
||||||
for _ in range(num_warmup_requests):
|
|
||||||
frontend._add_query([0] * warmup_input_len, warmup_sampling_params)
|
|
||||||
server.add_sequence_groups(frontend.get_inputs())
|
|
||||||
while True:
|
|
||||||
server.step()
|
|
||||||
if not server.has_unfinished_requests():
|
|
||||||
break
|
|
||||||
|
|
||||||
# Start benchmarking.
|
|
||||||
logger.info('Start benchmarking.')
|
|
||||||
# Initialize tqdm.
|
|
||||||
pbar = tqdm(total=len(requests), desc='Finished requests')
|
|
||||||
|
|
||||||
finished = []
|
|
||||||
server.scheduler.reset_stats()
|
|
||||||
start_time = time.time()
|
|
||||||
while True:
|
|
||||||
now = time.time()
|
|
||||||
if args.timeout is not None and now - start_time > args.timeout:
|
|
||||||
logger.info('Timeout. Stop benchmarking.')
|
|
||||||
break
|
|
||||||
|
|
||||||
while requests:
|
|
||||||
if requests[0][0] <= now - start_time:
|
|
||||||
request_time, input_tokens, sampling_params = requests.pop(0)
|
|
||||||
frontend._add_query(
|
|
||||||
input_tokens, sampling_params, arrival_time=start_time + request_time)
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
server.add_sequence_groups(frontend.get_inputs())
|
|
||||||
updated_seq_groups = server.step()
|
|
||||||
|
|
||||||
now = time.time()
|
|
||||||
for seq_group in updated_seq_groups:
|
|
||||||
if not seq_group.is_finished():
|
|
||||||
continue
|
|
||||||
arrival_time = seq_group.arrival_time
|
|
||||||
finish_time = now
|
|
||||||
for seq in seq_group.get_seqs():
|
|
||||||
seq_len = seq.get_len()
|
|
||||||
output_len = seq_len - seq.prompt_len
|
|
||||||
finished.append({
|
|
||||||
'group_id': seq_group.group_id,
|
|
||||||
'seq_id': seq.seq_id,
|
|
||||||
'arrival_time': arrival_time,
|
|
||||||
'finish_time': finish_time,
|
|
||||||
'prompt_len': seq.prompt_len,
|
|
||||||
'output_len': output_len,
|
|
||||||
})
|
|
||||||
pbar.update(1)
|
|
||||||
|
|
||||||
if not (requests or server.has_unfinished_requests()):
|
|
||||||
break
|
|
||||||
pbar.close()
|
|
||||||
logger.info('Finish benchmarking. Saving stats.')
|
|
||||||
server.scheduler.save_stats(args.output_dir)
|
|
||||||
with open(os.path.join(args.output_dir, 'sequences.pkl'), 'wb') as f:
|
|
||||||
pickle.dump(finished, f)
|
|
||||||
logger.info('Done.')
|
|
||||||
|
|
||||||
|
|
||||||
def get_model_name(model: str) -> str:
|
|
||||||
OPT_MODELS = [
|
|
||||||
'opt-125m',
|
|
||||||
'opt-350m',
|
|
||||||
'opt-1.3b',
|
|
||||||
'opt-2.7b',
|
|
||||||
'opt-6.7b',
|
|
||||||
'opt-13b',
|
|
||||||
'opt-30b',
|
|
||||||
'opt-66b',
|
|
||||||
'opt-175b',
|
|
||||||
]
|
|
||||||
for opt_model in OPT_MODELS:
|
|
||||||
if opt_model in model:
|
|
||||||
return opt_model
|
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(model)
|
|
||||||
assert config.model_type == 'llama'
|
|
||||||
hidden_size = config.hidden_size
|
|
||||||
if hidden_size == 4096:
|
|
||||||
return 'llama-7b'
|
|
||||||
elif hidden_size == 5120:
|
|
||||||
return 'llama-13b'
|
|
||||||
elif hidden_size == 6656:
|
|
||||||
return 'llama-30b'
|
|
||||||
elif hidden_size == 8192:
|
|
||||||
return 'llama-65b'
|
|
||||||
else:
|
|
||||||
raise ValueError(f'Unknown model: {model}')
|
|
||||||
|
|
||||||
|
|
||||||
def get_dataset_name(dataset: str) -> str:
|
|
||||||
if 'sharegpt' in dataset.lower():
|
|
||||||
return 'sharegpt'
|
|
||||||
elif 'alpaca' in dataset.lower():
|
|
||||||
return 'alpaca'
|
|
||||||
else:
|
|
||||||
raise ValueError(f'Unknown dataset: {dataset}')
|
|
||||||
|
|
||||||
|
|
||||||
def get_sampling_dir_name(
|
|
||||||
n1: float,
|
|
||||||
n2: float,
|
|
||||||
n3: float,
|
|
||||||
n4: float,
|
|
||||||
n6: float,
|
|
||||||
n2_beam: float,
|
|
||||||
n4_beam: float,
|
|
||||||
n6_beam: float,
|
|
||||||
n8_beam: float,
|
|
||||||
) -> str:
|
|
||||||
method = ''
|
|
||||||
if n1 > 0.0:
|
|
||||||
method = 'n1' if n1 == 1.0 else method + f'n1-{n1}-'
|
|
||||||
if n2 > 0.0:
|
|
||||||
method = 'n2' if n2 == 1.0 else method + f'n2-{n2}-'
|
|
||||||
if n3 > 0.0:
|
|
||||||
method = 'n3' if n3 == 1.0 else method + f'n3-{n3}-'
|
|
||||||
if n4 > 0.0:
|
|
||||||
method = 'n4' if n4 == 1.0 else method + f'n4-{n4}-'
|
|
||||||
if n6 > 0.0:
|
|
||||||
method = 'n6' if n6 == 1.0 else method + f'n6-{n6}-'
|
|
||||||
if n2_beam > 0.0:
|
|
||||||
method = 'n2-beam' if n2_beam == 1.0 else method + f'n2-beam-{n2_beam}-'
|
|
||||||
if n4_beam > 0.0:
|
|
||||||
method = 'n4-beam' if n4_beam == 1.0 else method + f'n4-beam-{n4_beam}-'
|
|
||||||
if n6_beam > 0.0:
|
|
||||||
method = 'n6-beam' if n6_beam == 1.0 else method + f'n6-beam-{n6_beam}-'
|
|
||||||
if n8_beam > 0.0:
|
|
||||||
method = 'n8-beam' if n8_beam == 1.0 else method + f'n8-beam-{n8_beam}-'
|
|
||||||
return method[:-1] if method.endswith('-') else method
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
parser = argparse.ArgumentParser(description='CacheFlow simple server.')
|
|
||||||
parser = add_server_arguments(parser)
|
|
||||||
parser.add_argument('--output-dir', type=str, help='path to output directory', default=None)
|
|
||||||
|
|
||||||
parser.add_argument('--dataset', type=str, help='path to dataset', required=True)
|
|
||||||
parser.add_argument('--request-rate', type=float, help='reqs/sec', required=True)
|
|
||||||
parser.add_argument('--duration', type=int, help='duration in seconds', required=True)
|
|
||||||
parser.add_argument('--do-memory-analysis', action='store_true',
|
|
||||||
help='do memory analysis (This will lower the throughput. Use this only for analysis.)')
|
|
||||||
parser.add_argument('--timeout', type=int, help='time out in seconds', default=None)
|
|
||||||
|
|
||||||
parser.add_argument('--n1', type=float, help='ratio of requests with n=1', default=0.0)
|
|
||||||
parser.add_argument('--n2', type=float, help='ratio of requests with n=2', default=0.0)
|
|
||||||
parser.add_argument('--n3', type=float, help='ratio of requests with n=3', default=0.0)
|
|
||||||
parser.add_argument('--n4', type=float, help='ratio of requests with n=4', default=0.0)
|
|
||||||
parser.add_argument('--n6', type=float, help='ratio of requests with n=6', default=0.0)
|
|
||||||
parser.add_argument('--n2-beam', type=float, help='ratio of requests with n=2 & beam search', default=0.0)
|
|
||||||
parser.add_argument('--n4-beam', type=float, help='ratio of requests with n=4 & beam search', default=0.0)
|
|
||||||
parser.add_argument('--n6-beam', type=float, help='ratio of requests with n=6 & beam search', default=0.0)
|
|
||||||
parser.add_argument('--n8-beam', type=float, help='ratio of requests with n=8 & beam search', default=0.0)
|
|
||||||
args = parser.parse_args()
|
|
||||||
if args.n1 + args.n2 + args.n3 + args.n4 + args.n6 + args.n2_beam + args.n4_beam + args.n6_beam + args.n8_beam != 1.0:
|
|
||||||
raise ValueError('The ratios of requests must sum to 1.')
|
|
||||||
|
|
||||||
model_name = get_model_name(args.model)
|
|
||||||
dataset_name = get_dataset_name(args.dataset)
|
|
||||||
if 'opt' in model_name:
|
|
||||||
if 'opt' not in args.dataset.lower():
|
|
||||||
raise ValueError(f'OPT models can only be used with OPT datasets.')
|
|
||||||
elif 'llama' in model_name:
|
|
||||||
if 'llama' not in args.dataset.lower():
|
|
||||||
raise ValueError(f'Llama models can only be used with Llama datasets.')
|
|
||||||
|
|
||||||
dataset_name = 'sharegpt' if 'sharegpt' in args.dataset else 'alpaca'
|
|
||||||
sample_dir = get_sampling_dir_name(
|
|
||||||
args.n1, args.n2, args.n3, args.n4, args.n6, args.n2_beam, args.n4_beam, args.n6_beam, args.n8_beam)
|
|
||||||
if args.output_dir is None:
|
|
||||||
args.output_dir = os.path.join(
|
|
||||||
'../exp',
|
|
||||||
dataset_name,
|
|
||||||
f'{model_name}-tp{args.tensor_parallel_size}',
|
|
||||||
sample_dir,
|
|
||||||
'cacheflow',
|
|
||||||
f'block{args.block_size}',
|
|
||||||
f'req-rate-{args.request_rate}',
|
|
||||||
f'seed{args.seed}',
|
|
||||||
f'duration-{args.duration}',
|
|
||||||
)
|
|
||||||
os.makedirs(args.output_dir, exist_ok=True)
|
|
||||||
|
|
||||||
# Set up logging.
|
|
||||||
logging.basicConfig(
|
|
||||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
|
||||||
datefmt="%m/%d/%Y %H:%M:%S",
|
|
||||||
level=logging.INFO,
|
|
||||||
handlers=[
|
|
||||||
logging.StreamHandler(),
|
|
||||||
logging.FileHandler(os.path.join(args.output_dir, 'log.txt')),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
logger.info(args)
|
|
||||||
|
|
||||||
main(args)
|
|
||||||
@ -1,116 +0,0 @@
|
|||||||
import pickle
|
|
||||||
import random
|
|
||||||
from typing import List, Tuple
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from cacheflow.sampling_params import SamplingParams
|
|
||||||
|
|
||||||
|
|
||||||
def generate_text_completion_requests(
|
|
||||||
dataset: str,
|
|
||||||
request_rate: float,
|
|
||||||
duration: int,
|
|
||||||
seed: int,
|
|
||||||
n1: float = 0.0,
|
|
||||||
n2: float = 0.0,
|
|
||||||
n3: float = 0.0,
|
|
||||||
n4: float = 0.0,
|
|
||||||
n6: float = 0.0,
|
|
||||||
n2_beam: float = 0.0,
|
|
||||||
n4_beam: float = 0.0,
|
|
||||||
n6_beam: float = 0.0,
|
|
||||||
n8_beam: float = 0.0,
|
|
||||||
max_seq_len: int = 2048,
|
|
||||||
time_quantum: int = 10,
|
|
||||||
) -> List[Tuple[float, List[int], SamplingParams]]:
|
|
||||||
random.seed(seed)
|
|
||||||
np.random.seed(seed)
|
|
||||||
|
|
||||||
# Generate timestamps for requests using Poisson distribution.
|
|
||||||
lam = request_rate * (time_quantum / 1000)
|
|
||||||
quantums_per_sec = 1000 / time_quantum
|
|
||||||
arrival_times = np.random.poisson(
|
|
||||||
lam=lam, size=int(duration * quantums_per_sec))
|
|
||||||
timestamps = []
|
|
||||||
for i, n in enumerate(arrival_times):
|
|
||||||
timestamps += [i * (time_quantum / 1000)] * n
|
|
||||||
|
|
||||||
# Load and shuffle the dataset.
|
|
||||||
num_requests = len(timestamps)
|
|
||||||
with open(dataset, 'rb') as f:
|
|
||||||
data = pickle.load(f)
|
|
||||||
|
|
||||||
filtered = []
|
|
||||||
for pair in data:
|
|
||||||
input_tokens, output_tokens = pair
|
|
||||||
input_len = len(input_tokens)
|
|
||||||
output_len = len(output_tokens)
|
|
||||||
# Filter out too long sequences.
|
|
||||||
if input_len + output_len < max_seq_len:
|
|
||||||
# Output tokens are not needed for the benchmark.
|
|
||||||
filtered.append((input_tokens, output_len))
|
|
||||||
|
|
||||||
data = []
|
|
||||||
while len(data) < num_requests:
|
|
||||||
data += filtered
|
|
||||||
data = data[:num_requests]
|
|
||||||
# Shuffle the data.
|
|
||||||
assert len(data) == len(timestamps)
|
|
||||||
random.shuffle(data)
|
|
||||||
|
|
||||||
random_sampling_params_dict = {
|
|
||||||
'temperature': 1.0,
|
|
||||||
'top_p': 1.0,
|
|
||||||
'use_beam_search': False,
|
|
||||||
'stop_token_ids': set(),
|
|
||||||
'num_logprobs': 0,
|
|
||||||
'context_window_size': None,
|
|
||||||
}
|
|
||||||
beam_search_params_dict = {
|
|
||||||
'temperature': 0.0,
|
|
||||||
'top_p': 1.0,
|
|
||||||
'use_beam_search': True,
|
|
||||||
'stop_token_ids': set(),
|
|
||||||
'num_logprobs': 0,
|
|
||||||
'context_window_size': None,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Generate requests based on the sampling parameter ratio.
|
|
||||||
requests = []
|
|
||||||
assert n1 + n2 + n3 + n4 + n6 + n2_beam + n4_beam + n6_beam + n8_beam == 1.0
|
|
||||||
cum_sum = 0
|
|
||||||
for timestamp, pair in zip(timestamps, data):
|
|
||||||
input_tokens, output_len = pair
|
|
||||||
if cum_sum < n1 * num_requests:
|
|
||||||
sampling_params = SamplingParams(
|
|
||||||
n=1, max_num_steps=output_len, **random_sampling_params_dict)
|
|
||||||
elif cum_sum < (n1 + n2) * num_requests:
|
|
||||||
sampling_params = SamplingParams(
|
|
||||||
n=2, max_num_steps=output_len, **random_sampling_params_dict)
|
|
||||||
elif cum_sum < (n1 + n2 + n3) * num_requests:
|
|
||||||
sampling_params = SamplingParams(
|
|
||||||
n=3, max_num_steps=output_len, **random_sampling_params_dict)
|
|
||||||
elif cum_sum < (n1 + n2 + n3 + n4) * num_requests:
|
|
||||||
sampling_params = SamplingParams(
|
|
||||||
n=4, max_num_steps=output_len, **random_sampling_params_dict)
|
|
||||||
elif cum_sum < (n1 + n2 + n3 + n4 + n6) * num_requests:
|
|
||||||
sampling_params = SamplingParams(
|
|
||||||
n=6, max_num_steps=output_len, **random_sampling_params_dict)
|
|
||||||
elif cum_sum < (n1 + n2 + n3 + n4 + n6 + n2_beam) * num_requests:
|
|
||||||
sampling_params = SamplingParams(
|
|
||||||
n=2, max_num_steps=output_len, **beam_search_params_dict)
|
|
||||||
elif cum_sum < (n1 + n2 + n3 + n4 + n6 + n2_beam + n4_beam) * num_requests:
|
|
||||||
sampling_params = SamplingParams(
|
|
||||||
n=4, max_num_steps=output_len, **beam_search_params_dict)
|
|
||||||
elif cum_sum < (n1 + n2 + n3 + n4 + n6 + n2_beam + n4_beam + n6_beam) * num_requests:
|
|
||||||
sampling_params = SamplingParams(
|
|
||||||
n=6, max_num_steps=output_len, **beam_search_params_dict)
|
|
||||||
elif cum_sum < (n1 + n2 + n3 + n4 + n6 + n2_beam + n4_beam + n6_beam + n8_beam) * num_requests:
|
|
||||||
sampling_params = SamplingParams(
|
|
||||||
n=8, max_num_steps=output_len, **beam_search_params_dict)
|
|
||||||
else:
|
|
||||||
raise ValueError('Invalid request ratio.')
|
|
||||||
cum_sum += 1
|
|
||||||
requests.append((timestamp, input_tokens, sampling_params))
|
|
||||||
return requests
|
|
||||||
8
benchmarks/README.md
Normal file
8
benchmarks/README.md
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
# Benchmarking vLLM
|
||||||
|
|
||||||
|
## Downloading the ShareGPT dataset
|
||||||
|
|
||||||
|
You can download the dataset by running:
|
||||||
|
```bash
|
||||||
|
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
|
||||||
|
```
|
||||||
101
benchmarks/benchmark_latency.py
Normal file
101
benchmarks/benchmark_latency.py
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
"""Benchmark the latency of processing a single batch of requests."""
|
||||||
|
import argparse
|
||||||
|
import time
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
|
||||||
|
|
||||||
|
def main(args: argparse.Namespace):
|
||||||
|
print(args)
|
||||||
|
|
||||||
|
# Process all the requests in a single batch if possible.
|
||||||
|
# NOTE(woosuk): If the request cannot be processed in a single batch,
|
||||||
|
# the engine will automatically process the request in multiple batches.
|
||||||
|
llm = LLM(
|
||||||
|
model=args.model,
|
||||||
|
tokenizer=args.tokenizer,
|
||||||
|
quantization=args.quantization,
|
||||||
|
tensor_parallel_size=args.tensor_parallel_size,
|
||||||
|
max_num_seqs=args.batch_size,
|
||||||
|
max_num_batched_tokens=args.batch_size * args.input_len,
|
||||||
|
trust_remote_code=args.trust_remote_code,
|
||||||
|
dtype=args.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
n=args.n,
|
||||||
|
temperature=0.0 if args.use_beam_search else 1.0,
|
||||||
|
top_p=1.0,
|
||||||
|
use_beam_search=args.use_beam_search,
|
||||||
|
ignore_eos=True,
|
||||||
|
max_tokens=args.output_len,
|
||||||
|
)
|
||||||
|
print(sampling_params)
|
||||||
|
dummy_prompt_token_ids = [[0] * args.input_len] * args.batch_size
|
||||||
|
|
||||||
|
def run_to_completion(profile: bool = False):
|
||||||
|
if profile:
|
||||||
|
torch.cuda.cudart().cudaProfilerStart()
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
|
llm.generate(prompt_token_ids=dummy_prompt_token_ids,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
use_tqdm=False)
|
||||||
|
|
||||||
|
end_time = time.perf_counter()
|
||||||
|
latency = end_time - start_time
|
||||||
|
if profile:
|
||||||
|
torch.cuda.cudart().cudaProfilerStop()
|
||||||
|
return latency
|
||||||
|
|
||||||
|
print("Warming up...")
|
||||||
|
run_to_completion(profile=False)
|
||||||
|
|
||||||
|
# Benchmark.
|
||||||
|
latencies = []
|
||||||
|
for _ in tqdm(range(args.num_iters), desc="Profiling iterations"):
|
||||||
|
latencies.append(run_to_completion(profile=False))
|
||||||
|
print(f'Avg latency: {np.mean(latencies)} seconds')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description='Benchmark the latency of processing a single batch of '
|
||||||
|
'requests till completion.')
|
||||||
|
parser.add_argument('--model', type=str, default='facebook/opt-125m')
|
||||||
|
parser.add_argument('--tokenizer', type=str, default=None)
|
||||||
|
parser.add_argument('--quantization',
|
||||||
|
'-q',
|
||||||
|
choices=['awq', None],
|
||||||
|
default=None)
|
||||||
|
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
|
||||||
|
parser.add_argument('--input-len', type=int, default=32)
|
||||||
|
parser.add_argument('--output-len', type=int, default=128)
|
||||||
|
parser.add_argument('--batch-size', type=int, default=8)
|
||||||
|
parser.add_argument('--n',
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help='Number of generated sequences per prompt.')
|
||||||
|
parser.add_argument('--use-beam-search', action='store_true')
|
||||||
|
parser.add_argument('--num-iters',
|
||||||
|
type=int,
|
||||||
|
default=3,
|
||||||
|
help='Number of iterations to run.')
|
||||||
|
parser.add_argument('--trust-remote-code',
|
||||||
|
action='store_true',
|
||||||
|
help='trust remote code from huggingface')
|
||||||
|
parser.add_argument(
|
||||||
|
'--dtype',
|
||||||
|
type=str,
|
||||||
|
default='auto',
|
||||||
|
choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'],
|
||||||
|
help='data type for model weights and activations. '
|
||||||
|
'The "auto" option will use FP16 precision '
|
||||||
|
'for FP32 and FP16 models, and BF16 precision '
|
||||||
|
'for BF16 models.')
|
||||||
|
args = parser.parse_args()
|
||||||
|
main(args)
|
||||||
233
benchmarks/benchmark_serving.py
Normal file
233
benchmarks/benchmark_serving.py
Normal file
@ -0,0 +1,233 @@
|
|||||||
|
"""Benchmark online serving throughput.
|
||||||
|
|
||||||
|
On the server side, run one of the following commands:
|
||||||
|
(vLLM backend)
|
||||||
|
python -m vllm.entrypoints.api_server \
|
||||||
|
--model <your_model> --swap-space 16 \
|
||||||
|
--disable-log-requests
|
||||||
|
|
||||||
|
(TGI backend)
|
||||||
|
./launch_hf_server.sh <your_model>
|
||||||
|
|
||||||
|
On the client side, run:
|
||||||
|
python benchmarks/benchmark_serving.py \
|
||||||
|
--backend <backend> \
|
||||||
|
--tokenizer <your_model> --dataset <target_dataset> \
|
||||||
|
--request-rate <request_rate>
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
from typing import AsyncGenerator, List, Tuple
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
import numpy as np
|
||||||
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||||
|
|
||||||
|
# (prompt len, output len, latency)
|
||||||
|
REQUEST_LATENCY: List[Tuple[int, int, float]] = []
|
||||||
|
|
||||||
|
|
||||||
|
def sample_requests(
|
||||||
|
dataset_path: str,
|
||||||
|
num_requests: int,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
) -> List[Tuple[str, int, int]]:
|
||||||
|
# Load the dataset.
|
||||||
|
with open(dataset_path) as f:
|
||||||
|
dataset = json.load(f)
|
||||||
|
# Filter out the conversations with less than 2 turns.
|
||||||
|
dataset = [
|
||||||
|
data for data in dataset
|
||||||
|
if len(data["conversations"]) >= 2
|
||||||
|
]
|
||||||
|
# Only keep the first two turns of each conversation.
|
||||||
|
dataset = [
|
||||||
|
(data["conversations"][0]["value"], data["conversations"][1]["value"])
|
||||||
|
for data in dataset
|
||||||
|
]
|
||||||
|
|
||||||
|
# Tokenize the prompts and completions.
|
||||||
|
prompts = [prompt for prompt, _ in dataset]
|
||||||
|
prompt_token_ids = tokenizer(prompts).input_ids
|
||||||
|
completions = [completion for _, completion in dataset]
|
||||||
|
completion_token_ids = tokenizer(completions).input_ids
|
||||||
|
tokenized_dataset = []
|
||||||
|
for i in range(len(dataset)):
|
||||||
|
output_len = len(completion_token_ids[i])
|
||||||
|
tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len))
|
||||||
|
|
||||||
|
# Filter out too long sequences.
|
||||||
|
filtered_dataset: List[Tuple[str, int, int]] = []
|
||||||
|
for prompt, prompt_token_ids, output_len in tokenized_dataset:
|
||||||
|
prompt_len = len(prompt_token_ids)
|
||||||
|
if prompt_len < 4 or output_len < 4:
|
||||||
|
# Prune too short sequences.
|
||||||
|
# This is because TGI causes errors when the input or output length
|
||||||
|
# is too short.
|
||||||
|
continue
|
||||||
|
if prompt_len > 1024 or prompt_len + output_len > 2048:
|
||||||
|
# Prune too long sequences.
|
||||||
|
continue
|
||||||
|
filtered_dataset.append((prompt, prompt_len, output_len))
|
||||||
|
|
||||||
|
# Sample the requests.
|
||||||
|
sampled_requests = random.sample(filtered_dataset, num_requests)
|
||||||
|
return sampled_requests
|
||||||
|
|
||||||
|
|
||||||
|
async def get_request(
|
||||||
|
input_requests: List[Tuple[str, int, int]],
|
||||||
|
request_rate: float,
|
||||||
|
) -> AsyncGenerator[Tuple[str, int, int], None]:
|
||||||
|
input_requests = iter(input_requests)
|
||||||
|
for request in input_requests:
|
||||||
|
yield request
|
||||||
|
|
||||||
|
if request_rate == float("inf"):
|
||||||
|
# If the request rate is infinity, then we don't need to wait.
|
||||||
|
continue
|
||||||
|
# Sample the request interval from the exponential distribution.
|
||||||
|
interval = np.random.exponential(1.0 / request_rate)
|
||||||
|
# The next request will be sent after the interval.
|
||||||
|
await asyncio.sleep(interval)
|
||||||
|
|
||||||
|
|
||||||
|
async def send_request(
|
||||||
|
backend: str,
|
||||||
|
api_url: str,
|
||||||
|
prompt: str,
|
||||||
|
prompt_len: int,
|
||||||
|
output_len: int,
|
||||||
|
best_of: int,
|
||||||
|
use_beam_search: bool,
|
||||||
|
) -> None:
|
||||||
|
request_start_time = time.perf_counter()
|
||||||
|
|
||||||
|
headers = {"User-Agent": "Benchmark Client"}
|
||||||
|
if backend == "vllm":
|
||||||
|
pload = {
|
||||||
|
"prompt": prompt,
|
||||||
|
"n": 1,
|
||||||
|
"best_of": best_of,
|
||||||
|
"use_beam_search": use_beam_search,
|
||||||
|
"temperature": 0.0 if use_beam_search else 1.0,
|
||||||
|
"top_p": 1.0,
|
||||||
|
"max_tokens": output_len,
|
||||||
|
"ignore_eos": True,
|
||||||
|
"stream": False,
|
||||||
|
}
|
||||||
|
elif backend == "tgi":
|
||||||
|
assert not use_beam_search
|
||||||
|
params = {
|
||||||
|
"best_of": best_of,
|
||||||
|
"max_new_tokens": output_len,
|
||||||
|
"do_sample": True,
|
||||||
|
}
|
||||||
|
pload = {
|
||||||
|
"inputs": prompt,
|
||||||
|
"parameters": params,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown backend: {backend}")
|
||||||
|
|
||||||
|
timeout = aiohttp.ClientTimeout(total=3 * 3600)
|
||||||
|
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||||
|
while True:
|
||||||
|
async with session.post(api_url, headers=headers, json=pload) as response:
|
||||||
|
chunks = []
|
||||||
|
async for chunk, _ in response.content.iter_chunks():
|
||||||
|
chunks.append(chunk)
|
||||||
|
output = b"".join(chunks).decode("utf-8")
|
||||||
|
output = json.loads(output)
|
||||||
|
|
||||||
|
# Re-send the request if it failed.
|
||||||
|
if "error" not in output:
|
||||||
|
break
|
||||||
|
|
||||||
|
request_end_time = time.perf_counter()
|
||||||
|
request_latency = request_end_time - request_start_time
|
||||||
|
REQUEST_LATENCY.append((prompt_len, output_len, request_latency))
|
||||||
|
|
||||||
|
|
||||||
|
async def benchmark(
|
||||||
|
backend: str,
|
||||||
|
api_url: str,
|
||||||
|
input_requests: List[Tuple[str, int, int]],
|
||||||
|
best_of: int,
|
||||||
|
use_beam_search: bool,
|
||||||
|
request_rate: float,
|
||||||
|
) -> None:
|
||||||
|
tasks: List[asyncio.Task] = []
|
||||||
|
async for request in get_request(input_requests, request_rate):
|
||||||
|
prompt, prompt_len, output_len = request
|
||||||
|
task = asyncio.create_task(send_request(backend, api_url, prompt,
|
||||||
|
prompt_len, output_len,
|
||||||
|
best_of, use_beam_search))
|
||||||
|
tasks.append(task)
|
||||||
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
|
||||||
|
def main(args: argparse.Namespace):
|
||||||
|
print(args)
|
||||||
|
random.seed(args.seed)
|
||||||
|
np.random.seed(args.seed)
|
||||||
|
|
||||||
|
api_url = f"http://{args.host}:{args.port}/generate"
|
||||||
|
tokenizer = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code)
|
||||||
|
input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
|
||||||
|
|
||||||
|
benchmark_start_time = time.perf_counter()
|
||||||
|
asyncio.run(benchmark(args.backend, api_url, input_requests, args.best_of,
|
||||||
|
args.use_beam_search, args.request_rate))
|
||||||
|
benchmark_end_time = time.perf_counter()
|
||||||
|
benchmark_time = benchmark_end_time - benchmark_start_time
|
||||||
|
print(f"Total time: {benchmark_time:.2f} s")
|
||||||
|
print(f"Throughput: {args.num_prompts / benchmark_time:.2f} requests/s")
|
||||||
|
|
||||||
|
# Compute the latency statistics.
|
||||||
|
avg_latency = np.mean([latency for _, _, latency in REQUEST_LATENCY])
|
||||||
|
print(f"Average latency: {avg_latency:.2f} s")
|
||||||
|
avg_per_token_latency = np.mean([
|
||||||
|
latency / (prompt_len + output_len)
|
||||||
|
for prompt_len, output_len, latency in REQUEST_LATENCY
|
||||||
|
])
|
||||||
|
print(f"Average latency per token: {avg_per_token_latency:.2f} s")
|
||||||
|
avg_per_output_token_latency = np.mean([
|
||||||
|
latency / output_len
|
||||||
|
for _, output_len, latency in REQUEST_LATENCY
|
||||||
|
])
|
||||||
|
print("Average latency per output token: "
|
||||||
|
f"{avg_per_output_token_latency:.2f} s")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Benchmark the online serving throughput.")
|
||||||
|
parser.add_argument("--backend", type=str, default="vllm",
|
||||||
|
choices=["vllm", "tgi"])
|
||||||
|
parser.add_argument("--host", type=str, default="localhost")
|
||||||
|
parser.add_argument("--port", type=int, default=8000)
|
||||||
|
parser.add_argument("--dataset", type=str, required=True,
|
||||||
|
help="Path to the dataset.")
|
||||||
|
parser.add_argument("--tokenizer", type=str, required=True,
|
||||||
|
help="Name or path of the tokenizer.")
|
||||||
|
parser.add_argument("--best-of", type=int, default=1,
|
||||||
|
help="Generates `best_of` sequences per prompt and "
|
||||||
|
"returns the best one.")
|
||||||
|
parser.add_argument("--use-beam-search", action="store_true")
|
||||||
|
parser.add_argument("--num-prompts", type=int, default=1000,
|
||||||
|
help="Number of prompts to process.")
|
||||||
|
parser.add_argument("--request-rate", type=float, default=float("inf"),
|
||||||
|
help="Number of requests per second. If this is inf, "
|
||||||
|
"then all the requests are sent at time 0. "
|
||||||
|
"Otherwise, we use Poisson process to synthesize "
|
||||||
|
"the request arrival times.")
|
||||||
|
parser.add_argument("--seed", type=int, default=0)
|
||||||
|
parser.add_argument('--trust-remote-code', action='store_true',
|
||||||
|
help='trust remote code from huggingface')
|
||||||
|
args = parser.parse_args()
|
||||||
|
main(args)
|
||||||
246
benchmarks/benchmark_throughput.py
Normal file
246
benchmarks/benchmark_throughput.py
Normal file
@ -0,0 +1,246 @@
|
|||||||
|
"""Benchmark offline inference throughput."""
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers import AutoModelForCausalLM, PreTrainedTokenizerBase
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def sample_requests(
|
||||||
|
dataset_path: str,
|
||||||
|
num_requests: int,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
) -> List[Tuple[str, int, int]]:
|
||||||
|
# Load the dataset.
|
||||||
|
with open(dataset_path) as f:
|
||||||
|
dataset = json.load(f)
|
||||||
|
# Filter out the conversations with less than 2 turns.
|
||||||
|
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
|
||||||
|
# Only keep the first two turns of each conversation.
|
||||||
|
dataset = [(data["conversations"][0]["value"],
|
||||||
|
data["conversations"][1]["value"]) for data in dataset]
|
||||||
|
|
||||||
|
# Tokenize the prompts and completions.
|
||||||
|
prompts = [prompt for prompt, _ in dataset]
|
||||||
|
prompt_token_ids = tokenizer(prompts).input_ids
|
||||||
|
completions = [completion for _, completion in dataset]
|
||||||
|
completion_token_ids = tokenizer(completions).input_ids
|
||||||
|
tokenized_dataset = []
|
||||||
|
for i in range(len(dataset)):
|
||||||
|
output_len = len(completion_token_ids[i])
|
||||||
|
tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len))
|
||||||
|
|
||||||
|
# Filter out too long sequences.
|
||||||
|
filtered_dataset: List[Tuple[str, int, int]] = []
|
||||||
|
for prompt, prompt_token_ids, output_len in tokenized_dataset:
|
||||||
|
prompt_len = len(prompt_token_ids)
|
||||||
|
if prompt_len < 4 or output_len < 4:
|
||||||
|
# Prune too short sequences.
|
||||||
|
continue
|
||||||
|
if prompt_len > 1024 or prompt_len + output_len > 2048:
|
||||||
|
# Prune too long sequences.
|
||||||
|
continue
|
||||||
|
filtered_dataset.append((prompt, prompt_len, output_len))
|
||||||
|
|
||||||
|
# Sample the requests.
|
||||||
|
sampled_requests = random.sample(filtered_dataset, num_requests)
|
||||||
|
return sampled_requests
|
||||||
|
|
||||||
|
|
||||||
|
def run_vllm(
|
||||||
|
requests: List[Tuple[str, int, int]],
|
||||||
|
model: str,
|
||||||
|
tokenizer: str,
|
||||||
|
quantization: Optional[str],
|
||||||
|
tensor_parallel_size: int,
|
||||||
|
seed: int,
|
||||||
|
n: int,
|
||||||
|
use_beam_search: bool,
|
||||||
|
trust_remote_code: bool,
|
||||||
|
dtype: str,
|
||||||
|
) -> float:
|
||||||
|
llm = LLM(
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
quantization=quantization,
|
||||||
|
tensor_parallel_size=tensor_parallel_size,
|
||||||
|
seed=seed,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add the requests to the engine.
|
||||||
|
for prompt, _, output_len in requests:
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
n=n,
|
||||||
|
temperature=0.0 if use_beam_search else 1.0,
|
||||||
|
top_p=1.0,
|
||||||
|
use_beam_search=use_beam_search,
|
||||||
|
ignore_eos=True,
|
||||||
|
max_tokens=output_len,
|
||||||
|
)
|
||||||
|
# FIXME(woosuk): Do not use internal method.
|
||||||
|
llm._add_request(
|
||||||
|
prompt=prompt,
|
||||||
|
prompt_token_ids=None,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
start = time.perf_counter()
|
||||||
|
# FIXME(woosuk): Do use internal method.
|
||||||
|
llm._run_engine(use_tqdm=True)
|
||||||
|
end = time.perf_counter()
|
||||||
|
return end - start
|
||||||
|
|
||||||
|
|
||||||
|
def run_hf(
|
||||||
|
requests: List[Tuple[str, int, int]],
|
||||||
|
model: str,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
n: int,
|
||||||
|
use_beam_search: bool,
|
||||||
|
max_batch_size: int,
|
||||||
|
trust_remote_code: bool,
|
||||||
|
) -> float:
|
||||||
|
assert not use_beam_search
|
||||||
|
llm = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
|
||||||
|
if llm.config.model_type == "llama":
|
||||||
|
# To enable padding in the HF backend.
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
llm = llm.cuda()
|
||||||
|
|
||||||
|
pbar = tqdm(total=len(requests))
|
||||||
|
start = time.perf_counter()
|
||||||
|
batch: List[str] = []
|
||||||
|
max_prompt_len = 0
|
||||||
|
max_output_len = 0
|
||||||
|
for i in range(len(requests)):
|
||||||
|
prompt, prompt_len, output_len = requests[i]
|
||||||
|
# Add the prompt to the batch.
|
||||||
|
batch.append(prompt)
|
||||||
|
max_prompt_len = max(max_prompt_len, prompt_len)
|
||||||
|
max_output_len = max(max_output_len, output_len)
|
||||||
|
if len(batch) < max_batch_size and i != len(requests) - 1:
|
||||||
|
# Check if we can add more requests to the batch.
|
||||||
|
_, next_prompt_len, next_output_len = requests[i + 1]
|
||||||
|
if (max(max_prompt_len, next_prompt_len) +
|
||||||
|
max(max_output_len, next_output_len)) <= 2048:
|
||||||
|
# We can add more requests to the batch.
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Generate the sequences.
|
||||||
|
input_ids = tokenizer(batch, return_tensors="pt",
|
||||||
|
padding=True).input_ids
|
||||||
|
llm_outputs = llm.generate(
|
||||||
|
input_ids=input_ids.cuda(),
|
||||||
|
do_sample=not use_beam_search,
|
||||||
|
num_return_sequences=n,
|
||||||
|
temperature=1.0,
|
||||||
|
top_p=1.0,
|
||||||
|
use_cache=True,
|
||||||
|
max_new_tokens=max_output_len,
|
||||||
|
)
|
||||||
|
# Include the decoding time.
|
||||||
|
tokenizer.batch_decode(llm_outputs, skip_special_tokens=True)
|
||||||
|
pbar.update(len(batch))
|
||||||
|
|
||||||
|
# Clear the batch.
|
||||||
|
batch = []
|
||||||
|
max_prompt_len = 0
|
||||||
|
max_output_len = 0
|
||||||
|
end = time.perf_counter()
|
||||||
|
return end - start
|
||||||
|
|
||||||
|
|
||||||
|
def main(args: argparse.Namespace):
|
||||||
|
print(args)
|
||||||
|
random.seed(args.seed)
|
||||||
|
|
||||||
|
# Sample the requests.
|
||||||
|
tokenizer = get_tokenizer(args.tokenizer,
|
||||||
|
trust_remote_code=args.trust_remote_code)
|
||||||
|
requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
|
||||||
|
|
||||||
|
if args.backend == "vllm":
|
||||||
|
elapsed_time = run_vllm(requests, args.model, args.tokenizer,
|
||||||
|
args.quantization, args.tensor_parallel_size,
|
||||||
|
args.seed, args.n, args.use_beam_search,
|
||||||
|
args.trust_remote_code, args.dtype)
|
||||||
|
elif args.backend == "hf":
|
||||||
|
assert args.tensor_parallel_size == 1
|
||||||
|
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
|
||||||
|
args.use_beam_search, args.hf_max_batch_size,
|
||||||
|
args.trust_remote_code)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown backend: {args.backend}")
|
||||||
|
total_num_tokens = sum(prompt_len + output_len
|
||||||
|
for _, prompt_len, output_len in requests)
|
||||||
|
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
|
||||||
|
f"{total_num_tokens / elapsed_time:.2f} tokens/s")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Benchmark the throughput.")
|
||||||
|
parser.add_argument("--backend",
|
||||||
|
type=str,
|
||||||
|
choices=["vllm", "hf"],
|
||||||
|
default="vllm")
|
||||||
|
parser.add_argument("--dataset",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the dataset.")
|
||||||
|
parser.add_argument("--model", type=str, default="facebook/opt-125m")
|
||||||
|
parser.add_argument("--tokenizer", type=str, default=None)
|
||||||
|
parser.add_argument('--quantization',
|
||||||
|
'-q',
|
||||||
|
choices=['awq', None],
|
||||||
|
default=None)
|
||||||
|
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
|
||||||
|
parser.add_argument("--n",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Number of generated sequences per prompt.")
|
||||||
|
parser.add_argument("--use-beam-search", action="store_true")
|
||||||
|
parser.add_argument("--num-prompts",
|
||||||
|
type=int,
|
||||||
|
default=1000,
|
||||||
|
help="Number of prompts to process.")
|
||||||
|
parser.add_argument("--seed", type=int, default=0)
|
||||||
|
parser.add_argument("--hf-max-batch-size",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Maximum batch size for HF backend.")
|
||||||
|
parser.add_argument('--trust-remote-code',
|
||||||
|
action='store_true',
|
||||||
|
help='trust remote code from huggingface')
|
||||||
|
parser.add_argument(
|
||||||
|
'--dtype',
|
||||||
|
type=str,
|
||||||
|
default='auto',
|
||||||
|
choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'],
|
||||||
|
help='data type for model weights and activations. '
|
||||||
|
'The "auto" option will use FP16 precision '
|
||||||
|
'for FP32 and FP16 models, and BF16 precision '
|
||||||
|
'for BF16 models.')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.backend == "vllm":
|
||||||
|
if args.hf_max_batch_size is not None:
|
||||||
|
raise ValueError("HF max batch size is only for HF backend.")
|
||||||
|
elif args.backend == "hf":
|
||||||
|
if args.hf_max_batch_size is None:
|
||||||
|
raise ValueError("HF max batch size is required for HF backend.")
|
||||||
|
if args.quantization is not None:
|
||||||
|
raise ValueError("Quantization is only for vLLM backend.")
|
||||||
|
if args.tokenizer is None:
|
||||||
|
args.tokenizer = args.model
|
||||||
|
|
||||||
|
main(args)
|
||||||
197
benchmarks/kernels/benchmark_paged_attention.py
Normal file
197
benchmarks/kernels/benchmark_paged_attention.py
Normal file
@ -0,0 +1,197 @@
|
|||||||
|
import argparse
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm import attention_ops
|
||||||
|
|
||||||
|
NUM_BLOCKS = 1024
|
||||||
|
PARTITION_SIZE = 512
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def main(
|
||||||
|
version: str,
|
||||||
|
num_seqs: int,
|
||||||
|
context_len: int,
|
||||||
|
num_query_heads: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
use_alibi: bool,
|
||||||
|
block_size: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
seed: int,
|
||||||
|
do_profile: bool,
|
||||||
|
) -> None:
|
||||||
|
random.seed(seed)
|
||||||
|
torch.random.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
|
||||||
|
scale = float(1.0 / (head_size**0.5))
|
||||||
|
query = torch.empty(num_seqs,
|
||||||
|
num_query_heads,
|
||||||
|
head_size,
|
||||||
|
dtype=dtype,
|
||||||
|
device="cuda")
|
||||||
|
query.uniform_(-scale, scale)
|
||||||
|
|
||||||
|
assert num_query_heads % num_kv_heads == 0
|
||||||
|
num_queries_per_kv = num_query_heads // num_kv_heads
|
||||||
|
head_mapping = torch.repeat_interleave(
|
||||||
|
torch.arange(num_kv_heads, dtype=torch.int32, device="cuda"),
|
||||||
|
num_queries_per_kv)
|
||||||
|
alibi_slopes = None
|
||||||
|
if use_alibi:
|
||||||
|
alibi_slopes = torch.randn(num_query_heads,
|
||||||
|
dtype=torch.float,
|
||||||
|
device="cuda")
|
||||||
|
|
||||||
|
context_lens = [context_len for _ in range(num_seqs)]
|
||||||
|
max_context_len = max(context_lens)
|
||||||
|
context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda")
|
||||||
|
|
||||||
|
# Create the block tables.
|
||||||
|
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
|
||||||
|
block_tables = []
|
||||||
|
for _ in range(num_seqs):
|
||||||
|
block_table = [
|
||||||
|
random.randint(0, NUM_BLOCKS - 1)
|
||||||
|
for _ in range(max_num_blocks_per_seq)
|
||||||
|
]
|
||||||
|
block_tables.append(block_table)
|
||||||
|
block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda")
|
||||||
|
|
||||||
|
# Create the KV cache.
|
||||||
|
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
||||||
|
key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, block_size, x)
|
||||||
|
key_cache = torch.empty(size=key_cache_shape, dtype=dtype, device="cuda")
|
||||||
|
key_cache.uniform_(-scale, scale)
|
||||||
|
value_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size, block_size)
|
||||||
|
value_cache = torch.empty(size=value_cache_shape,
|
||||||
|
dtype=dtype,
|
||||||
|
device="cuda")
|
||||||
|
value_cache.uniform_(-scale, scale)
|
||||||
|
|
||||||
|
# Prepare for the paged attention kernel.
|
||||||
|
output = torch.empty_like(query)
|
||||||
|
if version == "v2":
|
||||||
|
num_partitions = ((max_context_len + PARTITION_SIZE - 1) //
|
||||||
|
PARTITION_SIZE)
|
||||||
|
tmp_output = torch.empty(
|
||||||
|
size=(num_seqs, num_query_heads, num_partitions, head_size),
|
||||||
|
dtype=output.dtype,
|
||||||
|
device=output.device,
|
||||||
|
)
|
||||||
|
exp_sums = torch.empty(
|
||||||
|
size=(num_seqs, num_query_heads, num_partitions),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=output.device,
|
||||||
|
)
|
||||||
|
max_logits = torch.empty_like(exp_sums)
|
||||||
|
|
||||||
|
def run_benchmark(num_iters: int, profile: bool = False) -> float:
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
if profile:
|
||||||
|
torch.cuda.cudart().cudaProfilerStart()
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
|
for _ in range(num_iters):
|
||||||
|
if version == "v1":
|
||||||
|
attention_ops.paged_attention_v1(
|
||||||
|
output,
|
||||||
|
query,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
head_mapping,
|
||||||
|
scale,
|
||||||
|
block_tables,
|
||||||
|
context_lens,
|
||||||
|
block_size,
|
||||||
|
max_context_len,
|
||||||
|
alibi_slopes,
|
||||||
|
)
|
||||||
|
elif version == "v2":
|
||||||
|
attention_ops.paged_attention_v2(
|
||||||
|
output,
|
||||||
|
exp_sums,
|
||||||
|
max_logits,
|
||||||
|
tmp_output,
|
||||||
|
query,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
head_mapping,
|
||||||
|
scale,
|
||||||
|
block_tables,
|
||||||
|
context_lens,
|
||||||
|
block_size,
|
||||||
|
max_context_len,
|
||||||
|
alibi_slopes,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid version: {version}")
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
end_time = time.perf_counter()
|
||||||
|
if profile:
|
||||||
|
torch.cuda.cudart().cudaProfilerStart()
|
||||||
|
return (end_time - start_time) / num_iters
|
||||||
|
|
||||||
|
# Warmup.
|
||||||
|
print("Warming up...")
|
||||||
|
run_benchmark(num_iters=3, profile=False)
|
||||||
|
|
||||||
|
# Benchmark.
|
||||||
|
if do_profile:
|
||||||
|
latency = run_benchmark(num_iters=1, profile=True)
|
||||||
|
else:
|
||||||
|
latency = run_benchmark(num_iters=100, profile=False)
|
||||||
|
print(f"Kernel running time: {latency * 1000000:.3f} us")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Benchmark the paged attention kernel.")
|
||||||
|
parser.add_argument("--version",
|
||||||
|
type=str,
|
||||||
|
choices=["v1", "v2"],
|
||||||
|
default="v2")
|
||||||
|
parser.add_argument("--batch-size", type=int, default=8)
|
||||||
|
parser.add_argument("--context-len", type=int, default=4096)
|
||||||
|
parser.add_argument("--num-query-heads", type=int, default=64)
|
||||||
|
parser.add_argument("--num-kv-heads", type=int, default=8)
|
||||||
|
parser.add_argument("--head-size",
|
||||||
|
type=int,
|
||||||
|
choices=[64, 80, 96, 112, 128, 256],
|
||||||
|
default=128)
|
||||||
|
parser.add_argument("--block-size", type=int, choices=[16, 32], default=16)
|
||||||
|
parser.add_argument("--use-alibi", action="store_true")
|
||||||
|
parser.add_argument("--dtype",
|
||||||
|
type=str,
|
||||||
|
choices=["half", "bfloat16", "float"],
|
||||||
|
default="half")
|
||||||
|
parser.add_argument("--seed", type=int, default=0)
|
||||||
|
parser.add_argument("--profile", action="store_true")
|
||||||
|
args = parser.parse_args()
|
||||||
|
print(args)
|
||||||
|
|
||||||
|
if args.num_query_heads % args.num_kv_heads != 0:
|
||||||
|
raise ValueError("num_query_heads must be divisible by num_kv_heads")
|
||||||
|
dtype_to_torch_dtype = {
|
||||||
|
"half": torch.half,
|
||||||
|
"bfloat16": torch.bfloat16,
|
||||||
|
"float": torch.float,
|
||||||
|
}
|
||||||
|
main(
|
||||||
|
version=args.version,
|
||||||
|
num_seqs=args.batch_size,
|
||||||
|
context_len=args.context_len,
|
||||||
|
num_query_heads=args.num_query_heads,
|
||||||
|
num_kv_heads=args.num_kv_heads,
|
||||||
|
head_size=args.head_size,
|
||||||
|
block_size=args.block_size,
|
||||||
|
use_alibi=args.use_alibi,
|
||||||
|
dtype=dtype_to_torch_dtype[args.dtype],
|
||||||
|
seed=args.seed,
|
||||||
|
do_profile=args.profile,
|
||||||
|
)
|
||||||
16
benchmarks/launch_tgi_server.sh
Executable file
16
benchmarks/launch_tgi_server.sh
Executable file
@ -0,0 +1,16 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
PORT=8000
|
||||||
|
MODEL=$1
|
||||||
|
TOKENS=$2
|
||||||
|
|
||||||
|
docker run --gpus all --shm-size 1g -p $PORT:80 \
|
||||||
|
-v $PWD/data:/data \
|
||||||
|
ghcr.io/huggingface/text-generation-inference:0.8 \
|
||||||
|
--model-id $MODEL \
|
||||||
|
--sharded false \
|
||||||
|
--max-input-length 1024 \
|
||||||
|
--max-total-tokens 2048 \
|
||||||
|
--max-best-of 5 \
|
||||||
|
--max-concurrent-requests 5000 \
|
||||||
|
--max-batch-total-tokens $TOKENS
|
||||||
@ -1,179 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import asyncio
|
|
||||||
import time
|
|
||||||
from typing import List, Dict
|
|
||||||
import json
|
|
||||||
|
|
||||||
import ray
|
|
||||||
from transformers import AutoTokenizer
|
|
||||||
from fastapi import FastAPI, Request
|
|
||||||
from fastapi.responses import StreamingResponse
|
|
||||||
import uvicorn
|
|
||||||
|
|
||||||
from cacheflow.sampling_params import SamplingParams
|
|
||||||
from cacheflow.sequence import Sequence, SequenceGroup
|
|
||||||
from cacheflow.master.server import (Server, add_server_arguments,
|
|
||||||
initialize_ray_cluster)
|
|
||||||
from cacheflow.worker.controller import DeviceID
|
|
||||||
from cacheflow.utils import Counter, get_gpu_memory, get_cpu_memory
|
|
||||||
|
|
||||||
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
|
|
||||||
app = FastAPI()
|
|
||||||
|
|
||||||
|
|
||||||
class FastAPIFrontend:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
model_path: str,
|
|
||||||
pipeline_parallel_size: int,
|
|
||||||
tensor_parallel_size: int,
|
|
||||||
block_size: int,
|
|
||||||
dtype: str,
|
|
||||||
seed: int,
|
|
||||||
swap_space: int,
|
|
||||||
max_num_batched_tokens: int,
|
|
||||||
num_nodes: int,
|
|
||||||
num_devices_per_node: int,
|
|
||||||
distributed_init_method: str,
|
|
||||||
all_stage_devices: List[List[DeviceID]],
|
|
||||||
):
|
|
||||||
self.block_size = block_size
|
|
||||||
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(model)
|
|
||||||
self.seq_group_counter = Counter()
|
|
||||||
self.seq_counter = Counter()
|
|
||||||
remote_server_class = ray.remote(num_cpus=0)(Server)
|
|
||||||
self.server = remote_server_class.remote(
|
|
||||||
model=model,
|
|
||||||
model_path=model_path,
|
|
||||||
use_dummy_weights=False,
|
|
||||||
pipeline_parallel_size=pipeline_parallel_size,
|
|
||||||
tensor_parallel_size=tensor_parallel_size,
|
|
||||||
block_size=block_size,
|
|
||||||
dtype=dtype,
|
|
||||||
seed=seed,
|
|
||||||
swap_space=swap_space,
|
|
||||||
max_num_batched_tokens=max_num_batched_tokens,
|
|
||||||
num_nodes=num_nodes,
|
|
||||||
num_devices_per_node=num_devices_per_node,
|
|
||||||
distributed_init_method=distributed_init_method,
|
|
||||||
all_stage_devices=all_stage_devices,
|
|
||||||
gpu_memory=get_gpu_memory(),
|
|
||||||
cpu_memory=get_cpu_memory(),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.running_seq_groups: Dict[int, SequenceGroup] = {}
|
|
||||||
self.sequence_group_events: Dict[int, asyncio.Event] = {}
|
|
||||||
self.is_server_running = False
|
|
||||||
|
|
||||||
async def server_step(self):
|
|
||||||
self.is_server_running = True
|
|
||||||
updated_seq_groups = await self.server.step.remote()
|
|
||||||
self.is_server_running = False
|
|
||||||
# Notify the waiting coroutines that there new outputs ready.
|
|
||||||
for seq_group in updated_seq_groups:
|
|
||||||
group_id = seq_group.group_id
|
|
||||||
self.running_seq_groups[group_id] = seq_group
|
|
||||||
self.sequence_group_events[group_id].set()
|
|
||||||
|
|
||||||
async def generate(self, request_dict: Dict):
|
|
||||||
# Preprocess the request.
|
|
||||||
prompt = request_dict["prompt"]
|
|
||||||
sampling_params = SamplingParams.from_dict(request_dict)
|
|
||||||
sampling_params.stop_token_ids.add(self.tokenizer.eos_token_id)
|
|
||||||
token_ids = self.tokenizer.encode(prompt)
|
|
||||||
seqs: List[Sequence] = []
|
|
||||||
for _ in range(sampling_params.n):
|
|
||||||
seq_id = next(self.seq_counter)
|
|
||||||
seq = Sequence(seq_id, token_ids, block_size=self.block_size)
|
|
||||||
seqs.append(seq)
|
|
||||||
|
|
||||||
arrival_time = time.time()
|
|
||||||
group_id = next(self.seq_group_counter)
|
|
||||||
seq_group = SequenceGroup(group_id, seqs, arrival_time)
|
|
||||||
# Create an event to notify us that there is new output from the
|
|
||||||
# cacheflow server.
|
|
||||||
group_event = asyncio.Event()
|
|
||||||
self.running_seq_groups[group_id] = seq_group
|
|
||||||
self.sequence_group_events[group_id] = group_event
|
|
||||||
# Add the request into the cacheflow server's waiting queue.
|
|
||||||
await self.server.add_sequence_groups.remote([(seq_group, sampling_params)])
|
|
||||||
# The cacheflow server does not have a background loop that keeps
|
|
||||||
# processing incoming requests. Therefore, we need to keep kicking
|
|
||||||
# the server to process the requests.
|
|
||||||
while True:
|
|
||||||
# Kick the server if the server is not running.
|
|
||||||
if not self.is_server_running:
|
|
||||||
await self.server_step()
|
|
||||||
# Wait for new output. The group_event will be set in server_step
|
|
||||||
# when there is new output available for the sequence group.
|
|
||||||
# Added a timeout to prevent deadlock.
|
|
||||||
await asyncio.wait_for(group_event.wait(), timeout=TIMEOUT_TO_PREVENT_DEADLOCK)
|
|
||||||
# Reset the event to wait for the next output.
|
|
||||||
group_event.clear()
|
|
||||||
# Decode and return new outputs
|
|
||||||
seq_group = self.running_seq_groups[group_id]
|
|
||||||
all_outputs = []
|
|
||||||
for seq in seq_group.seqs:
|
|
||||||
token_ids = seq.get_token_ids()
|
|
||||||
output = self.tokenizer.decode(token_ids, skip_special_tokens=True)
|
|
||||||
all_outputs.append(output)
|
|
||||||
ret = {
|
|
||||||
"text": all_outputs,
|
|
||||||
"error": 0,
|
|
||||||
}
|
|
||||||
yield (json.dumps(ret) + "\0").encode("utf-8")
|
|
||||||
|
|
||||||
# Once finished, release the resources of the sequence group.
|
|
||||||
if seq_group.is_finished():
|
|
||||||
del self.running_seq_groups[group_id]
|
|
||||||
del self.sequence_group_events[group_id]
|
|
||||||
# Kick the server if the server is not running. This is to
|
|
||||||
# prevent that there are still requests in server's waiting
|
|
||||||
# queue to be executed.
|
|
||||||
if not self.is_server_running:
|
|
||||||
await self.server_step()
|
|
||||||
break
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/generate")
|
|
||||||
async def generate_stream(request: Request):
|
|
||||||
request_dict = await request.json()
|
|
||||||
return StreamingResponse(frontend.generate(request_dict))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument("--host", type=str, default="localhost")
|
|
||||||
parser.add_argument("--port", type=int, default=10002)
|
|
||||||
parser = add_server_arguments(parser)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# TODO(zhuohan): Support pipeline parallelism.
|
|
||||||
assert args.pipeline_parallel_size == 1, (
|
|
||||||
'Pipeline parallelism is not supported yet.')
|
|
||||||
|
|
||||||
(num_nodes, num_devices_per_node, distributed_init_method,
|
|
||||||
all_stage_devices) = (
|
|
||||||
initialize_ray_cluster(
|
|
||||||
pipeline_parallel_size=args.pipeline_parallel_size,
|
|
||||||
tensor_parallel_size=args.tensor_parallel_size))
|
|
||||||
|
|
||||||
frontend = FastAPIFrontend(
|
|
||||||
model=args.model,
|
|
||||||
model_path=args.model_path,
|
|
||||||
pipeline_parallel_size=args.pipeline_parallel_size,
|
|
||||||
tensor_parallel_size=args.tensor_parallel_size,
|
|
||||||
block_size=args.block_size,
|
|
||||||
dtype=args.dtype,
|
|
||||||
seed=args.seed,
|
|
||||||
swap_space=args.swap_space,
|
|
||||||
max_num_batched_tokens=args.max_num_batched_tokens,
|
|
||||||
num_nodes=num_nodes,
|
|
||||||
num_devices_per_node=num_devices_per_node,
|
|
||||||
distributed_init_method=distributed_init_method,
|
|
||||||
all_stage_devices=all_stage_devices,
|
|
||||||
)
|
|
||||||
|
|
||||||
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
|
||||||
@ -1,43 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import json
|
|
||||||
import time
|
|
||||||
|
|
||||||
import gradio as gr
|
|
||||||
import requests
|
|
||||||
|
|
||||||
|
|
||||||
def http_bot(prompt):
|
|
||||||
headers = {"User-Agent": "Cacheflow Client"}
|
|
||||||
pload = {
|
|
||||||
"prompt": prompt,
|
|
||||||
"max_num_steps": 128,
|
|
||||||
}
|
|
||||||
response = requests.post(args.model_url, headers=headers, json=pload, stream=True)
|
|
||||||
|
|
||||||
for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"):
|
|
||||||
if chunk:
|
|
||||||
data = json.loads(chunk.decode("utf-8"))
|
|
||||||
output = data["text"][0]
|
|
||||||
yield output
|
|
||||||
|
|
||||||
|
|
||||||
def build_demo():
|
|
||||||
with gr.Blocks() as demo:
|
|
||||||
gr.Markdown(
|
|
||||||
"# Cacheflow demo\n"
|
|
||||||
)
|
|
||||||
inputbox = gr.Textbox(label="Input", placeholder="Enter text and press ENTER")# .style(container=False)
|
|
||||||
outputbox = gr.Textbox(label="Output", placeholder="Generated result from the model")
|
|
||||||
inputbox.submit(http_bot, [inputbox], [outputbox])
|
|
||||||
return demo
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument("--host", type=str, default="localhost")
|
|
||||||
parser.add_argument("--port", type=int, default=10003)
|
|
||||||
parser.add_argument("--model-url", type=str, default="http://localhost:10002/generate")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
demo = build_demo()
|
|
||||||
demo.queue(concurrency_count=100).launch(server_name=args.host, server_port=args.port)
|
|
||||||
@ -1,23 +0,0 @@
|
|||||||
import requests
|
|
||||||
import json
|
|
||||||
|
|
||||||
def http_request():
|
|
||||||
prompt = "Ion Stoica is a"
|
|
||||||
|
|
||||||
headers = {"User-Agent": "Test Client"}
|
|
||||||
pload = {
|
|
||||||
"prompt": prompt,
|
|
||||||
"n": 4,
|
|
||||||
"use_beam_search": True,
|
|
||||||
"temperature": 0.0,
|
|
||||||
}
|
|
||||||
response = requests.post("http://localhost:10002/generate", headers=headers, json=pload, stream=True)
|
|
||||||
|
|
||||||
for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"):
|
|
||||||
if chunk:
|
|
||||||
data = json.loads(chunk.decode("utf-8"))
|
|
||||||
output = data["text"]
|
|
||||||
yield output
|
|
||||||
|
|
||||||
for h in http_request():
|
|
||||||
print(h, flush=True)
|
|
||||||
@ -1,529 +0,0 @@
|
|||||||
import enum
|
|
||||||
import os
|
|
||||||
import pickle
|
|
||||||
import time
|
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
|
||||||
|
|
||||||
from cacheflow.master.block_manager import BlockSpaceManager
|
|
||||||
from cacheflow.master.policy import PolicyFactory
|
|
||||||
from cacheflow.sampling_params import SamplingParams
|
|
||||||
from cacheflow.sequence import Sequence
|
|
||||||
from cacheflow.sequence import SequenceGroup
|
|
||||||
from cacheflow.sequence import SequenceGroupInputs
|
|
||||||
from cacheflow.sequence import SequenceOutputs
|
|
||||||
from cacheflow.sequence import SequenceStatus
|
|
||||||
|
|
||||||
|
|
||||||
class PreemptionMode(enum.Enum):
|
|
||||||
"""Preemption modes.
|
|
||||||
|
|
||||||
1. Swapping: Swap out the blocks of the preempted sequences to CPU memory
|
|
||||||
and swap them back in when the sequences are resumed.
|
|
||||||
2. Recomputation: Discard the blocks of the preempted sequences and
|
|
||||||
recompute them when the sequences are resumed, treating the sequences as
|
|
||||||
new prompts.
|
|
||||||
"""
|
|
||||||
SWAP = enum.auto()
|
|
||||||
RECOMPUTE = enum.auto()
|
|
||||||
|
|
||||||
|
|
||||||
class Scheduler:
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
controllers: List,
|
|
||||||
block_size: int,
|
|
||||||
num_gpu_blocks: int,
|
|
||||||
num_cpu_blocks: int,
|
|
||||||
max_num_batched_tokens: int,
|
|
||||||
max_num_sequences: int,
|
|
||||||
collect_stats: bool,
|
|
||||||
do_memory_analysis: bool = False,
|
|
||||||
) -> None:
|
|
||||||
self.controllers = controllers
|
|
||||||
self.block_size = block_size
|
|
||||||
self.num_gpu_blocks = num_gpu_blocks
|
|
||||||
self.num_cpu_blocks = num_cpu_blocks
|
|
||||||
self.max_num_batched_tokens = max_num_batched_tokens
|
|
||||||
self.max_num_sequences = max_num_sequences
|
|
||||||
self.collect_stats = collect_stats
|
|
||||||
self.do_memory_analysis = do_memory_analysis
|
|
||||||
|
|
||||||
# Instantiate the scheduling policy.
|
|
||||||
self.policy = PolicyFactory.get_policy(policy_name='fcfs')
|
|
||||||
# Create the block space manager.
|
|
||||||
self.block_manager = BlockSpaceManager(
|
|
||||||
block_size=block_size,
|
|
||||||
num_gpu_blocks=num_gpu_blocks,
|
|
||||||
num_cpu_blocks=num_cpu_blocks,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Sequence groups in the WAITING state.
|
|
||||||
self.waiting: List[SequenceGroup] = []
|
|
||||||
# Sequence groups in the RUNNING state.
|
|
||||||
self.running: List[SequenceGroup] = []
|
|
||||||
# Mapping: group_id -> num_steps.
|
|
||||||
self.num_steps: Dict[int, int] = {}
|
|
||||||
# Mapping: group_id -> sampling params.
|
|
||||||
self.sampling_params: Dict[int, SamplingParams] = {}
|
|
||||||
# Sequence groups in the SWAPPED state.
|
|
||||||
self.swapped: List[SequenceGroup] = []
|
|
||||||
|
|
||||||
# Performance-related statistics.
|
|
||||||
self.stats = Stats(num_gpu_blocks, num_cpu_blocks)
|
|
||||||
|
|
||||||
def add_sequence_groups(
|
|
||||||
self,
|
|
||||||
seq_groups: List[Tuple[SequenceGroup, SamplingParams]],
|
|
||||||
) -> None:
|
|
||||||
# Add sequence groups to the waiting queue.
|
|
||||||
for seq_group, sampling_params in seq_groups:
|
|
||||||
self.waiting.append(seq_group)
|
|
||||||
self.sampling_params[seq_group.group_id] = sampling_params
|
|
||||||
|
|
||||||
def _schedule(
|
|
||||||
self,
|
|
||||||
) -> Tuple[Dict[int, int], Dict[int, int], Dict[int, List[int]], List[int]]:
|
|
||||||
# Blocks that need to be swaped or copied before model execution.
|
|
||||||
blocks_to_swap_in: Dict[int, int] = {}
|
|
||||||
blocks_to_swap_out: Dict[int, int] = {}
|
|
||||||
blocks_to_copy: Dict[int, List[int]] = {}
|
|
||||||
|
|
||||||
# Fix the current time.
|
|
||||||
now = time.time()
|
|
||||||
|
|
||||||
# NOTE(woosuk): We prioritize the sequence groups in the RUNNING state
|
|
||||||
# in order to minimize the preemption overheads.
|
|
||||||
# Preemption happens only when there is no available slot to keep all
|
|
||||||
# the sequence groups in the RUNNING state.
|
|
||||||
# In this case, the policy is responsible for deciding which sequence
|
|
||||||
# groups to preempt.
|
|
||||||
self.running = self.policy.sort_by_priority(now, self.running)
|
|
||||||
|
|
||||||
# Reserve new token slots for the running sequence groups.
|
|
||||||
running: List[SequenceGroup] = []
|
|
||||||
preempted: List[SequenceGroup] = []
|
|
||||||
while self.running:
|
|
||||||
seq_group = self.running.pop(0)
|
|
||||||
while not self.block_manager.can_append(seq_group):
|
|
||||||
if self.running:
|
|
||||||
# Preempt the lowest-priority sequence groups.
|
|
||||||
victim_seq_group = self.running.pop(-1)
|
|
||||||
self._preempt(victim_seq_group, blocks_to_swap_out)
|
|
||||||
preempted.append(victim_seq_group)
|
|
||||||
else:
|
|
||||||
# No other sequence groups can be preempted.
|
|
||||||
# Preempt the current sequence group.
|
|
||||||
self._preempt(seq_group, blocks_to_swap_out)
|
|
||||||
preempted.append(seq_group)
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
# Append new slots to the sequence group.
|
|
||||||
self._append(seq_group, blocks_to_copy)
|
|
||||||
running.append(seq_group)
|
|
||||||
self.running = running
|
|
||||||
|
|
||||||
# Swap in the sequence groups in the SWAPPED state if possible.
|
|
||||||
self.swapped = self.policy.sort_by_priority(now, self.swapped)
|
|
||||||
# FCFS
|
|
||||||
while self.swapped and not blocks_to_swap_out:
|
|
||||||
seq_group = self.swapped[0]
|
|
||||||
# If the sequence group has been preempted in this step, stop.
|
|
||||||
if seq_group in preempted:
|
|
||||||
break
|
|
||||||
# If the sequence group cannot be swapped in, stop.
|
|
||||||
if not self.block_manager.can_swap_in(seq_group):
|
|
||||||
break
|
|
||||||
|
|
||||||
# The total number of sequences in the RUNNING state should not
|
|
||||||
# exceed the maximum number of sequences.
|
|
||||||
num_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED)
|
|
||||||
if len(self.running) + num_seqs > self.max_num_sequences:
|
|
||||||
break
|
|
||||||
|
|
||||||
seq_group = self.swapped.pop(0)
|
|
||||||
self._swap_in(seq_group, blocks_to_swap_in)
|
|
||||||
self._append(seq_group, blocks_to_copy)
|
|
||||||
self.running.append(seq_group)
|
|
||||||
|
|
||||||
num_batched_tokens = sum(
|
|
||||||
seq_group.num_seqs(status=SequenceStatus.RUNNING)
|
|
||||||
for seq_group in self.running
|
|
||||||
)
|
|
||||||
|
|
||||||
# Join waiting sequences if possible.
|
|
||||||
prompt_group_ids: List[int] = []
|
|
||||||
# NOTE(woosuk): The sequence groups in the SWAPPED state are strictly
|
|
||||||
# prioritized over the sequence groups in the WAITING state.
|
|
||||||
# This is because we want to bound the amount of CPU memory taken by
|
|
||||||
# the swapped sequence groups.
|
|
||||||
if not self.swapped:
|
|
||||||
self.waiting = self.policy.sort_by_priority(now, self.waiting)
|
|
||||||
while self.waiting:
|
|
||||||
seq_group = self.waiting[0]
|
|
||||||
# If the sequence group has been preempted in this step, stop.
|
|
||||||
if seq_group in preempted:
|
|
||||||
break
|
|
||||||
# If the sequence group cannot be allocated, stop.
|
|
||||||
if not self.block_manager.can_allocate(seq_group):
|
|
||||||
break
|
|
||||||
|
|
||||||
# If the number of batched tokens exceeds the limit, stop.
|
|
||||||
num_prompt_tokens = seq_group.seqs[0].get_len()
|
|
||||||
if (num_batched_tokens + num_prompt_tokens
|
|
||||||
> self.max_num_batched_tokens):
|
|
||||||
break
|
|
||||||
|
|
||||||
# The total number of sequences in the RUNNING state should not
|
|
||||||
# exceed the maximum number of sequences.
|
|
||||||
num_seqs = seq_group.num_seqs(status=SequenceStatus.WAITING)
|
|
||||||
if len(self.running) + num_seqs > self.max_num_sequences:
|
|
||||||
break
|
|
||||||
|
|
||||||
seq_group = self.waiting.pop(0)
|
|
||||||
self._allocate(seq_group)
|
|
||||||
self.running.append(seq_group)
|
|
||||||
num_batched_tokens += num_prompt_tokens
|
|
||||||
prompt_group_ids.append(seq_group.group_id)
|
|
||||||
|
|
||||||
if self.collect_stats:
|
|
||||||
if self.running or blocks_to_swap_in or blocks_to_swap_out:
|
|
||||||
self.stats.timestamps.append(now - self.stats.start_time)
|
|
||||||
self.stats.input_lens.append(num_batched_tokens)
|
|
||||||
self.stats.swap_out_lens.append(len(blocks_to_swap_out) * self.block_size)
|
|
||||||
self.stats.swap_in_lens.append(len(blocks_to_swap_in) * self.block_size)
|
|
||||||
self.stats.num_preemption.append(len(preempted))
|
|
||||||
self.stats.num_swapped.append(len(self.swapped))
|
|
||||||
self.stats.num_running.append(len(self.running))
|
|
||||||
self.stats.num_waiting.append(len(self.waiting))
|
|
||||||
|
|
||||||
num_free_gpu_blocks = self.block_manager.get_num_free_gpu_blocks()
|
|
||||||
num_used_gpu_blocks = self.num_gpu_blocks - num_free_gpu_blocks
|
|
||||||
self.stats.gpu_cache_usage.append(num_used_gpu_blocks / self.num_gpu_blocks)
|
|
||||||
num_free_cpu_blocks = self.block_manager.get_num_free_cpu_blocks()
|
|
||||||
num_used_cpu_blocks = self.num_cpu_blocks - num_free_cpu_blocks
|
|
||||||
self.stats.cpu_cache_usage.append(num_used_cpu_blocks / self.num_cpu_blocks)
|
|
||||||
|
|
||||||
if self.do_memory_analysis:
|
|
||||||
block_tables = self.block_manager.block_tables
|
|
||||||
num_logical_blocks = 0
|
|
||||||
num_logical_tokens = 0
|
|
||||||
num_physical_blocks = 0
|
|
||||||
num_physical_tokens = 0
|
|
||||||
physical_block_numbers = set()
|
|
||||||
num_reserved_tokens = 0
|
|
||||||
for seq_group in self.running:
|
|
||||||
group_id = seq_group.group_id
|
|
||||||
sampling_params = self.sampling_params[group_id]
|
|
||||||
max_num_steps = sampling_params.max_num_steps
|
|
||||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
|
||||||
num_logical_blocks += len(seq.logical_token_blocks)
|
|
||||||
num_logical_tokens += seq.get_len()
|
|
||||||
|
|
||||||
seq_id = seq.seq_id
|
|
||||||
block_table = block_tables[seq_id]
|
|
||||||
for i, block in enumerate(block_table):
|
|
||||||
if block.block_number in physical_block_numbers:
|
|
||||||
continue
|
|
||||||
physical_block_numbers.add(block.block_number)
|
|
||||||
num_physical_blocks += 1
|
|
||||||
num_physical_tokens += seq.logical_token_blocks[i].num_tokens
|
|
||||||
|
|
||||||
assert num_physical_blocks == num_used_gpu_blocks
|
|
||||||
self.stats.num_logical_blocks.append(num_logical_blocks)
|
|
||||||
self.stats.num_logical_tokens.append(num_logical_tokens)
|
|
||||||
self.stats.num_physical_blocks.append(num_physical_blocks)
|
|
||||||
self.stats.num_physical_tokens.append(num_physical_tokens)
|
|
||||||
self.stats.num_reserved_tokens.append(num_reserved_tokens)
|
|
||||||
|
|
||||||
return (blocks_to_swap_in,
|
|
||||||
blocks_to_swap_out,
|
|
||||||
blocks_to_copy,
|
|
||||||
prompt_group_ids)
|
|
||||||
|
|
||||||
def step(self) -> List[SequenceGroup]:
|
|
||||||
# Schedule sequence groups.
|
|
||||||
# This function call changes the internal states of the scheduler
|
|
||||||
# such as self.running, self.swapped, and self.waiting.
|
|
||||||
scheduler_output = self._schedule()
|
|
||||||
blocks_to_swap_in = scheduler_output[0]
|
|
||||||
blocks_to_swap_out = scheduler_output[1]
|
|
||||||
blocks_to_copy = scheduler_output[2]
|
|
||||||
prompt_group_ids = scheduler_output[3]
|
|
||||||
|
|
||||||
# Create input data structures.
|
|
||||||
input_seq_groups: List[SequenceGroupInputs] = []
|
|
||||||
updated_seq_groups: List[SequenceGroup] = self.running.copy()
|
|
||||||
|
|
||||||
for seq_group in self.running:
|
|
||||||
group_id = seq_group.group_id
|
|
||||||
is_prompt = group_id in prompt_group_ids
|
|
||||||
|
|
||||||
input_tokens: Dict[int, List[int]] = {}
|
|
||||||
seq_logprobs: Dict[int, float] = {}
|
|
||||||
block_tables: Dict[int, List[int]] = {}
|
|
||||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
|
||||||
seq_id = seq.seq_id
|
|
||||||
block_tables[seq_id] = self.block_manager.get_block_table(seq)
|
|
||||||
if is_prompt:
|
|
||||||
input_tokens[seq_id] = seq.get_token_ids()
|
|
||||||
else:
|
|
||||||
input_tokens[seq_id] = [seq.get_last_token_id()]
|
|
||||||
seq_logprobs[seq_id] = seq.cumulative_logprobs
|
|
||||||
# NOTE(woosuk): Sequences in the same group have the same
|
|
||||||
# sequence length
|
|
||||||
seq_len = seq.get_len()
|
|
||||||
|
|
||||||
input_seq_group = SequenceGroupInputs(
|
|
||||||
group_id=group_id,
|
|
||||||
is_prompt=is_prompt,
|
|
||||||
input_tokens=input_tokens,
|
|
||||||
context_len=seq_len,
|
|
||||||
seq_logprobs=seq_logprobs,
|
|
||||||
sampling_params=self.sampling_params[group_id],
|
|
||||||
block_tables=block_tables,
|
|
||||||
)
|
|
||||||
input_seq_groups.append(input_seq_group)
|
|
||||||
|
|
||||||
# Execute the first stage of the pipeline.
|
|
||||||
if input_seq_groups or blocks_to_swap_in or blocks_to_swap_out:
|
|
||||||
# Swap in and swap out should never happen at the same time.
|
|
||||||
assert not (blocks_to_swap_in and blocks_to_swap_out)
|
|
||||||
self.controllers[0].execute_stage(
|
|
||||||
input_seq_groups,
|
|
||||||
blocks_to_swap_in=blocks_to_swap_in,
|
|
||||||
blocks_to_swap_out=blocks_to_swap_out,
|
|
||||||
blocks_to_copy=blocks_to_copy,
|
|
||||||
)
|
|
||||||
|
|
||||||
return updated_seq_groups
|
|
||||||
|
|
||||||
def post_step(
|
|
||||||
self,
|
|
||||||
seq_outputs: Dict[int, SequenceOutputs],
|
|
||||||
) -> None:
|
|
||||||
# Update the running sequences and free blocks.
|
|
||||||
for seq_group in self.running:
|
|
||||||
group_id = seq_group.group_id
|
|
||||||
self.num_steps[group_id] += 1
|
|
||||||
stop_token_ids = self.sampling_params[group_id].stop_token_ids
|
|
||||||
|
|
||||||
# Process beam search results before processing the next tokens.
|
|
||||||
for seq in seq_group.seqs:
|
|
||||||
if seq.status == SequenceStatus.FINISHED:
|
|
||||||
continue
|
|
||||||
|
|
||||||
output = seq_outputs[seq.seq_id]
|
|
||||||
if seq.seq_id != output.parent_seq_id:
|
|
||||||
# The sequence is a fork of the parent sequence (beam search).
|
|
||||||
# Free the current sequence.
|
|
||||||
self.block_manager.free(seq)
|
|
||||||
# Fork the parent sequence.
|
|
||||||
parent_seq = seq_group.find(output.parent_seq_id)
|
|
||||||
parent_seq.fork(seq)
|
|
||||||
self.block_manager.fork(parent_seq, seq)
|
|
||||||
|
|
||||||
# Process the next tokens.
|
|
||||||
for seq in seq_group.seqs:
|
|
||||||
if seq.status == SequenceStatus.FINISHED:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Append a new token to the sequence.
|
|
||||||
output = seq_outputs[seq.seq_id]
|
|
||||||
seq.append(output.output_token, output.logprobs)
|
|
||||||
|
|
||||||
# Check if the sequence has generated a stop token.
|
|
||||||
if output.output_token in stop_token_ids:
|
|
||||||
self._free_seq(seq)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Check if the sequence has reached the maximum number of steps.
|
|
||||||
max_num_steps = self.sampling_params[group_id].max_num_steps
|
|
||||||
if self.num_steps[group_id] == max_num_steps:
|
|
||||||
self._free_seq(seq)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Update the running sequences.
|
|
||||||
running: List[SequenceGroup] = []
|
|
||||||
for seq_group in self.running:
|
|
||||||
if seq_group.is_finished():
|
|
||||||
self._free_seq_group(seq_group)
|
|
||||||
else:
|
|
||||||
running.append(seq_group)
|
|
||||||
self.running = running
|
|
||||||
|
|
||||||
def _allocate(self, seq_group: SequenceGroup) -> None:
|
|
||||||
self.block_manager.allocate(seq_group)
|
|
||||||
for seq in seq_group.seqs:
|
|
||||||
seq.status = SequenceStatus.RUNNING
|
|
||||||
# FIXME(woosuk): Support interactive generation.
|
|
||||||
if seq_group.group_id not in self.num_steps:
|
|
||||||
self.num_steps[seq_group.group_id] = 0
|
|
||||||
|
|
||||||
def _append(
|
|
||||||
self,
|
|
||||||
seq_group: SequenceGroup,
|
|
||||||
blocks_to_copy: Dict[int, List[int]],
|
|
||||||
) -> None:
|
|
||||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
|
||||||
ret = self.block_manager.append(seq)
|
|
||||||
if ret is not None:
|
|
||||||
src_block, dst_block = ret
|
|
||||||
if src_block in blocks_to_copy:
|
|
||||||
blocks_to_copy[src_block].append(dst_block)
|
|
||||||
else:
|
|
||||||
blocks_to_copy[src_block] = [dst_block]
|
|
||||||
|
|
||||||
def _preempt(
|
|
||||||
self,
|
|
||||||
seq_group: SequenceGroup,
|
|
||||||
blocks_to_swap_out: Dict[int, int],
|
|
||||||
preemption_mode: Optional[PreemptionMode] = None,
|
|
||||||
) -> None:
|
|
||||||
# If preemption mode is not specified, we determine the mode as follows:
|
|
||||||
# We use recomputation by default since it incurs lower overhead than
|
|
||||||
# swapping. However, when the sequence group has multiple sequences
|
|
||||||
# (e.g., beam search), recomputation is not supported. In such a case,
|
|
||||||
# we use swapping instead.
|
|
||||||
# FIXME(woosuk): This makes our scheduling policy a bit bizarre.
|
|
||||||
# As swapped sequences are prioritized over waiting sequences,
|
|
||||||
# sequence groups with multiple sequences are implicitly prioritized
|
|
||||||
# over sequence groups with a single sequence.
|
|
||||||
# TODO(woosuk): Support recomputation for sequence groups with multiple
|
|
||||||
# sequences. This may require a more sophisticated CUDA kernel.
|
|
||||||
if preemption_mode is None:
|
|
||||||
seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
|
|
||||||
if len(seqs) == 1:
|
|
||||||
preemption_mode = PreemptionMode.RECOMPUTE
|
|
||||||
else:
|
|
||||||
preemption_mode = PreemptionMode.SWAP
|
|
||||||
if preemption_mode == PreemptionMode.RECOMPUTE:
|
|
||||||
self._preempt_by_recompute(seq_group)
|
|
||||||
elif preemption_mode == PreemptionMode.SWAP:
|
|
||||||
self._preempt_by_swap(seq_group, blocks_to_swap_out)
|
|
||||||
else:
|
|
||||||
assert False, 'Invalid preemption mode.'
|
|
||||||
|
|
||||||
def _preempt_by_recompute(
|
|
||||||
self,
|
|
||||||
seq_group: SequenceGroup,
|
|
||||||
) -> None:
|
|
||||||
seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
|
|
||||||
assert len(seqs) == 1
|
|
||||||
for seq in seqs:
|
|
||||||
seq.status = SequenceStatus.WAITING
|
|
||||||
self.block_manager.free(seq)
|
|
||||||
self.waiting.append(seq_group)
|
|
||||||
|
|
||||||
def _preempt_by_swap(
|
|
||||||
self,
|
|
||||||
seq_group: SequenceGroup,
|
|
||||||
blocks_to_swap_out: Dict[int, int],
|
|
||||||
) -> None:
|
|
||||||
seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
|
|
||||||
for seq in seqs:
|
|
||||||
seq.status = SequenceStatus.SWAPPED
|
|
||||||
self._swap_out(seq_group, blocks_to_swap_out)
|
|
||||||
self.swapped.append(seq_group)
|
|
||||||
|
|
||||||
def _free_seq(self, seq: Sequence) -> None:
|
|
||||||
seq.status = SequenceStatus.FINISHED
|
|
||||||
self.block_manager.free(seq)
|
|
||||||
|
|
||||||
def _free_seq_group(self, seq_group: SequenceGroup) -> None:
|
|
||||||
group_id = seq_group.group_id
|
|
||||||
del self.num_steps[group_id]
|
|
||||||
del self.sampling_params[group_id]
|
|
||||||
|
|
||||||
def _swap_in(
|
|
||||||
self,
|
|
||||||
seq_group: SequenceGroup,
|
|
||||||
blocks_to_swap_in: Dict[int, int],
|
|
||||||
) -> None:
|
|
||||||
mapping = self.block_manager.swap_in(seq_group)
|
|
||||||
blocks_to_swap_in.update(mapping)
|
|
||||||
for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
|
|
||||||
seq.status = SequenceStatus.RUNNING
|
|
||||||
|
|
||||||
def _swap_out(
|
|
||||||
self,
|
|
||||||
seq_group: SequenceGroup,
|
|
||||||
blocks_to_swap_out: Dict[int, int],
|
|
||||||
) -> None:
|
|
||||||
assert self.block_manager.can_swap_out(seq_group)
|
|
||||||
mapping = self.block_manager.swap_out(seq_group)
|
|
||||||
blocks_to_swap_out.update(mapping)
|
|
||||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
|
||||||
seq.status = SequenceStatus.SWAPPED
|
|
||||||
|
|
||||||
def reset_stats(self) -> None:
|
|
||||||
self.stats.reset(self.num_gpu_blocks, self.num_cpu_blocks)
|
|
||||||
|
|
||||||
def save_stats(
|
|
||||||
self,
|
|
||||||
output_dir: str,
|
|
||||||
) -> None:
|
|
||||||
assert self.collect_stats, 'Statistics collection is disabled.'
|
|
||||||
self.stats.save(output_dir)
|
|
||||||
|
|
||||||
|
|
||||||
class Stats:
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
num_gpu_blocks: int,
|
|
||||||
num_cpu_blocks: int,
|
|
||||||
) -> None:
|
|
||||||
self.start_time: float = time.time()
|
|
||||||
self.num_gpu_blocks = num_gpu_blocks
|
|
||||||
self.num_cpu_blocks = num_cpu_blocks
|
|
||||||
|
|
||||||
self.timestamps: List[float] = []
|
|
||||||
self.input_lens: List[int] = []
|
|
||||||
self.swap_out_lens: List[int] = []
|
|
||||||
self.swap_in_lens: List[int] = []
|
|
||||||
self.num_preemption: List[int] = []
|
|
||||||
self.num_waiting: List[int] = []
|
|
||||||
self.num_running: List[int] = []
|
|
||||||
self.num_swapped: List[int] = []
|
|
||||||
self.gpu_cache_usage: List[float] = []
|
|
||||||
self.cpu_cache_usage: List[float] = []
|
|
||||||
|
|
||||||
self.num_logical_blocks: List[int] = []
|
|
||||||
self.num_logical_tokens: List[int] = []
|
|
||||||
self.num_physical_blocks: List[int] = []
|
|
||||||
self.num_physical_tokens: List[int] = []
|
|
||||||
self.num_reserved_tokens: List[int] = []
|
|
||||||
|
|
||||||
def reset(
|
|
||||||
self,
|
|
||||||
num_gpu_blocks: int,
|
|
||||||
num_cpu_blocks: int,
|
|
||||||
) -> None:
|
|
||||||
self.__init__(num_gpu_blocks, num_cpu_blocks)
|
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
|
||||||
return {
|
|
||||||
'start_time': self.start_time,
|
|
||||||
'num_gpu_blocks': self.num_gpu_blocks,
|
|
||||||
'num_cpu_blocks': self.num_cpu_blocks,
|
|
||||||
'timestamps': self.timestamps,
|
|
||||||
'input_lens': self.input_lens,
|
|
||||||
'swap_out_lens': self.swap_out_lens,
|
|
||||||
'swap_in_lens': self.swap_in_lens,
|
|
||||||
'num_preemption': self.num_preemption,
|
|
||||||
'num_waiting': self.num_waiting,
|
|
||||||
'num_running': self.num_running,
|
|
||||||
'num_swapped': self.num_swapped,
|
|
||||||
'gpu_cache_usage': self.gpu_cache_usage,
|
|
||||||
'cpu_cache_usage': self.cpu_cache_usage,
|
|
||||||
'num_logical_blocks': self.num_logical_blocks,
|
|
||||||
'num_logical_tokens': self.num_logical_tokens,
|
|
||||||
'num_physical_blocks': self.num_physical_blocks,
|
|
||||||
'num_physical_tokens': self.num_physical_tokens,
|
|
||||||
'num_reserved_tokens': self.num_reserved_tokens,
|
|
||||||
}
|
|
||||||
|
|
||||||
def save(self, output_dir: str) -> None:
|
|
||||||
with open(os.path.join(output_dir, 'stats.pkl'), 'wb') as f:
|
|
||||||
pickle.dump(self.to_dict(), f)
|
|
||||||
@ -1,192 +0,0 @@
|
|||||||
import argparse
|
|
||||||
from typing import List, Tuple
|
|
||||||
import random
|
|
||||||
|
|
||||||
import ray
|
|
||||||
|
|
||||||
from cacheflow.master.scheduler import Scheduler
|
|
||||||
from cacheflow.models import get_memory_analyzer
|
|
||||||
from cacheflow.worker.controller import Controller, DeviceID
|
|
||||||
from cacheflow.sequence import SequenceGroup
|
|
||||||
from cacheflow.sampling_params import SamplingParams
|
|
||||||
|
|
||||||
|
|
||||||
class Server:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
model_path: str,
|
|
||||||
use_dummy_weights: bool,
|
|
||||||
pipeline_parallel_size: int,
|
|
||||||
tensor_parallel_size: int,
|
|
||||||
block_size: int,
|
|
||||||
dtype: str,
|
|
||||||
seed: int,
|
|
||||||
swap_space: int,
|
|
||||||
max_num_batched_tokens: int,
|
|
||||||
max_num_sequences: int,
|
|
||||||
num_nodes: int,
|
|
||||||
num_devices_per_node: int,
|
|
||||||
distributed_init_method: str,
|
|
||||||
all_stage_devices: List[List[DeviceID]],
|
|
||||||
gpu_memory: int,
|
|
||||||
cpu_memory: int,
|
|
||||||
collect_stats: bool = False,
|
|
||||||
do_memory_analysis: bool = False,
|
|
||||||
):
|
|
||||||
self.num_nodes = num_nodes
|
|
||||||
self.num_devices_per_node = num_devices_per_node
|
|
||||||
self.world_size = pipeline_parallel_size * tensor_parallel_size
|
|
||||||
|
|
||||||
self.memory_analyzer = get_memory_analyzer(
|
|
||||||
model_name=model,
|
|
||||||
block_size=block_size,
|
|
||||||
dtype=dtype,
|
|
||||||
gpu_memory=gpu_memory,
|
|
||||||
cpu_memory=cpu_memory,
|
|
||||||
tensor_parallel_size=tensor_parallel_size,
|
|
||||||
)
|
|
||||||
self.num_gpu_blocks = self.memory_analyzer.get_max_num_gpu_blocks(
|
|
||||||
max_num_batched_tokens=max_num_batched_tokens)
|
|
||||||
self.num_cpu_blocks = self.memory_analyzer.get_max_num_cpu_blocks(
|
|
||||||
swap_space=swap_space)
|
|
||||||
print(f'# GPU blocks: {self.num_gpu_blocks}, '
|
|
||||||
f'# CPU blocks: {self.num_cpu_blocks}')
|
|
||||||
|
|
||||||
# Create a controller for each pipeline stage.
|
|
||||||
self.controllers: List[Controller] = []
|
|
||||||
for i in range(pipeline_parallel_size):
|
|
||||||
controller = Controller(
|
|
||||||
stage_id=i,
|
|
||||||
stage_devices=all_stage_devices[i],
|
|
||||||
world_size=self.world_size,
|
|
||||||
pipeline_parallel_size=pipeline_parallel_size,
|
|
||||||
tensor_parallel_size=tensor_parallel_size,
|
|
||||||
distributed_init_method=distributed_init_method,
|
|
||||||
model_name=model,
|
|
||||||
block_size=block_size,
|
|
||||||
num_gpu_blocks=self.num_gpu_blocks,
|
|
||||||
num_cpu_blocks=self.num_cpu_blocks,
|
|
||||||
dtype=dtype,
|
|
||||||
seed=seed,
|
|
||||||
model_path=model_path,
|
|
||||||
use_dummy_weights=use_dummy_weights,
|
|
||||||
max_num_batched_tokens=max_num_batched_tokens,
|
|
||||||
)
|
|
||||||
self.controllers.append(controller)
|
|
||||||
|
|
||||||
# Create a scheduler.
|
|
||||||
self.scheduler = Scheduler(
|
|
||||||
controllers=self.controllers,
|
|
||||||
block_size=block_size,
|
|
||||||
num_gpu_blocks=self.num_gpu_blocks,
|
|
||||||
num_cpu_blocks=self.num_cpu_blocks,
|
|
||||||
max_num_batched_tokens=max_num_batched_tokens,
|
|
||||||
max_num_sequences=max_num_sequences,
|
|
||||||
collect_stats=collect_stats,
|
|
||||||
do_memory_analysis=do_memory_analysis,
|
|
||||||
)
|
|
||||||
# Connect the controllers.
|
|
||||||
for i in range(len(self.controllers) - 1):
|
|
||||||
self.controllers[i].set_next(self.controllers[i + 1])
|
|
||||||
self.controllers[-1].set_next(self.scheduler)
|
|
||||||
|
|
||||||
def add_sequence_groups(
|
|
||||||
self,
|
|
||||||
sequence_groups: List[Tuple[SequenceGroup, SamplingParams]]
|
|
||||||
):
|
|
||||||
self.scheduler.add_sequence_groups(sequence_groups)
|
|
||||||
|
|
||||||
def step(self):
|
|
||||||
return self.scheduler.step()
|
|
||||||
|
|
||||||
def has_unfinished_requests(self):
|
|
||||||
return (self.scheduler.waiting or self.scheduler.running or
|
|
||||||
self.scheduler.swapped)
|
|
||||||
|
|
||||||
|
|
||||||
def initialize_ray_cluster(
|
|
||||||
address: str = 'auto',
|
|
||||||
pipeline_parallel_size: int = 1,
|
|
||||||
tensor_parallel_size: int = 1,
|
|
||||||
) -> Tuple[int, int, str, List[List[DeviceID]]]:
|
|
||||||
# Connect to a ray cluster.
|
|
||||||
ray.init(address=address)
|
|
||||||
|
|
||||||
# Assume we have a uniform cluster that each node has the same number of
|
|
||||||
# GPUs for now.
|
|
||||||
valid_node_resources = []
|
|
||||||
num_devices_per_node = None
|
|
||||||
for node in ray.nodes():
|
|
||||||
if (not node['Alive']) or node['Resources']['GPU'] <= 0:
|
|
||||||
continue
|
|
||||||
if num_devices_per_node is None:
|
|
||||||
num_devices_per_node = node['Resources']['GPU']
|
|
||||||
else:
|
|
||||||
assert num_devices_per_node == node['Resources']['GPU'], (
|
|
||||||
"The number of GPUs per node is not uniform.")
|
|
||||||
for key in node['Resources']:
|
|
||||||
if key.startswith('node:'):
|
|
||||||
valid_node_resources.append(key)
|
|
||||||
|
|
||||||
num_nodes = len(valid_node_resources)
|
|
||||||
|
|
||||||
assert (pipeline_parallel_size * tensor_parallel_size
|
|
||||||
<= num_nodes * num_devices_per_node), (
|
|
||||||
"The number of required GPUs exceeds the total number of "
|
|
||||||
"available GPUs.")
|
|
||||||
if tensor_parallel_size >= num_devices_per_node:
|
|
||||||
assert tensor_parallel_size % num_devices_per_node == 0, (
|
|
||||||
"The number of tensor parallelism is not divisible by the "
|
|
||||||
"number of GPUs per node.")
|
|
||||||
else:
|
|
||||||
assert num_devices_per_node % tensor_parallel_size == 0, (
|
|
||||||
"The number of GPUs per node is not divisible by the number "
|
|
||||||
"of tensor parallelism.")
|
|
||||||
|
|
||||||
# Assign GPUs to pipeline stages.
|
|
||||||
rank = 0
|
|
||||||
current_node_id = 0
|
|
||||||
current_device_id = 0
|
|
||||||
distributed_init_method = None
|
|
||||||
all_stage_devices = []
|
|
||||||
|
|
||||||
for i in range(pipeline_parallel_size):
|
|
||||||
stage_devices = []
|
|
||||||
for j in range(tensor_parallel_size):
|
|
||||||
node_resource = valid_node_resources[current_node_id]
|
|
||||||
stage_devices.append((rank, node_resource, current_device_id))
|
|
||||||
if distributed_init_method is None:
|
|
||||||
ip = node_resource.split("node:")[-1]
|
|
||||||
port = random.randint(10000, 20000)
|
|
||||||
distributed_init_method = f"tcp://{ip}:{port}"
|
|
||||||
rank += 1
|
|
||||||
current_device_id += 1
|
|
||||||
if current_device_id >= num_devices_per_node:
|
|
||||||
current_node_id += 1
|
|
||||||
current_device_id = 0
|
|
||||||
all_stage_devices.append(stage_devices)
|
|
||||||
|
|
||||||
return (num_nodes, num_devices_per_node, distributed_init_method,
|
|
||||||
all_stage_devices)
|
|
||||||
|
|
||||||
|
|
||||||
def add_server_arguments(parser: argparse.ArgumentParser):
|
|
||||||
# Model arguments
|
|
||||||
parser.add_argument('--model', type=str, default='facebook/opt-125m', help='model name')
|
|
||||||
parser.add_argument('--model-path', type=str, default='~/.cacheflow/model_weights',
|
|
||||||
help='model path to download and load the weights')
|
|
||||||
# Parallel arguments
|
|
||||||
parser.add_argument('--pipeline-parallel-size', '-pp', type=int, default=1, help='number of pipeline stages')
|
|
||||||
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1, help='number of tensor parallel replicas')
|
|
||||||
# KV cache arguments
|
|
||||||
parser.add_argument('--block-size', type=int, default=16, choices=[1, 2, 4, 8, 16, 32, 64, 128, 256], help='token block size')
|
|
||||||
# NOTE(woosuk): If FlashAttention is used, the float data type is not supported.
|
|
||||||
parser.add_argument('--dtype', type=str, default='half', choices=['half'], help='data type')
|
|
||||||
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
|
|
||||||
parser.add_argument('--seed', type=int, default=0, help='random seed')
|
|
||||||
parser.add_argument('--swap-space', type=int, default=20, help='CPU swap space size (GiB) per GPU')
|
|
||||||
parser.add_argument('--max-num-batched-tokens', type=int, default=2560, help='maximum number of batched tokens per iteration')
|
|
||||||
parser.add_argument('--max-num-sequences', type=int, default=256, help='maximum number of sequences per iteration')
|
|
||||||
parser.add_argument('--use-dummy-weights', action='store_true', help='use dummy values for model weights')
|
|
||||||
return parser
|
|
||||||
@ -1,69 +0,0 @@
|
|||||||
import time
|
|
||||||
from typing import List, Optional, Set, Tuple
|
|
||||||
|
|
||||||
from transformers import AutoTokenizer
|
|
||||||
|
|
||||||
from cacheflow.sampling_params import SamplingParams
|
|
||||||
from cacheflow.sequence import Sequence, SequenceGroup
|
|
||||||
from cacheflow.utils import Counter
|
|
||||||
|
|
||||||
|
|
||||||
class SimpleFrontend:
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_name: str,
|
|
||||||
block_size: int,
|
|
||||||
) -> None:
|
|
||||||
self.block_size = block_size
|
|
||||||
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
||||||
self.seq_group_counter = Counter()
|
|
||||||
self.seq_counter = Counter()
|
|
||||||
self.inputs: List[Tuple[SequenceGroup, SamplingParams]] = []
|
|
||||||
|
|
||||||
def add_eos_token(self, sampling_params: SamplingParams) -> SamplingParams:
|
|
||||||
# Stop generation when we see an EOS token.
|
|
||||||
sampling_params.stop_token_ids.add(self.tokenizer.eos_token_id)
|
|
||||||
return sampling_params
|
|
||||||
|
|
||||||
def query(
|
|
||||||
self,
|
|
||||||
prompt: str,
|
|
||||||
sampling_params: SamplingParams,
|
|
||||||
) -> None:
|
|
||||||
token_ids = self.tokenizer.encode(prompt)
|
|
||||||
self._add_query(token_ids, sampling_params)
|
|
||||||
|
|
||||||
def _add_query(
|
|
||||||
self,
|
|
||||||
token_ids: List[int],
|
|
||||||
sampling_params: SamplingParams,
|
|
||||||
arrival_time: Optional[float] = None,
|
|
||||||
) -> None:
|
|
||||||
if arrival_time is None:
|
|
||||||
arrival_time = time.time()
|
|
||||||
seqs: List[Sequence] = []
|
|
||||||
for _ in range(sampling_params.n):
|
|
||||||
seq_id = next(self.seq_counter)
|
|
||||||
seq = Sequence(seq_id, token_ids, block_size=self.block_size)
|
|
||||||
seqs.append(seq)
|
|
||||||
|
|
||||||
group_id = next(self.seq_group_counter)
|
|
||||||
seq_group = SequenceGroup(group_id, seqs, arrival_time)
|
|
||||||
self.inputs.append((seq_group, sampling_params))
|
|
||||||
|
|
||||||
def get_inputs(self) -> List[Tuple[SequenceGroup, SamplingParams]]:
|
|
||||||
inputs = self.inputs
|
|
||||||
self.inputs = []
|
|
||||||
return inputs
|
|
||||||
|
|
||||||
def print_response(
|
|
||||||
self,
|
|
||||||
seq_group: SequenceGroup,
|
|
||||||
) -> None:
|
|
||||||
for seq in seq_group.seqs:
|
|
||||||
token_ids = seq.get_token_ids()
|
|
||||||
output = self.tokenizer.decode(token_ids, skip_special_tokens=True)
|
|
||||||
output = output.strip()
|
|
||||||
print(f'Seq {seq.seq_id}: {output!r}')
|
|
||||||
@ -1,10 +0,0 @@
|
|||||||
from cacheflow.models.input_metadata import InputMetadata
|
|
||||||
from cacheflow.models.model_utils import get_memory_analyzer
|
|
||||||
from cacheflow.models.model_utils import get_model
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
'InputMetadata',
|
|
||||||
'get_memory_analyzer',
|
|
||||||
'get_model',
|
|
||||||
]
|
|
||||||
@ -1,20 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
from cacheflow import activation_ops
|
|
||||||
|
|
||||||
|
|
||||||
class SiluAndMul(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
x: torch.Tensor, # (num_tokens, 2 * d)
|
|
||||||
) -> torch.Tensor: # (num_tokens, d)
|
|
||||||
num_tokens = x.shape[0]
|
|
||||||
d = x.shape[1] // 2
|
|
||||||
out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device)
|
|
||||||
activation_ops.silu_and_mul(out, x)
|
|
||||||
return out
|
|
||||||
@ -1,207 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
from flash_attn.flash_attn_interface import _flash_attn_forward
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
from cacheflow import attention_ops
|
|
||||||
from cacheflow import cache_ops
|
|
||||||
from cacheflow import pos_encoding_ops
|
|
||||||
from cacheflow.models import InputMetadata
|
|
||||||
|
|
||||||
|
|
||||||
class GPTCacheFlowAttention(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, scale: float) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.scale = float(scale)
|
|
||||||
|
|
||||||
def multi_query_kv_attention(
|
|
||||||
self,
|
|
||||||
output: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
|
|
||||||
query: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
|
|
||||||
key: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
|
|
||||||
value: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
|
|
||||||
cumulative_prompt_lens: torch.Tensor, # [num_prompts + 1]
|
|
||||||
max_prompt_len: int,
|
|
||||||
) -> None:
|
|
||||||
if query.dtype == torch.float:
|
|
||||||
raise ValueError('The float data type is not supported by '
|
|
||||||
'FlashAttention. Use the half data type instead.')
|
|
||||||
head_size = query.shape[-1]
|
|
||||||
if head_size > 128:
|
|
||||||
raise ValueError('FlashAttention does not support head_size > 128.')
|
|
||||||
|
|
||||||
# Directly call FlashAttention's internal function to avoid allocating
|
|
||||||
# a new tensor for the output.
|
|
||||||
_flash_attn_forward(
|
|
||||||
query,
|
|
||||||
key,
|
|
||||||
value,
|
|
||||||
output,
|
|
||||||
cumulative_prompt_lens,
|
|
||||||
cumulative_prompt_lens,
|
|
||||||
max_prompt_len,
|
|
||||||
max_prompt_len,
|
|
||||||
dropout_p=0.0,
|
|
||||||
softmax_scale=self.scale,
|
|
||||||
causal=True,
|
|
||||||
return_softmax=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
def single_query_cached_kv_attention(
|
|
||||||
self,
|
|
||||||
output: torch.Tensor, # [num_generation_tokens, num_heads, head_size]
|
|
||||||
query: torch.Tensor, # [num_generation_tokens, num_heads, head_size]
|
|
||||||
key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x]
|
|
||||||
value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
|
|
||||||
input_metadata: InputMetadata,
|
|
||||||
) -> None:
|
|
||||||
head_size = value_cache.shape[2]
|
|
||||||
supported_head_sizes = [32, 64, 80, 96, 128, 160, 192, 256]
|
|
||||||
if head_size not in supported_head_sizes:
|
|
||||||
raise ValueError(f'head_size ({head_size}) is not supported by '
|
|
||||||
'the single_query_cached_kv_attention kernel. '
|
|
||||||
'Use one of the following head sizes: '
|
|
||||||
f'{supported_head_sizes}.')
|
|
||||||
|
|
||||||
block_size = value_cache.shape[3]
|
|
||||||
attention_ops.single_query_cached_kv_attention(
|
|
||||||
output,
|
|
||||||
query,
|
|
||||||
key_cache,
|
|
||||||
value_cache,
|
|
||||||
self.scale,
|
|
||||||
input_metadata.block_tables,
|
|
||||||
input_metadata.context_lens,
|
|
||||||
block_size,
|
|
||||||
input_metadata.max_context_len,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
query: torch.Tensor, # [num_tokens, num_heads * head_size]
|
|
||||||
key: torch.Tensor, # [num_tokens, num_heads * head_size]
|
|
||||||
value: torch.Tensor, # [num_tokens, num_heads * head_size]
|
|
||||||
key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x]
|
|
||||||
value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
|
|
||||||
input_metadata: InputMetadata,
|
|
||||||
cache_event: Optional[torch.cuda.Event],
|
|
||||||
) -> torch.Tensor: # [num_tokens, num_heads * head_size]
|
|
||||||
# NOTE: The query, key, and value tensors must be sliced from a qkv
|
|
||||||
# tensor of shape [num_tokens, 3 * num_heads * head_size].
|
|
||||||
|
|
||||||
# Reshape the query, key, and value tensors.
|
|
||||||
num_heads = value_cache.shape[1]
|
|
||||||
head_size = value_cache.shape[2]
|
|
||||||
query = query.view(-1, num_heads, head_size)
|
|
||||||
key = key.view(-1, num_heads, head_size)
|
|
||||||
value = value.view(-1, num_heads, head_size)
|
|
||||||
|
|
||||||
# Pre-allocate the output tensor.
|
|
||||||
output = torch.empty_like(query)
|
|
||||||
|
|
||||||
# Compute the attention op for prompts.
|
|
||||||
num_prompt_tokens = input_metadata.num_prompt_tokens
|
|
||||||
if num_prompt_tokens > 0:
|
|
||||||
self.multi_query_kv_attention(
|
|
||||||
output[:num_prompt_tokens],
|
|
||||||
query[:num_prompt_tokens],
|
|
||||||
key[:num_prompt_tokens],
|
|
||||||
value[:num_prompt_tokens],
|
|
||||||
input_metadata.cumulative_prompt_lens,
|
|
||||||
input_metadata.max_prompt_len,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Wait until the cache op is done.
|
|
||||||
if cache_event is not None:
|
|
||||||
cache_event.wait()
|
|
||||||
|
|
||||||
# Reshape the keys and values and store them in the cache.
|
|
||||||
num_valid_tokens = input_metadata.num_valid_tokens
|
|
||||||
if num_valid_tokens > 0:
|
|
||||||
# The stride is 3 because the key and value are sliced from qkv.
|
|
||||||
cache_ops.reshape_and_cache(
|
|
||||||
key[:num_valid_tokens],
|
|
||||||
value[:num_valid_tokens],
|
|
||||||
key_cache,
|
|
||||||
value_cache,
|
|
||||||
input_metadata.slot_mapping,
|
|
||||||
)
|
|
||||||
|
|
||||||
if input_metadata.num_generation_tokens > 0:
|
|
||||||
# Compute the attention op for generation tokens.
|
|
||||||
self.single_query_cached_kv_attention(
|
|
||||||
output[num_prompt_tokens:num_valid_tokens],
|
|
||||||
query[num_prompt_tokens:num_valid_tokens],
|
|
||||||
key_cache,
|
|
||||||
value_cache,
|
|
||||||
input_metadata)
|
|
||||||
|
|
||||||
# Reshape the output tensor.
|
|
||||||
# NOTE(woosuk): The output tensor may include paddings.
|
|
||||||
return output.view(-1, num_heads * head_size)
|
|
||||||
|
|
||||||
|
|
||||||
class OPTCacheFlowAttention(GPTCacheFlowAttention):
|
|
||||||
"""OPT uses the same attention mechanism as GPT."""
|
|
||||||
|
|
||||||
def __init__(self, scale: float) -> None:
|
|
||||||
super().__init__(scale)
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaCacheFlowAttention(GPTCacheFlowAttention):
|
|
||||||
"""Llama uses GPT-NeoX style rotary embedding."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
scale: float,
|
|
||||||
head_size: int,
|
|
||||||
max_position: int = 8192,
|
|
||||||
base: int = 10000,
|
|
||||||
) -> None:
|
|
||||||
super().__init__(scale)
|
|
||||||
|
|
||||||
# Create the cos and sin cache.
|
|
||||||
inv_freq = 1.0 / (base ** (torch.arange(0, head_size, 2) / head_size))
|
|
||||||
t = torch.arange(max_position).float()
|
|
||||||
freqs = torch.einsum('i,j -> ij', t, inv_freq.float())
|
|
||||||
cos = freqs.cos()
|
|
||||||
sin = freqs.sin()
|
|
||||||
cache = torch.cat((cos, sin), dim=-1)
|
|
||||||
|
|
||||||
# FIXME(woosuk): This assumes that we configure the default dtype when
|
|
||||||
# initializing the model. Make it more robust.
|
|
||||||
torch_dtype = torch.get_default_dtype()
|
|
||||||
cache = cache.to(torch_dtype)
|
|
||||||
# Embedding size: [max_position, head_size]
|
|
||||||
self.register_buffer('cos_sin_cache', cache, persistent=False)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
positions: torch.LongTensor, # [num_tokens]
|
|
||||||
query: torch.Tensor, # [num_tokens, num_heads * head_size]
|
|
||||||
key: torch.Tensor, # [num_tokens, num_heads * head_size]
|
|
||||||
value: torch.Tensor, # [num_tokens, num_heads * head_size]
|
|
||||||
key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x]
|
|
||||||
value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
|
|
||||||
input_metadata: InputMetadata,
|
|
||||||
cache_event: Optional[torch.cuda.Event],
|
|
||||||
) -> torch.Tensor: # [num_tokens, num_heads * head_size]
|
|
||||||
# Apply rotary embedding to the query and key before passing them
|
|
||||||
# to the attention op.
|
|
||||||
pos_encoding_ops.rotary_embedding_neox(
|
|
||||||
positions,
|
|
||||||
query,
|
|
||||||
key,
|
|
||||||
self.cos_sin_cache,
|
|
||||||
)
|
|
||||||
return super().forward(
|
|
||||||
query,
|
|
||||||
key,
|
|
||||||
value,
|
|
||||||
key_cache,
|
|
||||||
value_cache,
|
|
||||||
input_metadata,
|
|
||||||
cache_event,
|
|
||||||
)
|
|
||||||
@ -1,55 +0,0 @@
|
|||||||
from typing import List, Dict, Tuple
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from cacheflow.sampling_params import SamplingParams
|
|
||||||
|
|
||||||
|
|
||||||
class InputMetadata:
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
seq_groups: List[Tuple[List[int], SamplingParams]],
|
|
||||||
seq_logprobs: Dict[int, float], # Seq id -> cumulative logprobs.
|
|
||||||
prompt_lens: List[int],
|
|
||||||
cumulative_prompt_lens: torch.Tensor,
|
|
||||||
slot_mapping: torch.Tensor,
|
|
||||||
context_lens: torch.Tensor,
|
|
||||||
max_context_len: int,
|
|
||||||
block_tables: torch.Tensor,
|
|
||||||
) -> None:
|
|
||||||
self.seq_groups = seq_groups
|
|
||||||
self.seq_logprobs = seq_logprobs
|
|
||||||
self.prompt_lens = prompt_lens
|
|
||||||
self.cumulative_prompt_lens = cumulative_prompt_lens
|
|
||||||
self.slot_mapping = slot_mapping
|
|
||||||
self.context_lens = context_lens
|
|
||||||
self.max_context_len = max_context_len
|
|
||||||
self.block_tables = block_tables
|
|
||||||
|
|
||||||
self.num_prompts = len(prompt_lens)
|
|
||||||
self.num_prompt_tokens = sum(prompt_lens)
|
|
||||||
self.max_prompt_len = max(prompt_lens) if prompt_lens else 0
|
|
||||||
self.num_generation_tokens = context_lens.shape[0]
|
|
||||||
self.num_valid_tokens = slot_mapping.shape[0]
|
|
||||||
if block_tables.numel() > 0:
|
|
||||||
self.max_num_blocks_per_seq = block_tables.shape[1]
|
|
||||||
else:
|
|
||||||
self.max_num_blocks_per_seq = 0
|
|
||||||
assert block_tables.shape[0] == self.num_generation_tokens
|
|
||||||
assert context_lens.shape[0] == self.num_generation_tokens
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return (f'InputMetadata('
|
|
||||||
f'num_prompts={self.num_prompts}, '
|
|
||||||
f'num_prompt_tokens={self.num_prompt_tokens}, '
|
|
||||||
f'max_prompt_len={self.max_prompt_len}, '
|
|
||||||
f'num_generation_tokens={self.num_generation_tokens}, '
|
|
||||||
f'num_valid_tokens={self.num_valid_tokens}, '
|
|
||||||
f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, '
|
|
||||||
f'max_context_len={self.max_context_len}), '
|
|
||||||
f'prompt_lens={self.prompt_lens}, '
|
|
||||||
f'cumulative_prompt_lens={self.cumulative_prompt_lens}, '
|
|
||||||
f'slot_mapping={self.slot_mapping}, '
|
|
||||||
f'context_lens={self.context_lens}, '
|
|
||||||
f'block_tables={self.block_tables})')
|
|
||||||
@ -1,292 +0,0 @@
|
|||||||
"""1D LLaMA model compatible with HuggingFace weights."""
|
|
||||||
import os
|
|
||||||
import glob
|
|
||||||
import filelock
|
|
||||||
from tqdm import tqdm
|
|
||||||
from typing import Dict, List, Optional, Tuple
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
from transformers import LlamaConfig
|
|
||||||
|
|
||||||
from cacheflow.models import InputMetadata
|
|
||||||
from cacheflow.models.activation import SiluAndMul
|
|
||||||
from cacheflow.models.attention import LlamaCacheFlowAttention
|
|
||||||
from cacheflow.models.layernorm import RMSNorm
|
|
||||||
from cacheflow.models.sample import Sampler
|
|
||||||
from cacheflow.parallel_utils.parallel_state import (
|
|
||||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
|
||||||
from cacheflow.parallel_utils.tensor_parallel import (VocabParallelEmbedding,
|
|
||||||
ColumnParallelLinear,
|
|
||||||
RowParallelLinear)
|
|
||||||
from cacheflow.sequence import SequenceOutputs
|
|
||||||
|
|
||||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaMLP(nn.Module):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
hidden_size: int,
|
|
||||||
intermediate_size: int,
|
|
||||||
hidden_act: str,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.gate_up_proj = ColumnParallelLinear(hidden_size, 2 * intermediate_size,
|
|
||||||
bias=False, gather_output=False,
|
|
||||||
perform_initialization=False)
|
|
||||||
self.down_proj = RowParallelLinear(intermediate_size, hidden_size,
|
|
||||||
bias=False, input_is_parallel=True,
|
|
||||||
perform_initialization=False)
|
|
||||||
if hidden_act != 'silu':
|
|
||||||
raise ValueError(f'Unsupported activation: {hidden_act}. '
|
|
||||||
'Only silu is supported for now.')
|
|
||||||
self.act_fn = SiluAndMul()
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
gate_up, _ = self.gate_up_proj(x)
|
|
||||||
x = self.act_fn(gate_up)
|
|
||||||
x, _ = self.down_proj(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaAttention(nn.Module):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
hidden_size: int,
|
|
||||||
num_heads: int,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
|
|
||||||
self.total_num_heads = num_heads
|
|
||||||
assert self.total_num_heads % tensor_model_parallel_world_size == 0
|
|
||||||
self.num_heads = self.total_num_heads // tensor_model_parallel_world_size
|
|
||||||
self.head_dim = hidden_size // self.total_num_heads
|
|
||||||
self.scaling = self.head_dim ** -0.5
|
|
||||||
|
|
||||||
self.qkv_proj = ColumnParallelLinear(
|
|
||||||
hidden_size,
|
|
||||||
3 * self.total_num_heads * self.head_dim,
|
|
||||||
bias=False,
|
|
||||||
gather_output=False,
|
|
||||||
perform_initialization=False,
|
|
||||||
)
|
|
||||||
self.o_proj = RowParallelLinear(
|
|
||||||
self.total_num_heads * self.head_dim,
|
|
||||||
hidden_size,
|
|
||||||
bias=False,
|
|
||||||
input_is_parallel=True,
|
|
||||||
perform_initialization=False,
|
|
||||||
)
|
|
||||||
self.attn = LlamaCacheFlowAttention(self.scaling, self.head_dim)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
positions: torch.LongTensor,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
kv_cache: KVCache,
|
|
||||||
input_metadata: InputMetadata,
|
|
||||||
cache_event: Optional[torch.cuda.Event],
|
|
||||||
) -> torch.Tensor:
|
|
||||||
qkv, _ = self.qkv_proj(hidden_states)
|
|
||||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
|
||||||
k_cache, v_cache = kv_cache
|
|
||||||
attn_output = self.attn(
|
|
||||||
positions, q, k, v, k_cache, v_cache, input_metadata, cache_event)
|
|
||||||
output, _ = self.o_proj(attn_output)
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaDecoderLayer(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, config: LlamaConfig):
|
|
||||||
super().__init__()
|
|
||||||
self.hidden_size = config.hidden_size
|
|
||||||
self.self_attn = LlamaAttention(
|
|
||||||
hidden_size=self.hidden_size,
|
|
||||||
num_heads=config.num_attention_heads,
|
|
||||||
)
|
|
||||||
self.mlp = LlamaMLP(
|
|
||||||
hidden_size=self.hidden_size,
|
|
||||||
intermediate_size=config.intermediate_size,
|
|
||||||
hidden_act=config.hidden_act,
|
|
||||||
)
|
|
||||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
||||||
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
positions: torch.LongTensor,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
kv_cache: KVCache,
|
|
||||||
input_metadata: InputMetadata,
|
|
||||||
cache_event: Optional[torch.cuda.Event],
|
|
||||||
) -> torch.Tensor:
|
|
||||||
# Self Attention
|
|
||||||
residual = hidden_states
|
|
||||||
hidden_states = self.input_layernorm(hidden_states)
|
|
||||||
hidden_states = self.self_attn(
|
|
||||||
positions=positions,
|
|
||||||
hidden_states=hidden_states,
|
|
||||||
kv_cache=kv_cache,
|
|
||||||
input_metadata=input_metadata,
|
|
||||||
cache_event=cache_event,
|
|
||||||
)
|
|
||||||
hidden_states = residual + hidden_states
|
|
||||||
|
|
||||||
# Fully Connected
|
|
||||||
residual = hidden_states
|
|
||||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
|
||||||
hidden_states = self.mlp(hidden_states)
|
|
||||||
hidden_states = residual + hidden_states
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaModel(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, config: LlamaConfig):
|
|
||||||
super().__init__()
|
|
||||||
self.config = config
|
|
||||||
self.padding_idx = config.pad_token_id
|
|
||||||
self.vocab_size = config.vocab_size
|
|
||||||
|
|
||||||
self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size,
|
|
||||||
perform_initialization=False)
|
|
||||||
self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
|
||||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.LongTensor,
|
|
||||||
positions: torch.LongTensor,
|
|
||||||
kv_caches: List[KVCache],
|
|
||||||
input_metadata: InputMetadata,
|
|
||||||
cache_events: Optional[List[torch.cuda.Event]],
|
|
||||||
) -> torch.Tensor:
|
|
||||||
hidden_states = self.embed_tokens(input_ids)
|
|
||||||
for i in range(len(self.layers)):
|
|
||||||
if cache_events is None:
|
|
||||||
cache_event = None
|
|
||||||
else:
|
|
||||||
cache_event = cache_events[i]
|
|
||||||
layer = self.layers[i]
|
|
||||||
hidden_states = layer(
|
|
||||||
positions,
|
|
||||||
hidden_states,
|
|
||||||
kv_caches[i],
|
|
||||||
input_metadata,
|
|
||||||
cache_event,
|
|
||||||
)
|
|
||||||
hidden_states = self.norm(hidden_states)
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaForCausalLM(nn.Module):
|
|
||||||
def __init__(self, config):
|
|
||||||
super().__init__()
|
|
||||||
self.config = config
|
|
||||||
self.model = LlamaModel(config)
|
|
||||||
self.lm_head = ColumnParallelLinear(config.hidden_size,
|
|
||||||
config.vocab_size,
|
|
||||||
bias=False,
|
|
||||||
gather_output=False,
|
|
||||||
perform_initialization=False)
|
|
||||||
self.sampler = Sampler()
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.LongTensor,
|
|
||||||
positions: torch.LongTensor,
|
|
||||||
kv_caches: List[KVCache],
|
|
||||||
input_metadata: InputMetadata,
|
|
||||||
cache_events: Optional[List[torch.cuda.Event]],
|
|
||||||
) -> Dict[int, SequenceOutputs]:
|
|
||||||
hidden_states = self.model(
|
|
||||||
input_ids, positions, kv_caches, input_metadata, cache_events)
|
|
||||||
next_tokens = self.sampler(
|
|
||||||
self.lm_head.weight, hidden_states, input_metadata)
|
|
||||||
return next_tokens
|
|
||||||
|
|
||||||
_column_parallel_weights = ["embed_tokens.weight", "lm_head.weight",
|
|
||||||
"qkv_proj.weight", "gate_proj.weight",
|
|
||||||
"up_proj.weight"]
|
|
||||||
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
|
|
||||||
|
|
||||||
def load_weights(self, weights_path: str):
|
|
||||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
|
||||||
state_dict = self.state_dict()
|
|
||||||
for name, param in state_dict.items():
|
|
||||||
if "qkv_proj" in name or "gate_up_proj" in name:
|
|
||||||
if "qkv_proj" in name:
|
|
||||||
original_name = "qkv_proj"
|
|
||||||
weight_names = ["q_proj", "k_proj", "v_proj"]
|
|
||||||
shard_size = param.shape[0] // 3
|
|
||||||
else:
|
|
||||||
original_name = "gate_up_proj"
|
|
||||||
weight_names = ["gate_proj", "up_proj"]
|
|
||||||
shard_size = param.shape[0] // 2
|
|
||||||
weights_to_concat = []
|
|
||||||
for weight_name in weight_names:
|
|
||||||
weight = np.load(os.path.join(
|
|
||||||
weights_path, name.replace(original_name, weight_name)))
|
|
||||||
weights_to_concat.append(weight[
|
|
||||||
shard_size * tensor_model_parallel_rank
|
|
||||||
:shard_size * (tensor_model_parallel_rank + 1)])
|
|
||||||
loaded_weight = torch.from_numpy(
|
|
||||||
np.concatenate(weights_to_concat, axis=0))
|
|
||||||
else:
|
|
||||||
loaded_weight = torch.from_numpy(
|
|
||||||
np.load(os.path.join(weights_path, name)))
|
|
||||||
for p in self._column_parallel_weights:
|
|
||||||
if p in name:
|
|
||||||
shard_size = param.shape[0]
|
|
||||||
loaded_weight = loaded_weight[
|
|
||||||
shard_size * tensor_model_parallel_rank
|
|
||||||
:shard_size * (tensor_model_parallel_rank + 1)]
|
|
||||||
break
|
|
||||||
for p in self._row_parallel_weights:
|
|
||||||
if p in name:
|
|
||||||
shard_size = param.shape[1]
|
|
||||||
loaded_weight = loaded_weight[
|
|
||||||
:,
|
|
||||||
shard_size * tensor_model_parallel_rank
|
|
||||||
:shard_size * (tensor_model_parallel_rank + 1)]
|
|
||||||
break
|
|
||||||
|
|
||||||
assert param.shape == loaded_weight.shape
|
|
||||||
param.data.copy_(loaded_weight)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_weights(model_name: str, path: str):
|
|
||||||
if not os.path.isfile(os.path.join(model_name, "config.json")):
|
|
||||||
raise ValueError("LLaMA model's model_name has to be a path"
|
|
||||||
"to the huggingface model's directory.")
|
|
||||||
path = os.path.join(model_name, f"np")
|
|
||||||
path = os.path.abspath(os.path.expanduser(path))
|
|
||||||
os.makedirs(path, exist_ok=True)
|
|
||||||
lock_path = os.path.join(path, "file_lock")
|
|
||||||
lock = filelock.FileLock(lock_path)
|
|
||||||
|
|
||||||
with lock:
|
|
||||||
test_weight_path = os.path.join(path, "model.embed_tokens.weight")
|
|
||||||
if os.path.exists(test_weight_path):
|
|
||||||
return path
|
|
||||||
|
|
||||||
bin_files = glob.glob(os.path.join(model_name, "*.bin"))
|
|
||||||
|
|
||||||
for bin_file in tqdm(bin_files, desc="Convert format"):
|
|
||||||
state = torch.load(bin_file, map_location="cpu")
|
|
||||||
for name, param in tqdm(state.items(), leave=False):
|
|
||||||
param_path = os.path.join(path, name)
|
|
||||||
with open(param_path, "wb") as f:
|
|
||||||
np.save(f, param.cpu().detach().numpy())
|
|
||||||
|
|
||||||
return path
|
|
||||||
|
|
||||||
def initialize_dummy_weights(self) -> None:
|
|
||||||
for param in self.state_dict().values():
|
|
||||||
param.data.uniform_(-0.1, 0.1)
|
|
||||||
@ -1,240 +0,0 @@
|
|||||||
import torch
|
|
||||||
from transformers import AutoConfig
|
|
||||||
|
|
||||||
from cacheflow.models.utils import get_dtype_size
|
|
||||||
|
|
||||||
_GiB = 1 << 30
|
|
||||||
|
|
||||||
|
|
||||||
class CacheFlowMemoryAnalyzer:
|
|
||||||
|
|
||||||
def get_max_num_gpu_blocks(
|
|
||||||
self,
|
|
||||||
max_num_batched_tokens: int,
|
|
||||||
memory_utilization: float,
|
|
||||||
) -> int:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
def get_workspace_size(self) -> int:
|
|
||||||
return 1 * _GiB
|
|
||||||
|
|
||||||
def get_cache_block_size(self) -> int:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
def get_max_num_cpu_blocks(
|
|
||||||
self,
|
|
||||||
swap_space: int,
|
|
||||||
) -> int:
|
|
||||||
swap_space = swap_space * _GiB
|
|
||||||
cpu_memory = self.cpu_memory
|
|
||||||
if swap_space > 0.8 * cpu_memory:
|
|
||||||
raise ValueError(f'The swap space ({swap_space / _GiB:.2f} GiB) '
|
|
||||||
'takes more than 80% of the available memory '
|
|
||||||
f'({cpu_memory / _GiB:.2f} GiB).'
|
|
||||||
'Please check the swap space size.')
|
|
||||||
if swap_space > 0.5 * cpu_memory:
|
|
||||||
print(f'WARNING: The swap space ({swap_space / _GiB:.2f} GiB) '
|
|
||||||
'takes more than 50% of the available memory '
|
|
||||||
f'({cpu_memory / _GiB:.2f} GiB).'
|
|
||||||
'This may slow the system performance.')
|
|
||||||
max_num_blocks = swap_space // self.get_cache_block_size()
|
|
||||||
return max_num_blocks
|
|
||||||
|
|
||||||
|
|
||||||
class OPTMemoryAnalyzer(CacheFlowMemoryAnalyzer):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_name: str,
|
|
||||||
block_size: int,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
gpu_memory: int,
|
|
||||||
cpu_memory: int,
|
|
||||||
tensor_parallel_size: int,
|
|
||||||
) -> None:
|
|
||||||
self.model_name = model_name
|
|
||||||
self.block_size = block_size
|
|
||||||
self.dtype = dtype
|
|
||||||
self.gpu_memory = gpu_memory
|
|
||||||
self.cpu_memory = cpu_memory
|
|
||||||
self.tensor_parallel_size = tensor_parallel_size
|
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(model_name)
|
|
||||||
self.num_layers = config.num_hidden_layers
|
|
||||||
self.hidden_size = config.hidden_size
|
|
||||||
self.num_heads = config.num_attention_heads
|
|
||||||
self.head_size = config.hidden_size // self.num_heads
|
|
||||||
self.ffn_size = config.ffn_dim
|
|
||||||
self.embedding_size = config.word_embed_proj_dim
|
|
||||||
self.vocab_size = config.vocab_size
|
|
||||||
self.max_position = config.max_position_embeddings
|
|
||||||
|
|
||||||
def _get_param_size(self) -> int:
|
|
||||||
word_embedding = self.vocab_size * self.embedding_size // self.tensor_parallel_size
|
|
||||||
if self.embedding_size != self.hidden_size:
|
|
||||||
# Project in/out.
|
|
||||||
word_embedding += 2 * self.embedding_size * self.hidden_size
|
|
||||||
position_embedding = self.max_position * self.hidden_size
|
|
||||||
|
|
||||||
ln1 = 2 * self.hidden_size
|
|
||||||
q = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
|
|
||||||
k = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
|
|
||||||
v = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
|
|
||||||
out = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
|
|
||||||
mha = ln1 + q + k + v + out
|
|
||||||
|
|
||||||
ln2 = 2 * self.hidden_size
|
|
||||||
ffn1 = self.hidden_size * self.ffn_size // self.tensor_parallel_size + self.ffn_size
|
|
||||||
ffn2 = self.ffn_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
|
|
||||||
ffn = ln2 + ffn1 + ffn2
|
|
||||||
|
|
||||||
total = (word_embedding + position_embedding +
|
|
||||||
self.num_layers * (mha + ffn))
|
|
||||||
dtype_size = get_dtype_size(self.dtype)
|
|
||||||
return dtype_size * total
|
|
||||||
|
|
||||||
def _get_max_act_size(
|
|
||||||
self,
|
|
||||||
max_num_batched_tokens: int,
|
|
||||||
) -> int:
|
|
||||||
# NOTE: We approxmiately calculate the maximum activation size by
|
|
||||||
# estimating
|
|
||||||
# 1) the maximum activation tensor size during inference
|
|
||||||
# 2) the residual tensor size during inference
|
|
||||||
# Here, we assume that FlashAttention is used and
|
|
||||||
# thus the attention maps are never materialized in GPU DRAM.
|
|
||||||
residual = max_num_batched_tokens * self.hidden_size
|
|
||||||
qkv = 3 * (max_num_batched_tokens * self.hidden_size) // self.tensor_parallel_size
|
|
||||||
ffn = max_num_batched_tokens * self.ffn_size // self.tensor_parallel_size
|
|
||||||
# Double the activation size for input and output.
|
|
||||||
max_act = 2 * (max(qkv, ffn) + residual)
|
|
||||||
# Size of output logits.
|
|
||||||
output_logits = 2 * (max_num_batched_tokens * self.vocab_size)
|
|
||||||
max_act = max(max_act, output_logits)
|
|
||||||
dtype_size = get_dtype_size(self.dtype)
|
|
||||||
return dtype_size * max_act
|
|
||||||
|
|
||||||
def get_cache_block_size(self) -> int:
|
|
||||||
key_cache_block = self.block_size * self.hidden_size // self.tensor_parallel_size
|
|
||||||
value_cache_block = key_cache_block
|
|
||||||
total = self.num_layers * (key_cache_block + value_cache_block)
|
|
||||||
dtype_size = get_dtype_size(self.dtype)
|
|
||||||
return dtype_size * total
|
|
||||||
|
|
||||||
def get_max_num_gpu_blocks(
|
|
||||||
self,
|
|
||||||
max_num_batched_tokens: int,
|
|
||||||
memory_utilization: float = 0.95,
|
|
||||||
) -> int:
|
|
||||||
# NOTE(woosuk): This assumes that the machine has homogeneous GPUs.
|
|
||||||
usable_memory = int(memory_utilization * self.gpu_memory)
|
|
||||||
|
|
||||||
param_size = self._get_param_size()
|
|
||||||
act_size = self._get_max_act_size(max_num_batched_tokens)
|
|
||||||
workspace_size = self.get_workspace_size()
|
|
||||||
|
|
||||||
max_cache_size = usable_memory - (param_size + act_size + workspace_size)
|
|
||||||
if max_cache_size <= 0:
|
|
||||||
raise RuntimeError('Not enough GPU memory.')
|
|
||||||
max_num_blocks = max_cache_size // self.get_cache_block_size()
|
|
||||||
return max_num_blocks
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaMemoryAnalyzer(CacheFlowMemoryAnalyzer):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_name: str,
|
|
||||||
block_size: int,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
gpu_memory: int,
|
|
||||||
cpu_memory: int,
|
|
||||||
tensor_parallel_size: int,
|
|
||||||
) -> None:
|
|
||||||
self.model_name = model_name
|
|
||||||
self.block_size = block_size
|
|
||||||
self.dtype = dtype
|
|
||||||
self.gpu_memory = gpu_memory
|
|
||||||
self.cpu_memory = cpu_memory
|
|
||||||
self.tensor_parallel_size = tensor_parallel_size
|
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(model_name)
|
|
||||||
self.num_layers = config.num_hidden_layers
|
|
||||||
self.hidden_size = config.hidden_size
|
|
||||||
self.num_heads = config.num_attention_heads
|
|
||||||
self.head_size = config.hidden_size // self.num_heads
|
|
||||||
self.ffn_size = config.intermediate_size
|
|
||||||
self.vocab_size = config.vocab_size
|
|
||||||
self.max_position = 8192
|
|
||||||
|
|
||||||
def _get_param_size(self) -> int:
|
|
||||||
word_embedding = self.vocab_size * self.hidden_size // self.tensor_parallel_size
|
|
||||||
position_embedding = self.max_position * self.hidden_size
|
|
||||||
|
|
||||||
# NOTE: LLaMA does not have bias terms.
|
|
||||||
ln1 = self.hidden_size
|
|
||||||
q = self.hidden_size * self.hidden_size // self.tensor_parallel_size
|
|
||||||
k = self.hidden_size * self.hidden_size // self.tensor_parallel_size
|
|
||||||
v = self.hidden_size * self.hidden_size // self.tensor_parallel_size
|
|
||||||
out = self.hidden_size * self.hidden_size // self.tensor_parallel_size
|
|
||||||
# Rotary embedding.
|
|
||||||
# TODO(woosuk): Share the rotary embedding between layers.
|
|
||||||
rot = self.max_position * self.head_size
|
|
||||||
mha = ln1 + q + k + v + out + rot
|
|
||||||
|
|
||||||
ln2 = self.hidden_size
|
|
||||||
gate = self.hidden_size * self.ffn_size // self.tensor_parallel_size
|
|
||||||
down = self.ffn_size * self.hidden_size // self.tensor_parallel_size
|
|
||||||
up = self.hidden_size * self.ffn_size // self.tensor_parallel_size
|
|
||||||
ffn = ln2 + gate + down + up
|
|
||||||
|
|
||||||
total = (word_embedding + position_embedding + self.num_layers * (mha + ffn))
|
|
||||||
dtype_size = get_dtype_size(self.dtype)
|
|
||||||
return dtype_size * total
|
|
||||||
|
|
||||||
def _get_max_act_size(
|
|
||||||
self,
|
|
||||||
max_num_batched_tokens: int,
|
|
||||||
) -> int:
|
|
||||||
# NOTE: We approxmiately calculate the maximum activation size by
|
|
||||||
# estimating
|
|
||||||
# 1) the maximum activation tensor size during inference
|
|
||||||
# 2) the residual tensor size during inference
|
|
||||||
# Here, we assume that FlashAttention is used and
|
|
||||||
# thus the attention maps are never materialized in GPU DRAM.
|
|
||||||
residual = max_num_batched_tokens * self.hidden_size
|
|
||||||
qkv = 3 * (max_num_batched_tokens * self.hidden_size) // self.tensor_parallel_size
|
|
||||||
ffn = 2 * (max_num_batched_tokens * self.ffn_size) // self.tensor_parallel_size
|
|
||||||
# Double the activation size for input and output.
|
|
||||||
max_act = 2 * (max(qkv, ffn) + residual)
|
|
||||||
# Size of output logits.
|
|
||||||
output_logits = 2 * (max_num_batched_tokens * self.vocab_size)
|
|
||||||
max_act = max(max_act, output_logits)
|
|
||||||
dtype_size = get_dtype_size(self.dtype)
|
|
||||||
return dtype_size * max_act
|
|
||||||
|
|
||||||
def get_cache_block_size(self) -> int:
|
|
||||||
key_cache_block = self.block_size * self.hidden_size // self.tensor_parallel_size
|
|
||||||
value_cache_block = key_cache_block
|
|
||||||
total = self.num_layers * (key_cache_block + value_cache_block)
|
|
||||||
dtype_size = get_dtype_size(self.dtype)
|
|
||||||
return dtype_size * total
|
|
||||||
|
|
||||||
def get_max_num_gpu_blocks(
|
|
||||||
self,
|
|
||||||
max_num_batched_tokens: int,
|
|
||||||
memory_utilization: float = 0.95,
|
|
||||||
) -> int:
|
|
||||||
# NOTE(woosuk): This assumes that the machine has homogeneous GPUs.
|
|
||||||
gpu_memory = self.gpu_memory
|
|
||||||
usable_memory = int(memory_utilization * gpu_memory)
|
|
||||||
|
|
||||||
param_size = self._get_param_size()
|
|
||||||
act_size = self._get_max_act_size(max_num_batched_tokens)
|
|
||||||
workspace_size = self.get_workspace_size()
|
|
||||||
|
|
||||||
max_cache_size = usable_memory - (param_size + act_size + workspace_size)
|
|
||||||
if max_cache_size <= 0:
|
|
||||||
raise RuntimeError('Not enough GPU memory.')
|
|
||||||
max_num_blocks = max_cache_size // self.get_cache_block_size()
|
|
||||||
return max_num_blocks
|
|
||||||
@ -1,72 +0,0 @@
|
|||||||
from typing import Union
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from transformers import AutoConfig
|
|
||||||
|
|
||||||
from cacheflow.models.memory_analyzer import CacheFlowMemoryAnalyzer
|
|
||||||
from cacheflow.models.memory_analyzer import LlamaMemoryAnalyzer
|
|
||||||
from cacheflow.models.memory_analyzer import OPTMemoryAnalyzer
|
|
||||||
from cacheflow.models.llama import LlamaForCausalLM
|
|
||||||
from cacheflow.models.opt import OPTForCausalLM
|
|
||||||
from cacheflow.models.utils import get_torch_dtype
|
|
||||||
|
|
||||||
|
|
||||||
_MODELS = {
|
|
||||||
'llama': LlamaForCausalLM,
|
|
||||||
'opt': OPTForCausalLM,
|
|
||||||
}
|
|
||||||
|
|
||||||
_MEMORY_ANALYZERS = {
|
|
||||||
'llama': LlamaMemoryAnalyzer,
|
|
||||||
'opt': OPTMemoryAnalyzer,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def get_model(
|
|
||||||
model_name: str,
|
|
||||||
dtype: Union[torch.dtype, str],
|
|
||||||
path: str,
|
|
||||||
use_dummy_weights: bool,
|
|
||||||
) -> nn.Module:
|
|
||||||
torch_dtype = get_torch_dtype(dtype)
|
|
||||||
torch.set_default_dtype(torch_dtype)
|
|
||||||
config = AutoConfig.from_pretrained(model_name)
|
|
||||||
for model_class_name, model_class in _MODELS.items():
|
|
||||||
if model_class_name in model_name:
|
|
||||||
if use_dummy_weights:
|
|
||||||
# Create a model instance.
|
|
||||||
# The weights will be initialized as empty tensors.
|
|
||||||
model = model_class(config)
|
|
||||||
model = model.cuda()
|
|
||||||
# NOTE(woosuk): For precise performance evaluation, we assign
|
|
||||||
# random values to the weights.
|
|
||||||
model.initialize_dummy_weights()
|
|
||||||
else:
|
|
||||||
# Download model weights if it's not cached.
|
|
||||||
weights_dir = model_class.get_weights(model_name, path=path)
|
|
||||||
# Create a model instance.
|
|
||||||
model = model_class(config)
|
|
||||||
# Load the weights from the cached or downloaded files.
|
|
||||||
model.load_weights(weights_dir)
|
|
||||||
model = model.cuda()
|
|
||||||
return model.eval(), torch_dtype
|
|
||||||
raise ValueError(f'Unsupported model name: {model_name}')
|
|
||||||
|
|
||||||
|
|
||||||
def get_memory_analyzer(
|
|
||||||
model_name: str,
|
|
||||||
block_size: int,
|
|
||||||
dtype: Union[torch.dtype, str],
|
|
||||||
gpu_memory: int,
|
|
||||||
cpu_memory: int,
|
|
||||||
tensor_parallel_size: int = 1,
|
|
||||||
) -> CacheFlowMemoryAnalyzer:
|
|
||||||
torch_dtype = get_torch_dtype(dtype)
|
|
||||||
for model_class, memory_analyzer in _MEMORY_ANALYZERS.items():
|
|
||||||
if model_class in model_name:
|
|
||||||
return memory_analyzer(
|
|
||||||
model_name, block_size, torch_dtype, gpu_memory, cpu_memory,
|
|
||||||
tensor_parallel_size)
|
|
||||||
raise ValueError(f'Unsupported model name: {model_name}')
|
|
||||||
@ -1,330 +0,0 @@
|
|||||||
"""1D OPT model compatible with HuggingFace weights."""
|
|
||||||
import os
|
|
||||||
import glob
|
|
||||||
import filelock
|
|
||||||
from tqdm import tqdm
|
|
||||||
from typing import Dict, List, Optional, Tuple
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
from transformers import OPTConfig
|
|
||||||
from huggingface_hub import snapshot_download
|
|
||||||
|
|
||||||
from cacheflow.models import InputMetadata
|
|
||||||
from cacheflow.models.attention import OPTCacheFlowAttention
|
|
||||||
from cacheflow.models.sample import Sampler
|
|
||||||
from cacheflow.parallel_utils.parallel_state import (
|
|
||||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
|
||||||
from cacheflow.parallel_utils.tensor_parallel import (VocabParallelEmbedding,
|
|
||||||
ColumnParallelLinear,
|
|
||||||
RowParallelLinear)
|
|
||||||
from cacheflow.sequence import SequenceOutputs
|
|
||||||
|
|
||||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
|
||||||
|
|
||||||
|
|
||||||
class OPTLearnedPositionalEmbedding(nn.Embedding):
|
|
||||||
|
|
||||||
def __init__(self, num_embeddings: int, embedding_dim: int):
|
|
||||||
# OPT is set up so that if padding_idx is specified then offset the embedding ids by 2
|
|
||||||
# and adjust num_embeddings appropriately. Other models don't have this hack
|
|
||||||
self.offset = 2
|
|
||||||
super().__init__(num_embeddings + self.offset, embedding_dim)
|
|
||||||
|
|
||||||
def forward(self, positions: torch.LongTensor):
|
|
||||||
return super().forward(positions + self.offset)
|
|
||||||
|
|
||||||
|
|
||||||
class OPTAttention(nn.Module):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
embed_dim: int,
|
|
||||||
num_heads: int,
|
|
||||||
bias: bool = True,
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.embed_dim = embed_dim
|
|
||||||
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
|
|
||||||
total_num_heads = num_heads
|
|
||||||
assert num_heads % tensor_model_parallel_world_size == 0
|
|
||||||
self.num_heads = total_num_heads // tensor_model_parallel_world_size
|
|
||||||
self.head_dim = embed_dim // total_num_heads
|
|
||||||
self.scaling = self.head_dim ** -0.5
|
|
||||||
|
|
||||||
self.qkv_proj = ColumnParallelLinear(embed_dim, 3 * embed_dim, bias=bias,
|
|
||||||
gather_output=False,
|
|
||||||
perform_initialization=False)
|
|
||||||
self.out_proj = RowParallelLinear(embed_dim, embed_dim, bias=bias,
|
|
||||||
input_is_parallel=True,
|
|
||||||
perform_initialization=False)
|
|
||||||
self.attn = OPTCacheFlowAttention(scale=self.scaling)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
kv_cache: KVCache,
|
|
||||||
input_metadata: InputMetadata,
|
|
||||||
cache_event: Optional[torch.cuda.Event],
|
|
||||||
) -> torch.Tensor:
|
|
||||||
qkv, _ = self.qkv_proj(hidden_states)
|
|
||||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
|
||||||
key_cache, value_cache = kv_cache
|
|
||||||
attn_output = self.attn(
|
|
||||||
q, k, v, key_cache, value_cache, input_metadata, cache_event)
|
|
||||||
output, _ = self.out_proj(attn_output)
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
class OPTDecoderLayer(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, config: OPTConfig):
|
|
||||||
super().__init__()
|
|
||||||
self.config = config
|
|
||||||
self.embed_dim = config.hidden_size
|
|
||||||
self.self_attn = OPTAttention(
|
|
||||||
embed_dim=self.embed_dim,
|
|
||||||
num_heads=config.num_attention_heads,
|
|
||||||
bias=config.enable_bias,
|
|
||||||
)
|
|
||||||
self.do_layer_norm_before = config.do_layer_norm_before
|
|
||||||
assert config.activation_function == 'relu'
|
|
||||||
self.activation_fn = nn.ReLU()
|
|
||||||
|
|
||||||
self.self_attn_layer_norm = nn.LayerNorm(
|
|
||||||
self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine)
|
|
||||||
self.fc1 = ColumnParallelLinear(self.embed_dim, config.ffn_dim,
|
|
||||||
bias=config.enable_bias,
|
|
||||||
gather_output=False,
|
|
||||||
perform_initialization=False)
|
|
||||||
self.fc2 = RowParallelLinear(config.ffn_dim, self.embed_dim,
|
|
||||||
bias=config.enable_bias,
|
|
||||||
input_is_parallel=True,
|
|
||||||
perform_initialization=False)
|
|
||||||
self.final_layer_norm = nn.LayerNorm(
|
|
||||||
self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
kv_cache: KVCache,
|
|
||||||
input_metadata: InputMetadata,
|
|
||||||
cache_event: Optional[torch.cuda.Event],
|
|
||||||
) -> torch.Tensor:
|
|
||||||
# Self Attention
|
|
||||||
residual = hidden_states
|
|
||||||
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
|
|
||||||
if self.do_layer_norm_before:
|
|
||||||
hidden_states = self.self_attn_layer_norm(hidden_states)
|
|
||||||
hidden_states = self.self_attn(
|
|
||||||
hidden_states=hidden_states,
|
|
||||||
kv_cache=kv_cache,
|
|
||||||
input_metadata=input_metadata,
|
|
||||||
cache_event=cache_event)
|
|
||||||
hidden_states = residual + hidden_states
|
|
||||||
# 350m applies layer norm AFTER attention
|
|
||||||
if not self.do_layer_norm_before:
|
|
||||||
hidden_states = self.self_attn_layer_norm(hidden_states)
|
|
||||||
|
|
||||||
# Fully Connected
|
|
||||||
residual = hidden_states
|
|
||||||
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
|
|
||||||
if self.do_layer_norm_before:
|
|
||||||
hidden_states = self.final_layer_norm(hidden_states)
|
|
||||||
hidden_states, _ = self.fc1(hidden_states)
|
|
||||||
hidden_states = self.activation_fn(hidden_states)
|
|
||||||
hidden_states, _ = self.fc2(hidden_states)
|
|
||||||
hidden_states = residual + hidden_states
|
|
||||||
# 350m applies layer norm AFTER attention
|
|
||||||
if not self.do_layer_norm_before:
|
|
||||||
hidden_states = self.final_layer_norm(hidden_states)
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class OPTDecoder(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, config: OPTConfig):
|
|
||||||
super().__init__()
|
|
||||||
self.config = config
|
|
||||||
self.padding_idx = config.pad_token_id
|
|
||||||
self.max_target_positions = config.max_position_embeddings
|
|
||||||
self.vocab_size = config.vocab_size
|
|
||||||
|
|
||||||
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
|
|
||||||
config.word_embed_proj_dim,
|
|
||||||
perform_initialization=False)
|
|
||||||
# Positional embeddings are replicated (not sharded).
|
|
||||||
self.embed_positions = OPTLearnedPositionalEmbedding(
|
|
||||||
config.max_position_embeddings, config.hidden_size)
|
|
||||||
|
|
||||||
# Project out & in will be replicated if they exist.
|
|
||||||
if config.word_embed_proj_dim != config.hidden_size:
|
|
||||||
self.project_out = nn.Linear(config.hidden_size, config.word_embed_proj_dim, bias=False)
|
|
||||||
else:
|
|
||||||
self.project_out = None
|
|
||||||
|
|
||||||
if config.word_embed_proj_dim != config.hidden_size:
|
|
||||||
self.project_in = nn.Linear(config.word_embed_proj_dim, config.hidden_size, bias=False)
|
|
||||||
else:
|
|
||||||
self.project_in = None
|
|
||||||
|
|
||||||
# Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility
|
|
||||||
# with checkpoints that have been fine-tuned before transformers v4.20.1
|
|
||||||
# see https://github.com/facebookresearch/metaseq/pull/164
|
|
||||||
if config.do_layer_norm_before and not config._remove_final_layer_norm:
|
|
||||||
self.final_layer_norm = nn.LayerNorm(
|
|
||||||
config.hidden_size, elementwise_affine=config.layer_norm_elementwise_affine
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.final_layer_norm = None
|
|
||||||
|
|
||||||
self.layers = nn.ModuleList([OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.LongTensor,
|
|
||||||
positions: torch.LongTensor,
|
|
||||||
kv_caches: List[KVCache],
|
|
||||||
input_metadata: InputMetadata,
|
|
||||||
cache_events: Optional[List[torch.cuda.Event]],
|
|
||||||
) -> torch.Tensor:
|
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
|
||||||
pos_embeds = self.embed_positions(positions)
|
|
||||||
if self.project_in is not None:
|
|
||||||
inputs_embeds = self.project_in(inputs_embeds)
|
|
||||||
hidden_states = inputs_embeds + pos_embeds
|
|
||||||
|
|
||||||
for i in range(len(self.layers)):
|
|
||||||
if cache_events is None:
|
|
||||||
cache_event = None
|
|
||||||
else:
|
|
||||||
cache_event = cache_events[i]
|
|
||||||
layer = self.layers[i]
|
|
||||||
hidden_states = layer(
|
|
||||||
hidden_states, kv_caches[i], input_metadata, cache_event)
|
|
||||||
|
|
||||||
if self.final_layer_norm is not None:
|
|
||||||
hidden_states = self.final_layer_norm(hidden_states)
|
|
||||||
if self.project_out is not None:
|
|
||||||
hidden_states = self.project_out(hidden_states)
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class OPTModel(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, config: OPTConfig):
|
|
||||||
super().__init__()
|
|
||||||
self.decoder = OPTDecoder(config)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.LongTensor,
|
|
||||||
positions: torch.LongTensor,
|
|
||||||
kv_caches: List[KVCache],
|
|
||||||
input_metadata: InputMetadata,
|
|
||||||
cache_events: Optional[List[torch.cuda.Event]],
|
|
||||||
) -> torch.Tensor:
|
|
||||||
return self.decoder(
|
|
||||||
input_ids, positions, kv_caches, input_metadata, cache_events)
|
|
||||||
|
|
||||||
|
|
||||||
class OPTForCausalLM(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, config):
|
|
||||||
super().__init__()
|
|
||||||
self.config = config
|
|
||||||
self.model = OPTModel(config)
|
|
||||||
# TODO(zhuohan): create a new weight after implementing pipeline
|
|
||||||
# parallelism
|
|
||||||
self.lm_head_weight = self.model.decoder.embed_tokens.weight
|
|
||||||
self.sampler = Sampler()
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.LongTensor,
|
|
||||||
positions: torch.LongTensor,
|
|
||||||
kv_caches: List[KVCache],
|
|
||||||
input_metadata: InputMetadata,
|
|
||||||
cache_events: Optional[List[torch.cuda.Event]],
|
|
||||||
) -> Dict[int, SequenceOutputs]:
|
|
||||||
hidden_states = self.model(
|
|
||||||
input_ids, positions, kv_caches, input_metadata, cache_events)
|
|
||||||
next_tokens = self.sampler(
|
|
||||||
self.lm_head_weight, hidden_states, input_metadata)
|
|
||||||
return next_tokens
|
|
||||||
|
|
||||||
_column_parallel_weights = ["embed_tokens.weight", "fc1.weight", "fc1.bias"]
|
|
||||||
_row_parallel_weights = ["out_proj.weight", "fc2.weight"]
|
|
||||||
|
|
||||||
def load_weights(self, weights_path: str):
|
|
||||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
|
||||||
state_dict = self.state_dict()
|
|
||||||
for name, param in state_dict.items():
|
|
||||||
if "lm_head_weight" in name:
|
|
||||||
continue
|
|
||||||
if "qkv_proj" in name:
|
|
||||||
shard_size = param.shape[0] // 3
|
|
||||||
weights_to_concat = []
|
|
||||||
for weight_name in ["q_proj", "k_proj", "v_proj"]:
|
|
||||||
weight = np.load(os.path.join(
|
|
||||||
weights_path, name.replace("qkv_proj", weight_name)))
|
|
||||||
weights_to_concat.append(weight[
|
|
||||||
shard_size * tensor_model_parallel_rank
|
|
||||||
:shard_size * (tensor_model_parallel_rank + 1)])
|
|
||||||
loaded_weight = torch.from_numpy(
|
|
||||||
np.concatenate(weights_to_concat, axis=0))
|
|
||||||
else:
|
|
||||||
loaded_weight = torch.from_numpy(
|
|
||||||
np.load(os.path.join(weights_path, name)))
|
|
||||||
for p in self._column_parallel_weights:
|
|
||||||
if p in name:
|
|
||||||
shard_size = param.shape[0]
|
|
||||||
loaded_weight = loaded_weight[
|
|
||||||
shard_size * tensor_model_parallel_rank
|
|
||||||
:shard_size * (tensor_model_parallel_rank + 1)]
|
|
||||||
break
|
|
||||||
for p in self._row_parallel_weights:
|
|
||||||
if p in name:
|
|
||||||
shard_size = param.shape[1]
|
|
||||||
loaded_weight = loaded_weight[
|
|
||||||
:,
|
|
||||||
shard_size * tensor_model_parallel_rank
|
|
||||||
:shard_size * (tensor_model_parallel_rank + 1)]
|
|
||||||
break
|
|
||||||
|
|
||||||
assert param.shape == loaded_weight.shape
|
|
||||||
param.data.copy_(loaded_weight)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_weights(model_name: str, path: str):
|
|
||||||
path = os.path.join(path, f"{model_name}-np")
|
|
||||||
path = os.path.abspath(os.path.expanduser(path))
|
|
||||||
os.makedirs(path, exist_ok=True)
|
|
||||||
lock_path = os.path.join(path, "file_lock")
|
|
||||||
lock = filelock.FileLock(lock_path)
|
|
||||||
|
|
||||||
with lock:
|
|
||||||
test_weight_path = os.path.join(
|
|
||||||
path, "model.decoder.embed_positions.weight")
|
|
||||||
if os.path.exists(test_weight_path):
|
|
||||||
return path
|
|
||||||
|
|
||||||
folder = snapshot_download(model_name, allow_patterns="*.bin",
|
|
||||||
cache_dir=os.path.join(path, "cache"))
|
|
||||||
bin_files = glob.glob(os.path.join(folder, "*.bin"))
|
|
||||||
|
|
||||||
for bin_file in tqdm(bin_files, desc="Convert format"):
|
|
||||||
state = torch.load(bin_file, map_location="cpu")
|
|
||||||
for name, param in tqdm(state.items(), leave=False):
|
|
||||||
if name.startswith("decoder."):
|
|
||||||
name = "model." + name
|
|
||||||
param_path = os.path.join(path, name)
|
|
||||||
with open(param_path, "wb") as f:
|
|
||||||
np.save(f, param.cpu().detach().numpy())
|
|
||||||
|
|
||||||
return path
|
|
||||||
|
|
||||||
def initialize_dummy_weights(self) -> None:
|
|
||||||
for param in self.state_dict().values():
|
|
||||||
param.data.uniform_(-0.1, 0.1)
|
|
||||||
@ -1,287 +0,0 @@
|
|||||||
from typing import Dict, List, Tuple
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
from cacheflow.models import InputMetadata
|
|
||||||
from cacheflow.sampling_params import SamplingParams
|
|
||||||
from cacheflow.sequence import SequenceOutputs
|
|
||||||
from cacheflow.parallel_utils.tensor_parallel import gather_from_tensor_model_parallel_region
|
|
||||||
|
|
||||||
|
|
||||||
class Sampler(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
embedding: torch.Tensor,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
input_metadata: InputMetadata,
|
|
||||||
) -> Dict[int, SequenceOutputs]:
|
|
||||||
# Get the hidden states that we use for sampling.
|
|
||||||
hidden_states = _prune_hidden_states(hidden_states, input_metadata)
|
|
||||||
|
|
||||||
# Get the logits for the next tokens.
|
|
||||||
logits = torch.matmul(hidden_states, embedding.t())
|
|
||||||
logits = gather_from_tensor_model_parallel_region(logits)
|
|
||||||
|
|
||||||
# Apply temperature scaling.
|
|
||||||
temperatures = _get_temperatures(input_metadata)
|
|
||||||
assert len(temperatures) == logits.shape[0]
|
|
||||||
if any(t != 1.0 for t in temperatures):
|
|
||||||
t = torch.tensor(
|
|
||||||
temperatures, dtype=logits.dtype, device=logits.device)
|
|
||||||
# Use in-place division to avoid creating a new tensor.
|
|
||||||
logits.div_(t.unsqueeze(dim=1))
|
|
||||||
|
|
||||||
# We use float32 for probabilities and log probabilities.
|
|
||||||
# Compute the probabilities.
|
|
||||||
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
|
||||||
# Compute the log probabilities (before applying top-p).
|
|
||||||
logprobs = torch.log(probs)
|
|
||||||
|
|
||||||
# Apply top-p truncation.
|
|
||||||
top_ps = _get_top_ps(input_metadata)
|
|
||||||
assert len(top_ps) == probs.shape[0]
|
|
||||||
if any(p < 1.0 for p in top_ps):
|
|
||||||
p = torch.tensor(top_ps, dtype=probs.dtype, device=probs.device)
|
|
||||||
probs = _apply_top_p(probs, p)
|
|
||||||
|
|
||||||
# Sample the next tokens.
|
|
||||||
return _sample(probs, logprobs, input_metadata)
|
|
||||||
|
|
||||||
|
|
||||||
def _prune_hidden_states(
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
input_metadata: InputMetadata,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
start_idx = 0
|
|
||||||
last_token_indicies: List[int] = []
|
|
||||||
for prompt_len in input_metadata.prompt_lens:
|
|
||||||
last_token_indicies.append(start_idx + prompt_len - 1)
|
|
||||||
start_idx += prompt_len
|
|
||||||
last_token_indicies.extend(
|
|
||||||
range(start_idx, start_idx + input_metadata.num_generation_tokens))
|
|
||||||
return hidden_states[last_token_indicies]
|
|
||||||
|
|
||||||
|
|
||||||
def _get_temperatures(
|
|
||||||
input_metadata: InputMetadata,
|
|
||||||
) -> List[float]:
|
|
||||||
# Collect the temperatures for the logits.
|
|
||||||
temperatures: List[float] = []
|
|
||||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
|
||||||
seq_ids, sampling_params = seq_group
|
|
||||||
temperature = sampling_params.temperature
|
|
||||||
if temperature == 0.0:
|
|
||||||
# NOTE: Zero temperature means deterministic sampling
|
|
||||||
# (i.e., greedy sampling or beam search).
|
|
||||||
# Set the temperature to 1 to avoid division by zero.
|
|
||||||
temperature = 1.0
|
|
||||||
|
|
||||||
if i < input_metadata.num_prompts:
|
|
||||||
# A prompt input.
|
|
||||||
temperatures.append(temperature)
|
|
||||||
else:
|
|
||||||
# A generation token.
|
|
||||||
temperatures += [temperature] * len(seq_ids)
|
|
||||||
return temperatures
|
|
||||||
|
|
||||||
|
|
||||||
def _get_top_ps(
|
|
||||||
input_metadata: InputMetadata,
|
|
||||||
) -> List[float]:
|
|
||||||
top_ps: List[float] = []
|
|
||||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
|
||||||
seq_ids, sampling_params = seq_group
|
|
||||||
if i < input_metadata.num_prompts:
|
|
||||||
# A prompt input.
|
|
||||||
top_ps.append(sampling_params.top_p)
|
|
||||||
else:
|
|
||||||
# A generation token.
|
|
||||||
top_ps += [sampling_params.top_p] * len(seq_ids)
|
|
||||||
return top_ps
|
|
||||||
|
|
||||||
|
|
||||||
def _apply_top_p(
|
|
||||||
probs: torch.Tensor,
|
|
||||||
p: torch.Tensor,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
# TODO(woosuk): Optimize.
|
|
||||||
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
|
|
||||||
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
|
||||||
mask = (probs_sum - probs_sort) > p.unsqueeze(dim=1)
|
|
||||||
probs_sort[mask] = 0.0
|
|
||||||
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
|
||||||
probs = torch.gather(
|
|
||||||
probs_sort, dim=-1, index=torch.argsort(probs_idx, dim=-1))
|
|
||||||
return probs
|
|
||||||
|
|
||||||
|
|
||||||
def _get_topk_logprobs(
|
|
||||||
logprobs: torch.Tensor,
|
|
||||||
num_logprobs: int,
|
|
||||||
) -> Dict[int, float]:
|
|
||||||
if num_logprobs == 0:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
topk_logprobs, topk_ids = torch.topk(logprobs, num_logprobs)
|
|
||||||
if num_logprobs == 1:
|
|
||||||
topk_logprobs = [topk_logprobs.item()]
|
|
||||||
topk_ids = [topk_ids.item()]
|
|
||||||
else:
|
|
||||||
topk_logprobs = topk_logprobs.tolist()
|
|
||||||
topk_ids = topk_ids.tolist()
|
|
||||||
|
|
||||||
token_to_logprob: Dict[int, float] = {}
|
|
||||||
for token_id, logprob in zip(topk_ids, topk_logprobs):
|
|
||||||
token_to_logprob[token_id] = logprob
|
|
||||||
return token_to_logprob
|
|
||||||
|
|
||||||
|
|
||||||
def _sample_from_prompt(
|
|
||||||
prob: torch.Tensor,
|
|
||||||
sampling_params: SamplingParams,
|
|
||||||
) -> List[int]:
|
|
||||||
if sampling_params.use_beam_search:
|
|
||||||
# Beam search.
|
|
||||||
beam_width = sampling_params.n
|
|
||||||
_, next_token_ids = torch.topk(prob, beam_width)
|
|
||||||
next_token_ids = next_token_ids.tolist()
|
|
||||||
elif sampling_params.temperature == 0.0:
|
|
||||||
# Greedy sampling.
|
|
||||||
assert sampling_params.n == 1
|
|
||||||
next_token_id = torch.argmax(prob)
|
|
||||||
next_token_ids = [next_token_id.item()]
|
|
||||||
else:
|
|
||||||
# Neucleus sampling.
|
|
||||||
# Sample n tokens for the prompt.
|
|
||||||
n = sampling_params.n
|
|
||||||
next_token_ids = torch.multinomial(
|
|
||||||
prob, num_samples=n, replacement=True)
|
|
||||||
next_token_ids = next_token_ids.tolist()
|
|
||||||
return next_token_ids
|
|
||||||
|
|
||||||
|
|
||||||
def _sample_from_generation_tokens(
|
|
||||||
seq_ids: List[int],
|
|
||||||
probs: torch.Tensor,
|
|
||||||
logprobs: torch.Tensor,
|
|
||||||
seq_logprobs: List[float],
|
|
||||||
sampling_params: SamplingParams,
|
|
||||||
) -> Tuple[List[int], List[int]]:
|
|
||||||
# NOTE(woosuk): sampling_params.n can be greater than
|
|
||||||
# len(seq_ids) because some sequences in the group might have
|
|
||||||
# been already terminated.
|
|
||||||
if sampling_params.use_beam_search:
|
|
||||||
# Beam search.
|
|
||||||
# Add cumulative logprobs for the sequences in the group.
|
|
||||||
seq_logprobs = torch.tensor(
|
|
||||||
seq_logprobs, dtype=torch.float, device=logprobs.device)
|
|
||||||
logprobs = logprobs + seq_logprobs.unsqueeze(dim=1)
|
|
||||||
|
|
||||||
vocab_size = logprobs.size(-1)
|
|
||||||
beam_width = len(seq_ids)
|
|
||||||
_, topk_ids = torch.topk(logprobs.flatten(), beam_width)
|
|
||||||
topk_ids = topk_ids.tolist()
|
|
||||||
seq_idx = [i // vocab_size for i in topk_ids]
|
|
||||||
beam_seq_ids = [seq_ids[i] for i in seq_idx]
|
|
||||||
token_ids = [i % vocab_size for i in topk_ids]
|
|
||||||
|
|
||||||
beam_outputs: Dict[int, Tuple[int, int]] = {}
|
|
||||||
outstanding_beams: List[Tuple[int, int]] = []
|
|
||||||
# If a beam survives, continue with it.
|
|
||||||
for seq_id, token_id in zip(beam_seq_ids, token_ids):
|
|
||||||
if seq_id not in beam_outputs:
|
|
||||||
beam_outputs[seq_id] = (seq_id, token_id)
|
|
||||||
else:
|
|
||||||
outstanding_beams.append((seq_id, token_id))
|
|
||||||
|
|
||||||
# If a beam is discarded, fork another beam.
|
|
||||||
for seq_id in seq_ids:
|
|
||||||
if seq_id not in beam_outputs:
|
|
||||||
beam_outputs[seq_id] = outstanding_beams.pop()
|
|
||||||
assert not outstanding_beams
|
|
||||||
|
|
||||||
parent_seq_ids = [beam_outputs[seq_id][0] for seq_id in seq_ids]
|
|
||||||
next_token_ids = [beam_outputs[seq_id][1] for seq_id in seq_ids]
|
|
||||||
elif sampling_params.temperature == 0.0:
|
|
||||||
# Greedy sampling.
|
|
||||||
assert len(seq_ids) == 1
|
|
||||||
next_token_id = torch.argmax(probs, dim=-1)
|
|
||||||
next_token_ids = [next_token_id.item()]
|
|
||||||
parent_seq_ids = seq_ids
|
|
||||||
else:
|
|
||||||
# Neucleus sampling.
|
|
||||||
# Sample 1 token for each sequence in the group.
|
|
||||||
next_token_ids = torch.multinomial(
|
|
||||||
probs, num_samples=1, replacement=True)
|
|
||||||
next_token_ids = next_token_ids.squeeze(dim=-1).tolist()
|
|
||||||
parent_seq_ids = seq_ids
|
|
||||||
return parent_seq_ids, next_token_ids
|
|
||||||
|
|
||||||
|
|
||||||
def _sample(
|
|
||||||
probs: torch.Tensor,
|
|
||||||
logprobs: torch.Tensor,
|
|
||||||
input_metadata: InputMetadata,
|
|
||||||
) -> Dict[int, SequenceOutputs]:
|
|
||||||
seq_outputs: Dict[int, SequenceOutputs] = {}
|
|
||||||
|
|
||||||
# TODO(woosuk): Optimize.
|
|
||||||
idx = 0
|
|
||||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
|
||||||
seq_ids, sampling_params = seq_group
|
|
||||||
if i < input_metadata.num_prompts:
|
|
||||||
# Generate the next tokens for a prompt input.
|
|
||||||
assert len(seq_ids) == sampling_params.n
|
|
||||||
prob = probs[idx]
|
|
||||||
logprob = logprobs[idx]
|
|
||||||
idx += 1
|
|
||||||
|
|
||||||
# Sample the next tokens.
|
|
||||||
next_token_ids = _sample_from_prompt(prob, sampling_params)
|
|
||||||
# Get top-k log probabilities for the next tokens.
|
|
||||||
next_logprobs = _get_topk_logprobs(
|
|
||||||
logprob, sampling_params.num_logprobs)
|
|
||||||
|
|
||||||
# Build the output.
|
|
||||||
for seq_id, next_token_id in zip(seq_ids, next_token_ids):
|
|
||||||
output_logprobs = next_logprobs.copy()
|
|
||||||
output_logprobs[next_token_id] = logprob[next_token_id].item()
|
|
||||||
seq_outputs[seq_id] = SequenceOutputs(
|
|
||||||
seq_id, seq_id, next_token_id, output_logprobs)
|
|
||||||
else:
|
|
||||||
# Generate the next tokens for generation tokens.
|
|
||||||
prob = probs[idx:idx + len(seq_ids)]
|
|
||||||
logprob = logprobs[idx:idx + len(seq_ids)]
|
|
||||||
idx += len(seq_ids)
|
|
||||||
|
|
||||||
# Sample the next tokens.
|
|
||||||
seq_logprobs = [
|
|
||||||
input_metadata.seq_logprobs[seq_id] for seq_id in seq_ids]
|
|
||||||
parent_seq_ids, next_token_ids = _sample_from_generation_tokens(
|
|
||||||
seq_ids, prob, logprob, seq_logprobs, sampling_params)
|
|
||||||
|
|
||||||
# Get top-k log probabilities for the next tokens.
|
|
||||||
next_logprobs: Dict[int, Dict[int, float]] = {}
|
|
||||||
for i, seq_id in enumerate(seq_ids):
|
|
||||||
next_logprobs[seq_id] = _get_topk_logprobs(
|
|
||||||
logprob[i], sampling_params.num_logprobs)
|
|
||||||
|
|
||||||
# Build the output.
|
|
||||||
for seq_id, parent_seq_id, next_token_id in zip(
|
|
||||||
seq_ids, parent_seq_ids, next_token_ids):
|
|
||||||
i = seq_ids.index(parent_seq_id)
|
|
||||||
output_logprobs = next_logprobs[parent_seq_id].copy()
|
|
||||||
output_logprobs[next_token_id] = logprob[i, next_token_id].item()
|
|
||||||
seq_outputs[seq_id] = SequenceOutputs(
|
|
||||||
seq_id,
|
|
||||||
parent_seq_id,
|
|
||||||
next_token_id,
|
|
||||||
output_logprobs,
|
|
||||||
)
|
|
||||||
|
|
||||||
return seq_outputs
|
|
||||||
@ -1,24 +0,0 @@
|
|||||||
from typing import Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
_STR_DTYPE_TO_TORCH_DTYPE = {
|
|
||||||
'half': torch.half,
|
|
||||||
'float': torch.float,
|
|
||||||
'float16': torch.float16,
|
|
||||||
'float32': torch.float32,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def get_torch_dtype(dtype: Union[torch.dtype, str]) -> torch.dtype:
|
|
||||||
if isinstance(dtype, str):
|
|
||||||
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype.lower()]
|
|
||||||
else:
|
|
||||||
torch_dtype = dtype
|
|
||||||
return torch_dtype
|
|
||||||
|
|
||||||
|
|
||||||
def get_dtype_size(dtype: Union[torch.dtype, str]) -> int:
|
|
||||||
torch_dtype = get_torch_dtype(dtype)
|
|
||||||
return torch.tensor([], dtype=torch_dtype).element_size()
|
|
||||||
|
|
||||||
@ -1,12 +0,0 @@
|
|||||||
import cacheflow.parallel_utils.parallel_state
|
|
||||||
import cacheflow.parallel_utils.tensor_parallel
|
|
||||||
import cacheflow.parallel_utils.utils
|
|
||||||
|
|
||||||
# Alias parallel_state as mpu, its legacy name
|
|
||||||
mpu = parallel_state
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"parallel_state",
|
|
||||||
"tensor_parallel",
|
|
||||||
"utils",
|
|
||||||
]
|
|
||||||
@ -1,593 +0,0 @@
|
|||||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
|
||||||
|
|
||||||
"""Model and data parallel groups."""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from .utils import GlobalMemoryBuffer
|
|
||||||
|
|
||||||
# Intra-layer model parallel group that the current rank belongs to.
|
|
||||||
_TENSOR_MODEL_PARALLEL_GROUP = None
|
|
||||||
# Inter-layer model parallel group that the current rank belongs to.
|
|
||||||
_PIPELINE_MODEL_PARALLEL_GROUP = None
|
|
||||||
# Model parallel group (both intra- and pipeline) that the current rank belongs to.
|
|
||||||
_MODEL_PARALLEL_GROUP = None
|
|
||||||
# Embedding group.
|
|
||||||
_EMBEDDING_GROUP = None
|
|
||||||
# Position embedding group.
|
|
||||||
_POSITION_EMBEDDING_GROUP = None
|
|
||||||
# Data parallel group that the current rank belongs to.
|
|
||||||
_DATA_PARALLEL_GROUP = None
|
|
||||||
|
|
||||||
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None
|
|
||||||
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
|
|
||||||
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = None
|
|
||||||
|
|
||||||
# These values enable us to change the mpu sizes on the fly.
|
|
||||||
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None
|
|
||||||
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
|
|
||||||
_MPU_TENSOR_MODEL_PARALLEL_RANK = None
|
|
||||||
_MPU_PIPELINE_MODEL_PARALLEL_RANK = None
|
|
||||||
|
|
||||||
# A list of ranks that have a copy of the embedding.
|
|
||||||
_EMBEDDING_GLOBAL_RANKS = None
|
|
||||||
|
|
||||||
# A list of ranks that have a copy of the position embedding.
|
|
||||||
_POSITION_EMBEDDING_GLOBAL_RANKS = None
|
|
||||||
|
|
||||||
# A list of global ranks for each pipeline group to ease calculation of the source
|
|
||||||
# rank when broadcasting from the first or last pipeline stage.
|
|
||||||
_PIPELINE_GLOBAL_RANKS = None
|
|
||||||
|
|
||||||
# A list of global ranks for each data parallel group to ease calculation of the source
|
|
||||||
# rank when broadcasting weights from src to all other data parallel ranks
|
|
||||||
_DATA_PARALLEL_GLOBAL_RANKS = None
|
|
||||||
|
|
||||||
# Memory buffers to avoid dynamic memory allocation
|
|
||||||
_GLOBAL_MEMORY_BUFFER = None
|
|
||||||
|
|
||||||
_ALL_REDUCE_LAUNCHER: Optional['GraphAllReduce'] = None
|
|
||||||
|
|
||||||
def initialize_model_parallel(
|
|
||||||
tensor_model_parallel_size: int = 1,
|
|
||||||
pipeline_model_parallel_size: int = 1,
|
|
||||||
virtual_pipeline_model_parallel_size: Optional[int] = None,
|
|
||||||
pipeline_model_parallel_split_rank: Optional[int] = None,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Initialize model data parallel groups.
|
|
||||||
|
|
||||||
Arguments:
|
|
||||||
tensor_model_parallel_size: number of GPUs used for tensor model parallelism.
|
|
||||||
pipeline_model_parallel_size: number of GPUs used for pipeline model parallelism.
|
|
||||||
virtual_pipeline_model_parallel_size: number of virtual stages (interleaved
|
|
||||||
pipeline).
|
|
||||||
pipeline_model_parallel_split_rank: for models with both encoder and decoder,
|
|
||||||
rank in pipeline with split point.
|
|
||||||
|
|
||||||
Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
|
|
||||||
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
|
|
||||||
the model pipeline. The present function will
|
|
||||||
create 8 tensor model-parallel groups, 4 pipeline model-parallel groups
|
|
||||||
and 8 data-parallel groups as:
|
|
||||||
8 data_parallel groups:
|
|
||||||
[g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15]
|
|
||||||
8 tensor model-parallel groups:
|
|
||||||
[g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15]
|
|
||||||
4 pipeline model-parallel groups:
|
|
||||||
[g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15]
|
|
||||||
Note that for efficiency, the caller should make sure adjacent ranks
|
|
||||||
are on the same DGX box. For example if we are using 2 DGX-1 boxes
|
|
||||||
with a total of 16 GPUs, rank 0 to 7 belong to the first box and
|
|
||||||
ranks 8 to 15 belong to the second box.
|
|
||||||
"""
|
|
||||||
# Get world size and rank. Ensure some consistencies.
|
|
||||||
assert torch.distributed.is_initialized()
|
|
||||||
world_size: int = torch.distributed.get_world_size()
|
|
||||||
|
|
||||||
if world_size % (tensor_model_parallel_size * pipeline_model_parallel_size) != 0:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"world_size ({world_size}) is not divisible by tensor_model_parallel_size "
|
|
||||||
f"({tensor_model_parallel_size}) x pipeline_model_parallel_size ({pipeline_model_parallel_size})"
|
|
||||||
)
|
|
||||||
|
|
||||||
data_parallel_size: int = world_size // (tensor_model_parallel_size *
|
|
||||||
pipeline_model_parallel_size)
|
|
||||||
|
|
||||||
num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size
|
|
||||||
num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size
|
|
||||||
num_data_parallel_groups: int = world_size // data_parallel_size
|
|
||||||
|
|
||||||
if virtual_pipeline_model_parallel_size is not None:
|
|
||||||
if not pipeline_model_parallel_size > 2:
|
|
||||||
raise RuntimeError("pipeline-model-parallel size should be greater than 2 with "
|
|
||||||
"interleaved schedule")
|
|
||||||
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
|
|
||||||
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
|
|
||||||
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0
|
|
||||||
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = virtual_pipeline_model_parallel_size
|
|
||||||
|
|
||||||
if pipeline_model_parallel_split_rank is not None:
|
|
||||||
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
|
|
||||||
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = pipeline_model_parallel_split_rank
|
|
||||||
|
|
||||||
rank = torch.distributed.get_rank()
|
|
||||||
|
|
||||||
# Build the data-parallel groups.
|
|
||||||
global _DATA_PARALLEL_GROUP
|
|
||||||
global _DATA_PARALLEL_GLOBAL_RANKS
|
|
||||||
assert _DATA_PARALLEL_GROUP is None, 'data parallel group is already initialized'
|
|
||||||
all_data_parallel_group_ranks = []
|
|
||||||
for i in range(pipeline_model_parallel_size):
|
|
||||||
start_rank = i * num_pipeline_model_parallel_groups
|
|
||||||
end_rank = (i + 1) * num_pipeline_model_parallel_groups
|
|
||||||
for j in range(tensor_model_parallel_size):
|
|
||||||
ranks = range(start_rank + j, end_rank, tensor_model_parallel_size)
|
|
||||||
all_data_parallel_group_ranks.append(list(ranks))
|
|
||||||
group = torch.distributed.new_group(ranks)
|
|
||||||
if rank in ranks:
|
|
||||||
_DATA_PARALLEL_GROUP = group
|
|
||||||
_DATA_PARALLEL_GLOBAL_RANKS = ranks
|
|
||||||
|
|
||||||
# Build the model-parallel groups.
|
|
||||||
global _MODEL_PARALLEL_GROUP
|
|
||||||
assert _MODEL_PARALLEL_GROUP is None, 'model parallel group is already initialized'
|
|
||||||
for i in range(data_parallel_size):
|
|
||||||
ranks = [data_parallel_group_ranks[i]
|
|
||||||
for data_parallel_group_ranks in all_data_parallel_group_ranks]
|
|
||||||
group = torch.distributed.new_group(ranks)
|
|
||||||
if rank in ranks:
|
|
||||||
_MODEL_PARALLEL_GROUP = group
|
|
||||||
|
|
||||||
# Build the tensor model-parallel groups.
|
|
||||||
global _TENSOR_MODEL_PARALLEL_GROUP
|
|
||||||
assert _TENSOR_MODEL_PARALLEL_GROUP is None, \
|
|
||||||
'tensor model parallel group is already initialized'
|
|
||||||
for i in range(num_tensor_model_parallel_groups):
|
|
||||||
ranks = range(i * tensor_model_parallel_size,
|
|
||||||
(i + 1) * tensor_model_parallel_size)
|
|
||||||
group = torch.distributed.new_group(ranks)
|
|
||||||
if rank in ranks:
|
|
||||||
_TENSOR_MODEL_PARALLEL_GROUP = group
|
|
||||||
|
|
||||||
# Build the pipeline model-parallel groups and embedding groups
|
|
||||||
# (first and last rank in each pipeline model-parallel group).
|
|
||||||
global _PIPELINE_MODEL_PARALLEL_GROUP
|
|
||||||
global _PIPELINE_GLOBAL_RANKS
|
|
||||||
assert _PIPELINE_MODEL_PARALLEL_GROUP is None, \
|
|
||||||
'pipeline model parallel group is already initialized'
|
|
||||||
global _EMBEDDING_GROUP
|
|
||||||
global _EMBEDDING_GLOBAL_RANKS
|
|
||||||
assert _EMBEDDING_GROUP is None, 'embedding group is already initialized'
|
|
||||||
global _POSITION_EMBEDDING_GROUP
|
|
||||||
global _POSITION_EMBEDDING_GLOBAL_RANKS
|
|
||||||
assert _POSITION_EMBEDDING_GROUP is None, \
|
|
||||||
'position embedding group is already initialized'
|
|
||||||
for i in range(num_pipeline_model_parallel_groups):
|
|
||||||
ranks = range(i, world_size, num_pipeline_model_parallel_groups)
|
|
||||||
group = torch.distributed.new_group(ranks)
|
|
||||||
if rank in ranks:
|
|
||||||
_PIPELINE_MODEL_PARALLEL_GROUP = group
|
|
||||||
_PIPELINE_GLOBAL_RANKS = ranks
|
|
||||||
# Setup embedding group (to exchange gradients between
|
|
||||||
# first and last stages).
|
|
||||||
if len(ranks) > 1:
|
|
||||||
embedding_ranks = [ranks[0], ranks[-1]]
|
|
||||||
position_embedding_ranks = [ranks[0]]
|
|
||||||
if pipeline_model_parallel_split_rank is not None:
|
|
||||||
if ranks[pipeline_model_parallel_split_rank] not in embedding_ranks:
|
|
||||||
embedding_ranks = [ranks[0],
|
|
||||||
ranks[pipeline_model_parallel_split_rank],
|
|
||||||
ranks[-1]]
|
|
||||||
if ranks[pipeline_model_parallel_split_rank] not in position_embedding_ranks:
|
|
||||||
position_embedding_ranks = [ranks[0],
|
|
||||||
ranks[pipeline_model_parallel_split_rank]]
|
|
||||||
else:
|
|
||||||
embedding_ranks = ranks
|
|
||||||
position_embedding_ranks = ranks
|
|
||||||
|
|
||||||
group = torch.distributed.new_group(embedding_ranks)
|
|
||||||
if rank in embedding_ranks:
|
|
||||||
_EMBEDDING_GROUP = group
|
|
||||||
if rank in ranks:
|
|
||||||
_EMBEDDING_GLOBAL_RANKS = embedding_ranks
|
|
||||||
|
|
||||||
group = torch.distributed.new_group(position_embedding_ranks)
|
|
||||||
if rank in position_embedding_ranks:
|
|
||||||
_POSITION_EMBEDDING_GROUP = group
|
|
||||||
if rank in ranks:
|
|
||||||
_POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks
|
|
||||||
|
|
||||||
# Initialize global memory buffer
|
|
||||||
# This isn't really "parallel state" but there isn't another good place to
|
|
||||||
# put this. If we end up with a more generic initialization of megatron-core
|
|
||||||
# we could stick it there
|
|
||||||
_set_global_memory_buffer()
|
|
||||||
|
|
||||||
|
|
||||||
def initialize_all_reduce_launcher(
|
|
||||||
max_num_tokens: int,
|
|
||||||
hidden_size: int,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
disable_graph: bool = False,
|
|
||||||
) -> None:
|
|
||||||
global _ALL_REDUCE_LAUNCHER
|
|
||||||
_ALL_REDUCE_LAUNCHER = GraphAllReduce(
|
|
||||||
max_num_tokens=max_num_tokens,
|
|
||||||
hidden_size=hidden_size,
|
|
||||||
dtype=dtype,
|
|
||||||
disable_graph=disable_graph,
|
|
||||||
)
|
|
||||||
|
|
||||||
def model_parallel_is_initialized():
|
|
||||||
"""Check if model and data parallel groups are initialized."""
|
|
||||||
if _TENSOR_MODEL_PARALLEL_GROUP is None or \
|
|
||||||
_PIPELINE_MODEL_PARALLEL_GROUP is None or \
|
|
||||||
_DATA_PARALLEL_GROUP is None:
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def get_model_parallel_group():
|
|
||||||
"""Get the model parallel group the caller rank belongs to."""
|
|
||||||
assert _MODEL_PARALLEL_GROUP is not None, \
|
|
||||||
'model parallel group is not initialized'
|
|
||||||
return _MODEL_PARALLEL_GROUP
|
|
||||||
|
|
||||||
|
|
||||||
def get_tensor_model_parallel_group():
|
|
||||||
"""Get the tensor model parallel group the caller rank belongs to."""
|
|
||||||
assert _TENSOR_MODEL_PARALLEL_GROUP is not None, \
|
|
||||||
'intra_layer_model parallel group is not initialized'
|
|
||||||
return _TENSOR_MODEL_PARALLEL_GROUP
|
|
||||||
|
|
||||||
|
|
||||||
def get_pipeline_model_parallel_group():
|
|
||||||
"""Get the pipeline model parallel group the caller rank belongs to."""
|
|
||||||
assert _PIPELINE_MODEL_PARALLEL_GROUP is not None, \
|
|
||||||
'pipeline_model parallel group is not initialized'
|
|
||||||
return _PIPELINE_MODEL_PARALLEL_GROUP
|
|
||||||
|
|
||||||
|
|
||||||
def get_data_parallel_group():
|
|
||||||
"""Get the data parallel group the caller rank belongs to."""
|
|
||||||
assert _DATA_PARALLEL_GROUP is not None, \
|
|
||||||
'data parallel group is not initialized'
|
|
||||||
return _DATA_PARALLEL_GROUP
|
|
||||||
|
|
||||||
|
|
||||||
def get_embedding_group():
|
|
||||||
"""Get the embedding group the caller rank belongs to."""
|
|
||||||
assert _EMBEDDING_GROUP is not None, \
|
|
||||||
'embedding group is not initialized'
|
|
||||||
return _EMBEDDING_GROUP
|
|
||||||
|
|
||||||
|
|
||||||
def get_position_embedding_group():
|
|
||||||
"""Get the position embedding group the caller rank belongs to."""
|
|
||||||
assert _POSITION_EMBEDDING_GROUP is not None, \
|
|
||||||
'position embedding group is not initialized'
|
|
||||||
return _POSITION_EMBEDDING_GROUP
|
|
||||||
|
|
||||||
|
|
||||||
def set_tensor_model_parallel_world_size(world_size):
|
|
||||||
"""Set the tensor model parallel size"""
|
|
||||||
global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
|
|
||||||
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = world_size
|
|
||||||
|
|
||||||
|
|
||||||
def set_pipeline_model_parallel_world_size(world_size):
|
|
||||||
"""Set the pipeline model parallel size"""
|
|
||||||
global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
|
|
||||||
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = world_size
|
|
||||||
|
|
||||||
|
|
||||||
def get_tensor_model_parallel_world_size():
|
|
||||||
"""Return world size for the tensor model parallel group."""
|
|
||||||
global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
|
|
||||||
if _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE is not None:
|
|
||||||
return _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
|
|
||||||
return torch.distributed.get_world_size(group=get_tensor_model_parallel_group())
|
|
||||||
|
|
||||||
|
|
||||||
def get_pipeline_model_parallel_world_size():
|
|
||||||
"""Return world size for the pipeline model parallel group."""
|
|
||||||
global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
|
|
||||||
if _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE is not None:
|
|
||||||
return _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
|
|
||||||
return torch.distributed.get_world_size(group=get_pipeline_model_parallel_group())
|
|
||||||
|
|
||||||
|
|
||||||
def set_tensor_model_parallel_rank(rank):
|
|
||||||
"""Set tensor model parallel rank."""
|
|
||||||
global _MPU_TENSOR_MODEL_PARALLEL_RANK
|
|
||||||
_MPU_TENSOR_MODEL_PARALLEL_RANK = rank
|
|
||||||
|
|
||||||
|
|
||||||
def set_pipeline_model_parallel_rank(rank):
|
|
||||||
"""Set pipeline model parallel rank."""
|
|
||||||
global _MPU_PIPELINE_MODEL_PARALLEL_RANK
|
|
||||||
_MPU_PIPELINE_MODEL_PARALLEL_RANK = rank
|
|
||||||
|
|
||||||
|
|
||||||
def set_pipeline_model_parallel_split_rank(rank):
|
|
||||||
"""Set pipeline model parallel split rank."""
|
|
||||||
global _MPU_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
|
|
||||||
_MPU_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = rank
|
|
||||||
|
|
||||||
|
|
||||||
def get_tensor_model_parallel_rank():
|
|
||||||
"""Return my rank for the tensor model parallel group."""
|
|
||||||
global _MPU_TENSOR_MODEL_PARALLEL_RANK
|
|
||||||
if _MPU_TENSOR_MODEL_PARALLEL_RANK is not None:
|
|
||||||
return _MPU_TENSOR_MODEL_PARALLEL_RANK
|
|
||||||
return torch.distributed.get_rank(group=get_tensor_model_parallel_group())
|
|
||||||
|
|
||||||
|
|
||||||
def get_pipeline_model_parallel_rank():
|
|
||||||
"""Return my rank for the pipeline model parallel group."""
|
|
||||||
global _MPU_PIPELINE_MODEL_PARALLEL_RANK
|
|
||||||
if _MPU_PIPELINE_MODEL_PARALLEL_RANK is not None:
|
|
||||||
return _MPU_PIPELINE_MODEL_PARALLEL_RANK
|
|
||||||
return torch.distributed.get_rank(group=get_pipeline_model_parallel_group())
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def is_pipeline_first_stage(ignore_virtual=False):
|
|
||||||
"""Return True if in the first pipeline model-parallel stage, False otherwise."""
|
|
||||||
if not ignore_virtual:
|
|
||||||
if get_virtual_pipeline_model_parallel_world_size() is not None and \
|
|
||||||
get_virtual_pipeline_model_parallel_rank() != 0:
|
|
||||||
return False
|
|
||||||
return get_pipeline_model_parallel_rank() == 0
|
|
||||||
|
|
||||||
|
|
||||||
def is_pipeline_last_stage(ignore_virtual=False):
|
|
||||||
"""Return True if in the last pipeline model-parallel stage, False otherwise."""
|
|
||||||
if not ignore_virtual:
|
|
||||||
virtual_pipeline_model_parallel_world_size = \
|
|
||||||
get_virtual_pipeline_model_parallel_world_size()
|
|
||||||
if virtual_pipeline_model_parallel_world_size is not None and \
|
|
||||||
get_virtual_pipeline_model_parallel_rank() != (
|
|
||||||
virtual_pipeline_model_parallel_world_size - 1):
|
|
||||||
return False
|
|
||||||
return get_pipeline_model_parallel_rank() == (
|
|
||||||
get_pipeline_model_parallel_world_size() - 1)
|
|
||||||
|
|
||||||
|
|
||||||
def is_rank_in_embedding_group(ignore_virtual=False):
|
|
||||||
"""Return true if current rank is in embedding group, False otherwise."""
|
|
||||||
rank = torch.distributed.get_rank()
|
|
||||||
global _EMBEDDING_GLOBAL_RANKS
|
|
||||||
if ignore_virtual:
|
|
||||||
return rank in _EMBEDDING_GLOBAL_RANKS
|
|
||||||
if rank in _EMBEDDING_GLOBAL_RANKS:
|
|
||||||
if rank == _EMBEDDING_GLOBAL_RANKS[0]:
|
|
||||||
return is_pipeline_first_stage(ignore_virtual=False)
|
|
||||||
elif rank == _EMBEDDING_GLOBAL_RANKS[-1]:
|
|
||||||
return is_pipeline_last_stage(ignore_virtual=False)
|
|
||||||
else:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def is_rank_in_position_embedding_group():
|
|
||||||
"""Return true if current rank is in position embedding group, False otherwise."""
|
|
||||||
rank = torch.distributed.get_rank()
|
|
||||||
global _POSITION_EMBEDDING_GLOBAL_RANKS
|
|
||||||
return rank in _POSITION_EMBEDDING_GLOBAL_RANKS
|
|
||||||
|
|
||||||
|
|
||||||
def is_pipeline_stage_before_split(rank=None):
|
|
||||||
"""Return True if pipeline stage executes encoder block for a model
|
|
||||||
with both encoder and decoder."""
|
|
||||||
if get_pipeline_model_parallel_world_size() == 1:
|
|
||||||
return True
|
|
||||||
if rank is None:
|
|
||||||
rank = get_pipeline_model_parallel_rank()
|
|
||||||
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
|
|
||||||
if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None:
|
|
||||||
return True
|
|
||||||
if rank < _PIPELINE_MODEL_PARALLEL_SPLIT_RANK:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def is_pipeline_stage_after_split(rank=None):
|
|
||||||
"""Return True if pipeline stage executes decoder block for a model
|
|
||||||
with both encoder and decoder."""
|
|
||||||
if get_pipeline_model_parallel_world_size() == 1:
|
|
||||||
return True
|
|
||||||
if rank is None:
|
|
||||||
rank = get_pipeline_model_parallel_rank()
|
|
||||||
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
|
|
||||||
if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None:
|
|
||||||
return True
|
|
||||||
if rank >= _PIPELINE_MODEL_PARALLEL_SPLIT_RANK:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def is_pipeline_stage_at_split():
|
|
||||||
"""Return true if pipeline stage executes decoder block and next
|
|
||||||
stage executes encoder block for a model with both encoder and
|
|
||||||
decoder."""
|
|
||||||
rank = get_pipeline_model_parallel_rank()
|
|
||||||
return is_pipeline_stage_before_split(rank) and \
|
|
||||||
is_pipeline_stage_after_split(rank+1)
|
|
||||||
|
|
||||||
|
|
||||||
def get_virtual_pipeline_model_parallel_rank():
|
|
||||||
"""Return the virtual pipeline-parallel rank."""
|
|
||||||
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
|
|
||||||
return _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
|
|
||||||
|
|
||||||
|
|
||||||
def set_virtual_pipeline_model_parallel_rank(rank):
|
|
||||||
"""Set the virtual pipeline-parallel rank."""
|
|
||||||
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
|
|
||||||
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = rank
|
|
||||||
|
|
||||||
|
|
||||||
def get_virtual_pipeline_model_parallel_world_size():
|
|
||||||
"""Return the virtual pipeline-parallel world size."""
|
|
||||||
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
|
|
||||||
return _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
|
|
||||||
|
|
||||||
|
|
||||||
def get_tensor_model_parallel_src_rank():
|
|
||||||
"""Calculate the global rank corresponding to the first local rank
|
|
||||||
in the tensor model parallel group."""
|
|
||||||
global_rank = torch.distributed.get_rank()
|
|
||||||
local_world_size = get_tensor_model_parallel_world_size()
|
|
||||||
return (global_rank // local_world_size) * local_world_size
|
|
||||||
|
|
||||||
|
|
||||||
def get_data_parallel_src_rank():
|
|
||||||
"""Calculate the global rank corresponding to the first local rank
|
|
||||||
in the data parallel group."""
|
|
||||||
assert _DATA_PARALLEL_GLOBAL_RANKS is not None, \
|
|
||||||
"Data parallel group is not initialized"
|
|
||||||
return _DATA_PARALLEL_GLOBAL_RANKS[0]
|
|
||||||
|
|
||||||
|
|
||||||
def get_pipeline_model_parallel_first_rank():
|
|
||||||
"""Return the global rank of the first process in the pipeline for the
|
|
||||||
current tensor parallel group"""
|
|
||||||
assert _PIPELINE_GLOBAL_RANKS is not None, \
|
|
||||||
"Pipeline parallel group is not initialized"
|
|
||||||
return _PIPELINE_GLOBAL_RANKS[0]
|
|
||||||
|
|
||||||
|
|
||||||
def get_pipeline_model_parallel_last_rank():
|
|
||||||
"""Return the global rank of the last process in the pipeline for the
|
|
||||||
current tensor parallel group"""
|
|
||||||
assert _PIPELINE_GLOBAL_RANKS is not None, \
|
|
||||||
"Pipeline parallel group is not initialized"
|
|
||||||
last_rank_local = get_pipeline_model_parallel_world_size() - 1
|
|
||||||
return _PIPELINE_GLOBAL_RANKS[last_rank_local]
|
|
||||||
|
|
||||||
def get_pipeline_model_parallel_next_rank():
|
|
||||||
"""Return the global rank that follows the caller in the pipeline"""
|
|
||||||
assert _PIPELINE_GLOBAL_RANKS is not None, \
|
|
||||||
"Pipeline parallel group is not initialized"
|
|
||||||
rank_in_pipeline = get_pipeline_model_parallel_rank()
|
|
||||||
world_size = get_pipeline_model_parallel_world_size()
|
|
||||||
return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size]
|
|
||||||
|
|
||||||
|
|
||||||
def get_pipeline_model_parallel_prev_rank():
|
|
||||||
"""Return the global rank that preceeds the caller in the pipeline"""
|
|
||||||
assert _PIPELINE_GLOBAL_RANKS is not None, \
|
|
||||||
"Pipeline parallel group is not initialized"
|
|
||||||
rank_in_pipeline = get_pipeline_model_parallel_rank()
|
|
||||||
world_size = get_pipeline_model_parallel_world_size()
|
|
||||||
return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size]
|
|
||||||
|
|
||||||
|
|
||||||
def get_data_parallel_world_size():
|
|
||||||
"""Return world size for the data parallel group."""
|
|
||||||
return torch.distributed.get_world_size(group=get_data_parallel_group())
|
|
||||||
|
|
||||||
|
|
||||||
def get_data_parallel_rank():
|
|
||||||
"""Return my rank for the data parallel group."""
|
|
||||||
return torch.distributed.get_rank(group=get_data_parallel_group())
|
|
||||||
|
|
||||||
def _set_global_memory_buffer():
|
|
||||||
"""Initialize global buffer"""
|
|
||||||
global _GLOBAL_MEMORY_BUFFER
|
|
||||||
assert _GLOBAL_MEMORY_BUFFER is None, 'global memory buffer is already initialized'
|
|
||||||
_GLOBAL_MEMORY_BUFFER = GlobalMemoryBuffer()
|
|
||||||
|
|
||||||
def get_global_memory_buffer():
|
|
||||||
"""Return the global GlobalMemoryBuffer object"""
|
|
||||||
assert _GLOBAL_MEMORY_BUFFER is not None, 'global memory buffer is not initialized'
|
|
||||||
return _GLOBAL_MEMORY_BUFFER
|
|
||||||
|
|
||||||
def get_all_reduce_launcher() -> 'GraphAllReduce':
|
|
||||||
assert _ALL_REDUCE_LAUNCHER is not None, 'all reduce launcher is not initialized'
|
|
||||||
return _ALL_REDUCE_LAUNCHER
|
|
||||||
|
|
||||||
def destroy_model_parallel():
|
|
||||||
"""Set the groups to none."""
|
|
||||||
global _MODEL_PARALLEL_GROUP
|
|
||||||
_MODEL_PARALLEL_GROUP = None
|
|
||||||
global _TENSOR_MODEL_PARALLEL_GROUP
|
|
||||||
_TENSOR_MODEL_PARALLEL_GROUP = None
|
|
||||||
global _PIPELINE_MODEL_PARALLEL_GROUP
|
|
||||||
_PIPELINE_MODEL_PARALLEL_GROUP = None
|
|
||||||
global _DATA_PARALLEL_GROUP
|
|
||||||
_DATA_PARALLEL_GROUP = None
|
|
||||||
global _EMBEDDING_GROUP
|
|
||||||
_EMBEDDING_GROUP = None
|
|
||||||
global _POSITION_EMBEDDING_GROUP
|
|
||||||
_POSITION_EMBEDDING_GROUP = None
|
|
||||||
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
|
|
||||||
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None
|
|
||||||
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
|
|
||||||
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
|
|
||||||
global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
|
|
||||||
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None
|
|
||||||
global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
|
|
||||||
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
|
|
||||||
global _MPU_TENSOR_MODEL_PARALLEL_RANK
|
|
||||||
_MPU_TENSOR_MODEL_PARALLEL_RANK = None
|
|
||||||
global _MPU_PIPELINE_MODEL_PARALLEL_RANK
|
|
||||||
_MPU_PIPELINE_MODEL_PARALLEL_RANK = None
|
|
||||||
global _GLOBAL_MEMORY_BUFFER
|
|
||||||
_GLOBAL_MEMORY_BUFFER = None
|
|
||||||
|
|
||||||
|
|
||||||
class GraphAllReduce:
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
max_num_tokens: int,
|
|
||||||
hidden_size: int,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
disable_graph: bool = False,
|
|
||||||
) -> None:
|
|
||||||
self.max_num_tokens = max_num_tokens
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.disable_graph = disable_graph
|
|
||||||
|
|
||||||
tp_world_size = get_tensor_model_parallel_world_size()
|
|
||||||
if tp_world_size == 1:
|
|
||||||
return
|
|
||||||
|
|
||||||
self.group = get_tensor_model_parallel_group()
|
|
||||||
self.buffer = torch.empty(
|
|
||||||
size=(max_num_tokens, hidden_size),
|
|
||||||
dtype=dtype,
|
|
||||||
device='cuda',
|
|
||||||
)
|
|
||||||
|
|
||||||
# Build graphs for different number of tokens.
|
|
||||||
if not self.disable_graph:
|
|
||||||
self.graphs = {}
|
|
||||||
for num_tokens in range(8, max_num_tokens + 1, 8):
|
|
||||||
self.graphs[num_tokens] = self._build_graph(num_tokens)
|
|
||||||
|
|
||||||
def _build_graph(self, num_tokens: int) -> torch.cuda.CUDAGraph:
|
|
||||||
# Warm up.
|
|
||||||
torch.distributed.all_reduce(self.buffer[:num_tokens], group=self.group)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
# Build graph.
|
|
||||||
graph = torch.cuda.CUDAGraph()
|
|
||||||
with torch.cuda.graph(graph):
|
|
||||||
torch.distributed.all_reduce(
|
|
||||||
self.buffer[:num_tokens], group=self.group)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
return graph
|
|
||||||
|
|
||||||
def launch(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
# NOTE: x must be a slice of self.buffer.
|
|
||||||
num_tokens = x.shape[0]
|
|
||||||
if self.disable_graph:
|
|
||||||
torch.distributed.all_reduce(x, group=self.group)
|
|
||||||
else:
|
|
||||||
self.graphs[num_tokens].replay()
|
|
||||||
return x
|
|
||||||
@ -1,55 +0,0 @@
|
|||||||
from .layers import (
|
|
||||||
ColumnParallelLinear,
|
|
||||||
RowParallelLinear,
|
|
||||||
VocabParallelEmbedding,
|
|
||||||
set_tensor_model_parallel_attributes,
|
|
||||||
set_defaults_if_not_set_tensor_model_parallel_attributes,
|
|
||||||
copy_tensor_model_parallel_attributes,
|
|
||||||
param_is_not_tensor_parallel_duplicate,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .mappings import (
|
|
||||||
copy_to_tensor_model_parallel_region,
|
|
||||||
gather_from_tensor_model_parallel_region,
|
|
||||||
gather_from_sequence_parallel_region,
|
|
||||||
scatter_to_tensor_model_parallel_region,
|
|
||||||
scatter_to_sequence_parallel_region,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .random import (
|
|
||||||
checkpoint,
|
|
||||||
get_cuda_rng_tracker,
|
|
||||||
model_parallel_cuda_manual_seed,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .utils import (
|
|
||||||
split_tensor_along_last_dim,
|
|
||||||
split_tensor_into_1d_equal_chunks,
|
|
||||||
gather_split_1d_tensor,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
#layers.py
|
|
||||||
"ColumnParallelLinear",
|
|
||||||
"RowParallelLinear",
|
|
||||||
"VocabParallelEmbedding",
|
|
||||||
"set_tensor_model_parallel_attributes",
|
|
||||||
"set_defaults_if_not_set_tensor_model_parallel_attributes",
|
|
||||||
"copy_tensor_model_parallel_attributes",
|
|
||||||
"param_is_not_tensor_parallel_duplicate",
|
|
||||||
# mappings.py
|
|
||||||
"copy_to_tensor_model_parallel_region",
|
|
||||||
"gather_from_tensor_model_parallel_region",
|
|
||||||
"gather_from_sequence_parallel_region",
|
|
||||||
# "reduce_from_tensor_model_parallel_region",
|
|
||||||
"scatter_to_tensor_model_parallel_region",
|
|
||||||
"scatter_to_sequence_parallel_region",
|
|
||||||
# random.py
|
|
||||||
"checkpoint",
|
|
||||||
"get_cuda_rng_tracker",
|
|
||||||
"model_parallel_cuda_manual_seed",
|
|
||||||
# utils.py
|
|
||||||
"split_tensor_along_last_dim",
|
|
||||||
"split_tensor_into_1d_equal_chunks",
|
|
||||||
"gather_split_1d_tensor",
|
|
||||||
]
|
|
||||||
@ -1,446 +0,0 @@
|
|||||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
|
||||||
|
|
||||||
# Parts of the code here are adapted from PyTorch
|
|
||||||
# repo: https://github.com/pytorch/pytorch
|
|
||||||
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import torch.nn.init as init
|
|
||||||
from torch.nn.parameter import Parameter
|
|
||||||
|
|
||||||
from cacheflow.parallel_utils.parallel_state import (
|
|
||||||
get_tensor_model_parallel_rank,
|
|
||||||
get_tensor_model_parallel_world_size,
|
|
||||||
get_all_reduce_launcher,
|
|
||||||
)
|
|
||||||
from .mappings import (
|
|
||||||
copy_to_tensor_model_parallel_region,
|
|
||||||
gather_from_tensor_model_parallel_region,
|
|
||||||
reduce_from_tensor_model_parallel_region,
|
|
||||||
scatter_to_tensor_model_parallel_region,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .random import get_cuda_rng_tracker
|
|
||||||
from .utils import (
|
|
||||||
divide,
|
|
||||||
VocabUtility,
|
|
||||||
)
|
|
||||||
|
|
||||||
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {'tensor_model_parallel': False,
|
|
||||||
'partition_dim': -1,
|
|
||||||
'partition_stride': 1}
|
|
||||||
|
|
||||||
def param_is_not_tensor_parallel_duplicate(param):
|
|
||||||
return (hasattr(param, 'tensor_model_parallel') and
|
|
||||||
param.tensor_model_parallel) or (
|
|
||||||
get_tensor_model_parallel_rank() == 0)
|
|
||||||
|
|
||||||
|
|
||||||
def set_tensor_model_parallel_attributes(tensor, is_parallel, dim, stride):
|
|
||||||
# Make sure the attributes are not set.
|
|
||||||
for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
|
|
||||||
assert not hasattr(tensor, attribute)
|
|
||||||
# Set the attributes.
|
|
||||||
setattr(tensor, 'tensor_model_parallel', is_parallel)
|
|
||||||
setattr(tensor, 'partition_dim', dim)
|
|
||||||
setattr(tensor, 'partition_stride', stride)
|
|
||||||
|
|
||||||
|
|
||||||
def set_defaults_if_not_set_tensor_model_parallel_attributes(tensor):
|
|
||||||
def maybe_set(attribute, value):
|
|
||||||
if not hasattr(tensor, attribute):
|
|
||||||
setattr(tensor, attribute, value)
|
|
||||||
for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
|
|
||||||
maybe_set(attribute, _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS[attribute])
|
|
||||||
|
|
||||||
|
|
||||||
def copy_tensor_model_parallel_attributes(destination_tensor, source_tensor):
|
|
||||||
def maybe_copy(attribute):
|
|
||||||
if hasattr(source_tensor, attribute):
|
|
||||||
setattr(destination_tensor, attribute,
|
|
||||||
getattr(source_tensor, attribute))
|
|
||||||
for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
|
|
||||||
maybe_copy(attribute)
|
|
||||||
|
|
||||||
|
|
||||||
def _initialize_affine_weight_gpu(weight, init_method,
|
|
||||||
partition_dim, stride=1):
|
|
||||||
"""Initialize affine weight for model parallel on GPU."""
|
|
||||||
|
|
||||||
set_tensor_model_parallel_attributes(tensor=weight,
|
|
||||||
is_parallel=True,
|
|
||||||
dim=partition_dim,
|
|
||||||
stride=stride)
|
|
||||||
|
|
||||||
with get_cuda_rng_tracker().fork():
|
|
||||||
init_method(weight)
|
|
||||||
|
|
||||||
|
|
||||||
def _initialize_affine_weight_cpu(weight, output_size, input_size,
|
|
||||||
per_partition_size, partition_dim,
|
|
||||||
init_method, stride=1,
|
|
||||||
return_master_weight=False,
|
|
||||||
*, params_dtype=None):
|
|
||||||
"""Initialize affine weight for model parallel.
|
|
||||||
|
|
||||||
Build the master weight on all processes and scatter
|
|
||||||
the relevant chunk."""
|
|
||||||
|
|
||||||
set_tensor_model_parallel_attributes(tensor=weight,
|
|
||||||
is_parallel=True,
|
|
||||||
dim=partition_dim,
|
|
||||||
stride=stride)
|
|
||||||
|
|
||||||
if params_dtype is None:
|
|
||||||
params_dtype = torch.get_default_dtype()
|
|
||||||
|
|
||||||
# Initialize master weight
|
|
||||||
master_weight = torch.empty(output_size, input_size,
|
|
||||||
dtype=torch.float,
|
|
||||||
requires_grad=False)
|
|
||||||
init_method(master_weight)
|
|
||||||
master_weight = master_weight.to(dtype=params_dtype)
|
|
||||||
|
|
||||||
# Split and copy
|
|
||||||
per_partition_per_stride_size = divide(per_partition_size, stride)
|
|
||||||
weight_list = torch.split(master_weight, per_partition_per_stride_size,
|
|
||||||
dim=partition_dim)
|
|
||||||
rank = get_tensor_model_parallel_rank()
|
|
||||||
world_size = get_tensor_model_parallel_world_size()
|
|
||||||
my_weight_list = weight_list[rank::world_size]
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
torch.cat(my_weight_list, dim=partition_dim, out=weight)
|
|
||||||
if return_master_weight:
|
|
||||||
return master_weight
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class VocabParallelEmbedding(torch.nn.Module):
|
|
||||||
"""Embedding parallelized in the vocabulary dimension.
|
|
||||||
|
|
||||||
This is mainly adapted from torch.nn.Embedding and all the default
|
|
||||||
values are kept.
|
|
||||||
Arguments:
|
|
||||||
num_embeddings: vocabulary size.
|
|
||||||
embedding_dim: size of hidden state.
|
|
||||||
|
|
||||||
Keyword Arguments:
|
|
||||||
init_method: method to initialize weights.
|
|
||||||
params_dtype
|
|
||||||
use_cpu_initialization
|
|
||||||
perform_initialization
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, num_embeddings: int, embedding_dim: int, *,
|
|
||||||
init_method=init.xavier_normal_,
|
|
||||||
params_dtype: torch.dtype=None,
|
|
||||||
use_cpu_initialization: bool=False,
|
|
||||||
perform_initialization: bool=True):
|
|
||||||
super(VocabParallelEmbedding, self).__init__()
|
|
||||||
# Keep the input dimensions.
|
|
||||||
self.num_embeddings = num_embeddings
|
|
||||||
self.embedding_dim = embedding_dim
|
|
||||||
if params_dtype is None:
|
|
||||||
params_dtype = torch.get_default_dtype()
|
|
||||||
|
|
||||||
# Set the defaults for compatibility.
|
|
||||||
self.padding_idx = None
|
|
||||||
self.max_norm = None
|
|
||||||
self.norm_type = 2.
|
|
||||||
self.scale_grad_by_freq = False
|
|
||||||
self.sparse = False
|
|
||||||
self._weight = None
|
|
||||||
self.tensor_model_parallel_size = get_tensor_model_parallel_world_size()
|
|
||||||
# Divide the weight matrix along the vocaburaly dimension.
|
|
||||||
self.vocab_start_index, self.vocab_end_index = \
|
|
||||||
VocabUtility.vocab_range_from_global_vocab_size(
|
|
||||||
self.num_embeddings, get_tensor_model_parallel_rank(),
|
|
||||||
self.tensor_model_parallel_size)
|
|
||||||
self.num_embeddings_per_partition = self.vocab_end_index - \
|
|
||||||
self.vocab_start_index
|
|
||||||
|
|
||||||
# Allocate weights and initialize.
|
|
||||||
if use_cpu_initialization:
|
|
||||||
self.weight = Parameter(torch.empty(
|
|
||||||
self.num_embeddings_per_partition, self.embedding_dim,
|
|
||||||
dtype=params_dtype))
|
|
||||||
if perform_initialization:
|
|
||||||
_initialize_affine_weight_cpu(
|
|
||||||
self.weight, self.num_embeddings, self.embedding_dim,
|
|
||||||
self.num_embeddings_per_partition, 0, init_method,
|
|
||||||
params_dtype=params_dtype)
|
|
||||||
else:
|
|
||||||
self.weight = Parameter(torch.empty(
|
|
||||||
self.num_embeddings_per_partition, self.embedding_dim,
|
|
||||||
device=torch.cuda.current_device(), dtype=params_dtype))
|
|
||||||
if perform_initialization:
|
|
||||||
_initialize_affine_weight_gpu(self.weight, init_method,
|
|
||||||
partition_dim=0, stride=1)
|
|
||||||
|
|
||||||
def forward(self, input_):
|
|
||||||
if self.tensor_model_parallel_size > 1:
|
|
||||||
# Build the mask.
|
|
||||||
input_mask = (input_ < self.vocab_start_index) | \
|
|
||||||
(input_ >= self.vocab_end_index)
|
|
||||||
# Mask the input.
|
|
||||||
masked_input = input_.clone() - self.vocab_start_index
|
|
||||||
masked_input[input_mask] = 0
|
|
||||||
else:
|
|
||||||
masked_input = input_
|
|
||||||
# Get the embeddings.
|
|
||||||
output_parallel = F.embedding(masked_input, self.weight,
|
|
||||||
self.padding_idx, self.max_norm,
|
|
||||||
self.norm_type, self.scale_grad_by_freq,
|
|
||||||
self.sparse)
|
|
||||||
# Mask the output embedding.
|
|
||||||
if self.tensor_model_parallel_size > 1:
|
|
||||||
output_parallel[input_mask, :] = 0.0
|
|
||||||
# Reduce across all the model parallel GPUs.
|
|
||||||
output = reduce_from_tensor_model_parallel_region(output_parallel)
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
class ColumnParallelLinear(torch.nn.Module):
|
|
||||||
"""Linear layer with column parallelism.
|
|
||||||
|
|
||||||
The linear layer is defined as Y = XA + b. A is parallelized along
|
|
||||||
its second dimension as A = [A_1, ..., A_p].
|
|
||||||
|
|
||||||
Arguments:
|
|
||||||
input_size: first dimension of matrix A.
|
|
||||||
output_size: second dimension of matrix A.
|
|
||||||
|
|
||||||
Keyword Arguments
|
|
||||||
bias: If true, add bias
|
|
||||||
gather_output: If true, call all-gather on output and make Y available
|
|
||||||
to all GPUs, otherwise, every GPU will have its output
|
|
||||||
which is Y_i = XA_i
|
|
||||||
init_method: method to initialize weights. Note that bias is always set
|
|
||||||
to zero.
|
|
||||||
stride: For the strided linear layers.
|
|
||||||
keep_master_weight_for_test: This was added for testing and should be
|
|
||||||
set to False. It returns the master weights
|
|
||||||
used for initialization.
|
|
||||||
skip_bias_add: This was added to enable performance optimations where bias
|
|
||||||
can be fused with other elementwise operations. we skip
|
|
||||||
adding bias but instead return it.
|
|
||||||
params_dtype:
|
|
||||||
use_cpu_initialization:
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, input_size, output_size, *,
|
|
||||||
bias=True, gather_output=True,
|
|
||||||
init_method=init.xavier_normal_, stride=1,
|
|
||||||
keep_master_weight_for_test=False,
|
|
||||||
skip_bias_add=False,
|
|
||||||
params_dtype=None,
|
|
||||||
use_cpu_initialization=False,
|
|
||||||
perform_initialization=True,
|
|
||||||
):
|
|
||||||
super(ColumnParallelLinear, self).__init__()
|
|
||||||
|
|
||||||
# Keep input parameters
|
|
||||||
self.input_size = input_size
|
|
||||||
self.output_size = output_size
|
|
||||||
self.gather_output = gather_output
|
|
||||||
# Divide the weight matrix along the last dimension.
|
|
||||||
world_size = get_tensor_model_parallel_world_size()
|
|
||||||
self.output_size_per_partition = divide(output_size, world_size)
|
|
||||||
self.skip_bias_add = skip_bias_add
|
|
||||||
|
|
||||||
if params_dtype is None:
|
|
||||||
params_dtype = torch.get_default_dtype()
|
|
||||||
|
|
||||||
# Parameters.
|
|
||||||
# Note: torch.nn.functional.linear performs XA^T + b and as a result
|
|
||||||
# we allocate the transpose.
|
|
||||||
# Initialize weight.
|
|
||||||
if use_cpu_initialization:
|
|
||||||
self.weight = Parameter(torch.empty(self.output_size_per_partition,
|
|
||||||
self.input_size,
|
|
||||||
dtype=params_dtype))
|
|
||||||
if perform_initialization:
|
|
||||||
self.master_weight = _initialize_affine_weight_cpu(
|
|
||||||
self.weight, self.output_size, self.input_size,
|
|
||||||
self.output_size_per_partition, 0, init_method,
|
|
||||||
stride=stride, return_master_weight=keep_master_weight_for_test)
|
|
||||||
else:
|
|
||||||
self.weight = Parameter(torch.empty(
|
|
||||||
self.output_size_per_partition, self.input_size,
|
|
||||||
device=torch.cuda.current_device(), dtype=params_dtype))
|
|
||||||
if perform_initialization:
|
|
||||||
_initialize_affine_weight_gpu(self.weight, init_method,
|
|
||||||
partition_dim=0, stride=stride)
|
|
||||||
|
|
||||||
if bias:
|
|
||||||
if use_cpu_initialization:
|
|
||||||
self.bias = Parameter(torch.empty(
|
|
||||||
self.output_size_per_partition, dtype=params_dtype))
|
|
||||||
else:
|
|
||||||
self.bias = Parameter(torch.empty(
|
|
||||||
self.output_size_per_partition,
|
|
||||||
device=torch.cuda.current_device(),
|
|
||||||
dtype=params_dtype))
|
|
||||||
set_tensor_model_parallel_attributes(self.bias, True, 0, stride)
|
|
||||||
# Always initialize bias to zero.
|
|
||||||
with torch.no_grad():
|
|
||||||
self.bias.zero_()
|
|
||||||
else:
|
|
||||||
self.register_parameter('bias', None)
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, input_):
|
|
||||||
"""Forward of ColumnParallelLinear
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_: 3D tensor whose order of dimension is [sequence, batch, hidden]
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- output
|
|
||||||
- bias
|
|
||||||
"""
|
|
||||||
bias = self.bias if not self.skip_bias_add else None
|
|
||||||
|
|
||||||
input_parallel = copy_to_tensor_model_parallel_region(input_)
|
|
||||||
# Matrix multiply.
|
|
||||||
output_parallel = F.linear(input_parallel, self.weight, bias)
|
|
||||||
if self.gather_output:
|
|
||||||
# All-gather across the partitions.
|
|
||||||
output = gather_from_tensor_model_parallel_region(output_parallel)
|
|
||||||
else:
|
|
||||||
output = output_parallel
|
|
||||||
output_bias = self.bias if self.skip_bias_add else None
|
|
||||||
return output, output_bias
|
|
||||||
|
|
||||||
|
|
||||||
class RowParallelLinear(torch.nn.Module):
|
|
||||||
"""Linear layer with row parallelism.
|
|
||||||
|
|
||||||
The linear layer is defined as Y = XA + b. A is parallelized along
|
|
||||||
its first dimension and X along its second dimension as:
|
|
||||||
- -
|
|
||||||
| A_1 |
|
|
||||||
| . |
|
|
||||||
A = | . | X = [X_1, ..., X_p]
|
|
||||||
| . |
|
|
||||||
| A_p |
|
|
||||||
- -
|
|
||||||
Arguments:
|
|
||||||
input_size: first dimension of matrix A.
|
|
||||||
output_size: second dimension of matrix A.
|
|
||||||
|
|
||||||
Keyword Arguments:
|
|
||||||
bias: If true, add bias. Note that bias is not parallelized.
|
|
||||||
input_is_parallel: If true, we assume that the input is already
|
|
||||||
split across the GPUs and we do not split
|
|
||||||
again.
|
|
||||||
init_method: method to initialize weights. Note that bias is always set
|
|
||||||
to zero.
|
|
||||||
stride: For the strided linear layers.
|
|
||||||
keep_master_weight_for_test: This was added for testing and should be
|
|
||||||
set to False. It returns the master weights
|
|
||||||
used for initialization.
|
|
||||||
skip_bias_add: This was added to enable performance optimization where bias
|
|
||||||
can be fused with other elementwise operations. We skip
|
|
||||||
adding bias but instead return it.
|
|
||||||
params_dtype:
|
|
||||||
use_cpu_initialization:
|
|
||||||
perform_initialization:
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, input_size, output_size, *,
|
|
||||||
bias=True, input_is_parallel=False,
|
|
||||||
init_method=init.xavier_normal_, stride=1,
|
|
||||||
keep_master_weight_for_test=False,
|
|
||||||
skip_bias_add=False,
|
|
||||||
params_dtype=None,
|
|
||||||
use_cpu_initialization=False,
|
|
||||||
perform_initialization=True,
|
|
||||||
):
|
|
||||||
super(RowParallelLinear, self).__init__()
|
|
||||||
|
|
||||||
# Keep input parameters
|
|
||||||
self.input_size = input_size
|
|
||||||
self.output_size = output_size
|
|
||||||
self.input_is_parallel = input_is_parallel
|
|
||||||
if params_dtype is None:
|
|
||||||
params_dtype = torch.get_default_dtype()
|
|
||||||
|
|
||||||
# Divide the weight matrix along the last dimension.
|
|
||||||
world_size = get_tensor_model_parallel_world_size()
|
|
||||||
self.input_size_per_partition = divide(input_size, world_size)
|
|
||||||
self.skip_bias_add = skip_bias_add
|
|
||||||
|
|
||||||
# Parameters.
|
|
||||||
# Note: torch.nn.functional.linear performs XA^T + b and as a result
|
|
||||||
# we allocate the transpose.
|
|
||||||
# Initialize weight.
|
|
||||||
if use_cpu_initialization:
|
|
||||||
self.weight = Parameter(torch.empty(self.output_size,
|
|
||||||
self.input_size_per_partition,
|
|
||||||
dtype=params_dtype))
|
|
||||||
if perform_initialization:
|
|
||||||
self.master_weight = _initialize_affine_weight_cpu(
|
|
||||||
self.weight, self.output_size, self.input_size,
|
|
||||||
self.input_size_per_partition, 1, init_method,
|
|
||||||
stride=stride, return_master_weight=keep_master_weight_for_test,
|
|
||||||
params_dtype=params_dtype)
|
|
||||||
else:
|
|
||||||
self.weight = Parameter(torch.empty(
|
|
||||||
self.output_size, self.input_size_per_partition,
|
|
||||||
device=torch.cuda.current_device(), dtype=params_dtype))
|
|
||||||
if perform_initialization:
|
|
||||||
_initialize_affine_weight_gpu(self.weight, init_method,
|
|
||||||
partition_dim=1, stride=stride)
|
|
||||||
if bias:
|
|
||||||
if use_cpu_initialization:
|
|
||||||
self.bias = Parameter(torch.empty(self.output_size,
|
|
||||||
dtype=params_dtype))
|
|
||||||
else:
|
|
||||||
self.bias = Parameter(torch.empty(
|
|
||||||
self.output_size, device=torch.cuda.current_device(),
|
|
||||||
dtype=params_dtype))
|
|
||||||
|
|
||||||
# Always initialize bias to zero.
|
|
||||||
with torch.no_grad():
|
|
||||||
self.bias.zero_()
|
|
||||||
else:
|
|
||||||
self.register_parameter('bias', None)
|
|
||||||
self.weight_t = self.weight.t()
|
|
||||||
|
|
||||||
def forward(self, input_):
|
|
||||||
"""Forward of RowParallelLinear
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_: 3D tensor whose order of dimension is [sequence, batch, hidden]
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- output
|
|
||||||
- bias
|
|
||||||
"""
|
|
||||||
# Set up backprop all-reduce.
|
|
||||||
if self.input_is_parallel:
|
|
||||||
input_parallel = input_
|
|
||||||
else:
|
|
||||||
input_parallel = scatter_to_tensor_model_parallel_region(input_)
|
|
||||||
if get_tensor_model_parallel_world_size() == 1:
|
|
||||||
# Matrix multiply.
|
|
||||||
output_ = F.linear(input_parallel, self.weight)
|
|
||||||
else:
|
|
||||||
# Matrix multiply.
|
|
||||||
all_reduce_launcher = get_all_reduce_launcher()
|
|
||||||
num_tokens = input_parallel.shape[0]
|
|
||||||
output_buffer = all_reduce_launcher.buffer[:num_tokens]
|
|
||||||
torch.matmul(input_parallel, self.weight_t, out=output_buffer)
|
|
||||||
# All-reduce across all the partitions.
|
|
||||||
output_ = all_reduce_launcher.launch(output_buffer)
|
|
||||||
|
|
||||||
if not self.skip_bias_add:
|
|
||||||
output = output_ + self.bias if self.bias is not None else output_
|
|
||||||
output_bias = None
|
|
||||||
else:
|
|
||||||
output = output_
|
|
||||||
output_bias = self.bias
|
|
||||||
return output, output_bias
|
|
||||||
@ -1,279 +0,0 @@
|
|||||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from cacheflow.parallel_utils.parallel_state import (
|
|
||||||
get_tensor_model_parallel_rank,
|
|
||||||
get_tensor_model_parallel_world_size,
|
|
||||||
get_tensor_model_parallel_group,
|
|
||||||
)
|
|
||||||
from .utils import split_tensor_along_last_dim
|
|
||||||
|
|
||||||
|
|
||||||
def _reduce(input_):
|
|
||||||
"""All-reduce the input tensor across model parallel group."""
|
|
||||||
|
|
||||||
# Bypass the function if we are using only 1 GPU.
|
|
||||||
if get_tensor_model_parallel_world_size()==1:
|
|
||||||
return input_
|
|
||||||
|
|
||||||
# All-reduce.
|
|
||||||
torch.distributed.all_reduce(input_, group=get_tensor_model_parallel_group())
|
|
||||||
|
|
||||||
return input_
|
|
||||||
|
|
||||||
|
|
||||||
def _split_along_last_dim(input_):
|
|
||||||
"""Split the tensor along its last dimension and keep the
|
|
||||||
corresponding slice."""
|
|
||||||
|
|
||||||
world_size = get_tensor_model_parallel_world_size()
|
|
||||||
# Bypass the function if we are using only 1 GPU.
|
|
||||||
if world_size == 1:
|
|
||||||
return input_
|
|
||||||
|
|
||||||
# Split along last dimension.
|
|
||||||
input_list = split_tensor_along_last_dim(input_, world_size)
|
|
||||||
|
|
||||||
# Note: torch.split does not create contiguous tensors by default.
|
|
||||||
rank = get_tensor_model_parallel_rank()
|
|
||||||
output = input_list[rank].contiguous()
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
def _split_along_first_dim(input_):
|
|
||||||
"""Split the tensor along its first dimension and keep the
|
|
||||||
corresponding slice."""
|
|
||||||
|
|
||||||
world_size = get_tensor_model_parallel_world_size()
|
|
||||||
# Bypass the function if we are using only 1 GPU.
|
|
||||||
if world_size == 1:
|
|
||||||
return input_
|
|
||||||
|
|
||||||
# Split along first dimension.
|
|
||||||
dim_size = input_.size()[0]
|
|
||||||
assert dim_size % world_size == 0, \
|
|
||||||
"First dimension of the tensor should be divisible by tensor parallel size"
|
|
||||||
local_dim_size = dim_size // world_size
|
|
||||||
rank = get_tensor_model_parallel_rank()
|
|
||||||
dim_offset = rank * local_dim_size
|
|
||||||
|
|
||||||
output = input_[dim_offset:dim_offset+local_dim_size].contiguous()
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
def _gather_along_last_dim(input_):
|
|
||||||
"""Gather tensors and concatinate along the last dimension."""
|
|
||||||
|
|
||||||
world_size = get_tensor_model_parallel_world_size()
|
|
||||||
# Bypass the function if we are using only 1 GPU.
|
|
||||||
if world_size == 1:
|
|
||||||
return input_
|
|
||||||
|
|
||||||
# Size and dimension.
|
|
||||||
last_dim = input_.dim() - 1
|
|
||||||
rank = get_tensor_model_parallel_rank()
|
|
||||||
|
|
||||||
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
|
|
||||||
tensor_list[rank] = input_
|
|
||||||
torch.distributed.all_gather(tensor_list, input_, group=get_tensor_model_parallel_group())
|
|
||||||
|
|
||||||
# Note: torch.cat already creates a contiguous tensor.
|
|
||||||
output = torch.cat(tensor_list, dim=last_dim).contiguous()
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
def _gather_along_first_dim(input_):
|
|
||||||
"""Gather tensors and concatinate along the first dimension."""
|
|
||||||
|
|
||||||
world_size = get_tensor_model_parallel_world_size()
|
|
||||||
# Bypass the function if we are using only 1 GPU.
|
|
||||||
if world_size == 1:
|
|
||||||
return input_
|
|
||||||
|
|
||||||
dim_size = list(input_.size())
|
|
||||||
dim_size[0] = dim_size[0] * world_size
|
|
||||||
|
|
||||||
output = torch.empty(dim_size, dtype=input_.dtype,
|
|
||||||
device=torch.cuda.current_device())
|
|
||||||
torch.distributed._all_gather_base(output, input_.contiguous(),
|
|
||||||
group=get_tensor_model_parallel_group())
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
def _reduce_scatter_along_first_dim(input_):
|
|
||||||
"""Reduce-scatter the input tensor across model parallel group."""
|
|
||||||
world_size = get_tensor_model_parallel_world_size()
|
|
||||||
# Bypass the function if we are using only 1 GPU.
|
|
||||||
if world_size == 1:
|
|
||||||
return input_
|
|
||||||
|
|
||||||
dim_size = list(input_.size())
|
|
||||||
assert dim_size[0] % world_size == 0, \
|
|
||||||
"First dimension of the tensor should be divisible by tensor parallel size"
|
|
||||||
|
|
||||||
dim_size[0] = dim_size[0] // world_size
|
|
||||||
|
|
||||||
output = torch.empty(dim_size, dtype=input_.dtype,
|
|
||||||
device=torch.cuda.current_device())
|
|
||||||
torch.distributed._reduce_scatter_base(output, input_.contiguous(),
|
|
||||||
group=get_tensor_model_parallel_group())
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
class _CopyToModelParallelRegion(torch.autograd.Function):
|
|
||||||
"""Pass the input to the model parallel region."""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def symbolic(graph, input_):
|
|
||||||
return input_
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, input_):
|
|
||||||
return input_
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, grad_output):
|
|
||||||
return _reduce(grad_output)
|
|
||||||
|
|
||||||
|
|
||||||
class _ReduceFromModelParallelRegion(torch.autograd.Function):
|
|
||||||
"""All-reduce the input from the model parallel region."""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def symbolic(graph, input_):
|
|
||||||
return _reduce(input_)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, input_):
|
|
||||||
return _reduce(input_)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, grad_output):
|
|
||||||
return grad_output
|
|
||||||
|
|
||||||
|
|
||||||
class _ScatterToModelParallelRegion(torch.autograd.Function):
|
|
||||||
"""Split the input and keep only the corresponding chuck to the rank."""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def symbolic(graph, input_):
|
|
||||||
return _split_along_last_dim(input_)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, input_):
|
|
||||||
return _split_along_last_dim(input_)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, grad_output):
|
|
||||||
return _gather_along_last_dim(grad_output)
|
|
||||||
|
|
||||||
|
|
||||||
class _GatherFromModelParallelRegion(torch.autograd.Function):
|
|
||||||
"""Gather the input from model parallel region and concatinate."""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def symbolic(graph, input_):
|
|
||||||
return _gather_along_last_dim(input_)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, input_):
|
|
||||||
return _gather_along_last_dim(input_)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, grad_output):
|
|
||||||
return _split_along_last_dim(grad_output)
|
|
||||||
|
|
||||||
|
|
||||||
class _ScatterToSequenceParallelRegion(torch.autograd.Function):
|
|
||||||
"""Split the input and keep only the corresponding chuck to the rank."""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def symbolic(graph, input_):
|
|
||||||
return _split_along_first_dim(input_)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, input_):
|
|
||||||
return _split_along_first_dim(input_)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, grad_output):
|
|
||||||
return _gather_along_first_dim(grad_output)
|
|
||||||
|
|
||||||
|
|
||||||
class _GatherFromSequenceParallelRegion(torch.autograd.Function):
|
|
||||||
"""Gather the input from sequence parallel region and concatinate."""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def symbolic(graph, input_, tensor_parallel_output_grad=True):
|
|
||||||
return _gather_along_first_dim(input_)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, input_, tensor_parallel_output_grad=True):
|
|
||||||
ctx.tensor_parallel_output_grad = tensor_parallel_output_grad
|
|
||||||
return _gather_along_first_dim(input_)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, grad_output):
|
|
||||||
tensor_parallel_output_grad = ctx.tensor_parallel_output_grad
|
|
||||||
|
|
||||||
# If the computation graph after the gather operation is
|
|
||||||
# in the tensor parallel mode, output gradients need to reduce
|
|
||||||
# scattered and whereas if the computation is duplicated,
|
|
||||||
# output gradients need to be scattered.
|
|
||||||
if tensor_parallel_output_grad:
|
|
||||||
return _reduce_scatter_along_first_dim(grad_output), None
|
|
||||||
else:
|
|
||||||
return _split_along_first_dim(grad_output), None
|
|
||||||
|
|
||||||
|
|
||||||
class _ReduceScatterToSequenceParallelRegion(torch.autograd.Function):
|
|
||||||
"""Reduce scatter the input from the model parallel region."""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def symbolic(graph, input_):
|
|
||||||
return _reduce_scatter_along_first_dim(input_)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, input_):
|
|
||||||
return _reduce_scatter_along_first_dim(input_)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, grad_output):
|
|
||||||
return _gather_along_first_dim(grad_output)
|
|
||||||
|
|
||||||
|
|
||||||
# -----------------
|
|
||||||
# Helper functions.
|
|
||||||
# -----------------
|
|
||||||
|
|
||||||
def copy_to_tensor_model_parallel_region(input_):
|
|
||||||
return _CopyToModelParallelRegion.apply(input_)
|
|
||||||
|
|
||||||
|
|
||||||
def reduce_from_tensor_model_parallel_region(input_):
|
|
||||||
return _ReduceFromModelParallelRegion.apply(input_)
|
|
||||||
|
|
||||||
|
|
||||||
def scatter_to_tensor_model_parallel_region(input_):
|
|
||||||
return _ScatterToModelParallelRegion.apply(input_)
|
|
||||||
|
|
||||||
|
|
||||||
def gather_from_tensor_model_parallel_region(input_):
|
|
||||||
return _GatherFromModelParallelRegion.apply(input_)
|
|
||||||
|
|
||||||
|
|
||||||
def scatter_to_sequence_parallel_region(input_):
|
|
||||||
return _ScatterToSequenceParallelRegion.apply(input_)
|
|
||||||
|
|
||||||
|
|
||||||
def gather_from_sequence_parallel_region(input_, tensor_parallel_output_grad=True):
|
|
||||||
return _GatherFromSequenceParallelRegion.apply(input_, tensor_parallel_output_grad)
|
|
||||||
|
|
||||||
|
|
||||||
def reduce_scatter_to_sequence_parallel_region(input_):
|
|
||||||
return _ReduceScatterToSequenceParallelRegion.apply(input_)
|
|
||||||
|
|
||||||
@ -1,253 +0,0 @@
|
|||||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
|
||||||
|
|
||||||
# Parts of the code here are adapted from PyTorch
|
|
||||||
# repo: https://github.com/pytorch/pytorch
|
|
||||||
|
|
||||||
import contextlib
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch import _C
|
|
||||||
from torch.cuda import _lazy_call, device as device_ctx_manager
|
|
||||||
from torch.utils.checkpoint import detach_variable
|
|
||||||
|
|
||||||
from cacheflow.parallel_utils.parallel_state import (
|
|
||||||
get_data_parallel_rank,
|
|
||||||
get_tensor_model_parallel_group,
|
|
||||||
get_tensor_model_parallel_rank,
|
|
||||||
get_tensor_model_parallel_world_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .utils import (
|
|
||||||
split_tensor_into_1d_equal_chunks,
|
|
||||||
gather_split_1d_tensor,
|
|
||||||
)
|
|
||||||
|
|
||||||
from cacheflow.parallel_utils.utils import safely_set_viewless_tensor_data
|
|
||||||
|
|
||||||
# Default name for the model parallel rng tracker.
|
|
||||||
_MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng'
|
|
||||||
|
|
||||||
|
|
||||||
def _set_cuda_rng_state(new_state, device=-1):
|
|
||||||
"""Sets the random number generator state of the current GPU.
|
|
||||||
|
|
||||||
Argumentss:
|
|
||||||
new_state (torch.ByteTensor): The desired state
|
|
||||||
This function is adapted from PyTorch repo (torch.cuda.set_rng_state)
|
|
||||||
with a single change: the input state is not cloned. Cloning caused
|
|
||||||
major performance issues for +4 GPU cases.
|
|
||||||
"""
|
|
||||||
if hasattr(_C, '_cuda_setRNGState') and callable(_C._cuda_setRNGState):
|
|
||||||
# older PyTorch
|
|
||||||
def cb():
|
|
||||||
with device_ctx_manager(device):
|
|
||||||
_C._cuda_setRNGState(new_state)
|
|
||||||
else:
|
|
||||||
# newer PyTorch
|
|
||||||
if device == -1:
|
|
||||||
device = torch.device('cuda')
|
|
||||||
elif isinstance(device, str):
|
|
||||||
device = torch.device(device)
|
|
||||||
elif isinstance(device, int):
|
|
||||||
device = torch.device('cuda', device)
|
|
||||||
|
|
||||||
def cb():
|
|
||||||
idx = device.index
|
|
||||||
if idx is None:
|
|
||||||
idx = torch.cuda.current_device()
|
|
||||||
default_generator = torch.cuda.default_generators[idx]
|
|
||||||
default_generator.set_state(new_state)
|
|
||||||
|
|
||||||
_lazy_call(cb)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class CudaRNGStatesTracker:
|
|
||||||
"""Tracker for the cuda RNG states.
|
|
||||||
|
|
||||||
Using the `add` method, a cuda rng state is initialized based on
|
|
||||||
the input `seed` and is assigned to `name`. Later, by forking the
|
|
||||||
rng state, we can perform operations and return to our starting
|
|
||||||
cuda state.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
# Map from a string name to the cuda rng state.
|
|
||||||
self.states_ = {}
|
|
||||||
# Seeds are just for book keeping and ensure no seed is set twice.
|
|
||||||
self.seeds_ = set()
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
"""Set to the initial state (no tracker)."""
|
|
||||||
self.states_ = {}
|
|
||||||
self.seeds_ = set()
|
|
||||||
|
|
||||||
def get_states(self):
|
|
||||||
"""Get rng states. Copy the dictionary so we have direct
|
|
||||||
pointers to the states, not just a pointer to the dictionary."""
|
|
||||||
states = {}
|
|
||||||
for name in self.states_:
|
|
||||||
states[name] = self.states_[name]
|
|
||||||
return states
|
|
||||||
|
|
||||||
def set_states(self, states):
|
|
||||||
"""Set the rng states. For efficiency purposes, we do not check
|
|
||||||
the size of seed for compatibility."""
|
|
||||||
self.states_ = states
|
|
||||||
|
|
||||||
def add(self, name, seed):
|
|
||||||
"""Track the rng state."""
|
|
||||||
# Check seed is not already used.
|
|
||||||
if seed in self.seeds_:
|
|
||||||
raise Exception('seed {} already exists'.format(seed))
|
|
||||||
self.seeds_.add(seed)
|
|
||||||
# Check that state is not already defined.
|
|
||||||
if name in self.states_:
|
|
||||||
raise Exception('cuda rng state {} already exists'.format(name))
|
|
||||||
# Get the current rng state.
|
|
||||||
orig_rng_state = torch.cuda.get_rng_state()
|
|
||||||
# Set the new state and store it.
|
|
||||||
torch.cuda.manual_seed(seed)
|
|
||||||
self.states_[name] = torch.cuda.get_rng_state()
|
|
||||||
# Reset rng state to what it was.
|
|
||||||
_set_cuda_rng_state(orig_rng_state)
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
|
||||||
def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME):
|
|
||||||
"""Fork the cuda rng state, perform operations, and exit with
|
|
||||||
the original state."""
|
|
||||||
# Check if we have added the state
|
|
||||||
if name not in self.states_:
|
|
||||||
raise Exception('cuda rng state {} is not added'.format(name))
|
|
||||||
# Store current rng state.
|
|
||||||
orig_cuda_rng_state = torch.cuda.get_rng_state()
|
|
||||||
# Set rng state to the desired one
|
|
||||||
_set_cuda_rng_state(self.states_[name])
|
|
||||||
# Do the stuff we wanted to do.
|
|
||||||
try:
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
# Update the current rng state for later use.
|
|
||||||
self.states_[name] = torch.cuda.get_rng_state()
|
|
||||||
# And set the state to the original state we started with.
|
|
||||||
_set_cuda_rng_state(orig_cuda_rng_state)
|
|
||||||
|
|
||||||
|
|
||||||
# RNG tracker object.
|
|
||||||
_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
|
|
||||||
|
|
||||||
|
|
||||||
def get_cuda_rng_tracker():
|
|
||||||
"""Get cuda rng tracker."""
|
|
||||||
return _CUDA_RNG_STATE_TRACKER
|
|
||||||
|
|
||||||
|
|
||||||
def model_parallel_cuda_manual_seed(seed):
|
|
||||||
"""Initialize model parallel cuda seed.
|
|
||||||
|
|
||||||
This function should be called after the model parallel is
|
|
||||||
initialized. Also, no torch.cuda.manual_seed should be called
|
|
||||||
after this function. Basically, this is replacement for that
|
|
||||||
function.
|
|
||||||
Two set of RNG states are tracked:
|
|
||||||
default state: This is for data parallelism and is the same among a
|
|
||||||
set of model parallel GPUs but different across
|
|
||||||
different model paralle groups. This is used for
|
|
||||||
example for dropout in the non-tensor-model-parallel regions.
|
|
||||||
tensor-model-parallel state: This state is different among a set of model
|
|
||||||
parallel GPUs, but the same across data parallel
|
|
||||||
groups. This is used for example for dropout in
|
|
||||||
model parallel regions.
|
|
||||||
"""
|
|
||||||
# 2718 is just for fun and any POSITIVE value will work.
|
|
||||||
offset = seed + 2718
|
|
||||||
tensor_model_parallel_seed = offset + get_tensor_model_parallel_rank()
|
|
||||||
# Data parallel gets the original seed.
|
|
||||||
data_parallel_seed = seed
|
|
||||||
|
|
||||||
_CUDA_RNG_STATE_TRACKER.reset()
|
|
||||||
# Set the default state.
|
|
||||||
torch.cuda.manual_seed(data_parallel_seed)
|
|
||||||
# and model parallel state.
|
|
||||||
_CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME,
|
|
||||||
tensor_model_parallel_seed)
|
|
||||||
|
|
||||||
|
|
||||||
class CheckpointFunction(torch.autograd.Function):
|
|
||||||
"""This function is adapted from torch.utils.checkpoint with
|
|
||||||
two main changes:
|
|
||||||
1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state`
|
|
||||||
2) the states in the model parallel tracker are also properly
|
|
||||||
tracked/set/reset.
|
|
||||||
"""
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, run_function, distribute_saved_activations, *args):
|
|
||||||
ctx.run_function = run_function
|
|
||||||
ctx.distribute_saved_activations \
|
|
||||||
= distribute_saved_activations
|
|
||||||
|
|
||||||
# Copy the rng states.
|
|
||||||
ctx.fwd_cpu_rng_state = torch.get_rng_state()
|
|
||||||
ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state()
|
|
||||||
ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
outputs = run_function(*args)
|
|
||||||
|
|
||||||
# Divide hidden states across model parallel group and only keep
|
|
||||||
# the chunk corresponding to the current rank.
|
|
||||||
if distribute_saved_activations:
|
|
||||||
ctx.input_0_shape = args[0].data.shape
|
|
||||||
safely_set_viewless_tensor_data(
|
|
||||||
args[0],
|
|
||||||
split_tensor_into_1d_equal_chunks(args[0].data, new_buffer=True))
|
|
||||||
|
|
||||||
# Store everything.
|
|
||||||
ctx.save_for_backward(*args)
|
|
||||||
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, *args):
|
|
||||||
if not torch.autograd._is_checkpoint_valid():
|
|
||||||
raise RuntimeError("Checkpointing is not compatible with .grad(), "
|
|
||||||
"please use .backward() if possible")
|
|
||||||
inputs = ctx.saved_tensors
|
|
||||||
if ctx.distribute_saved_activations:
|
|
||||||
safely_set_viewless_tensor_data(
|
|
||||||
inputs[0],
|
|
||||||
gather_split_1d_tensor(inputs[0].data).view(ctx.input_0_shape))
|
|
||||||
|
|
||||||
# Store the current states.
|
|
||||||
bwd_cpu_rng_state = torch.get_rng_state()
|
|
||||||
bwd_cuda_rng_state = torch.cuda.get_rng_state()
|
|
||||||
bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
|
|
||||||
|
|
||||||
# Set the states to what it used to be before the forward pass.
|
|
||||||
torch.set_rng_state(ctx.fwd_cpu_rng_state)
|
|
||||||
_set_cuda_rng_state(ctx.fwd_cuda_rng_state)
|
|
||||||
get_cuda_rng_tracker().set_states(ctx.fwd_cuda_rng_state_tracker)
|
|
||||||
|
|
||||||
# Compute the forward pass.
|
|
||||||
detached_inputs = detach_variable(inputs)
|
|
||||||
with torch.enable_grad():
|
|
||||||
outputs = ctx.run_function(*detached_inputs)
|
|
||||||
|
|
||||||
# Set the states back to what it was at the start of this function.
|
|
||||||
torch.set_rng_state(bwd_cpu_rng_state)
|
|
||||||
_set_cuda_rng_state(bwd_cuda_rng_state)
|
|
||||||
get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker)
|
|
||||||
|
|
||||||
if isinstance(outputs, torch.Tensor):
|
|
||||||
outputs = (outputs,)
|
|
||||||
torch.autograd.backward(outputs, args)
|
|
||||||
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp
|
|
||||||
for inp in detached_inputs)
|
|
||||||
return (None, None) + grads
|
|
||||||
|
|
||||||
|
|
||||||
def checkpoint(function, distribute_saved_activations, *args):
|
|
||||||
"""Checkpoint a model or part of the model.
|
|
||||||
This has been directly copied from torch.utils.checkpoint."""
|
|
||||||
return CheckpointFunction.apply(function,
|
|
||||||
distribute_saved_activations, *args)
|
|
||||||
@ -1,108 +0,0 @@
|
|||||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from typing import List, Sequence
|
|
||||||
|
|
||||||
from cacheflow.parallel_utils.utils import divide
|
|
||||||
from cacheflow.parallel_utils import parallel_state
|
|
||||||
|
|
||||||
def split_tensor_along_last_dim(
|
|
||||||
tensor: torch.Tensor,
|
|
||||||
num_partitions: int,
|
|
||||||
contiguous_split_chunks: bool = False,
|
|
||||||
) -> List[torch.Tensor]:
|
|
||||||
""" Split a tensor along its last dimension.
|
|
||||||
|
|
||||||
Arguments:
|
|
||||||
tensor: input tensor.
|
|
||||||
num_partitions: number of partitions to split the tensor
|
|
||||||
contiguous_split_chunks: If True, make each chunk contiguous
|
|
||||||
in memory.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A list of Tensors
|
|
||||||
"""
|
|
||||||
# Get the size and dimension.
|
|
||||||
last_dim = tensor.dim() - 1
|
|
||||||
last_dim_size = divide(tensor.size()[last_dim], num_partitions)
|
|
||||||
# Split.
|
|
||||||
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
|
|
||||||
# Note: torch.split does not create contiguous tensors by default.
|
|
||||||
if contiguous_split_chunks:
|
|
||||||
return tuple(chunk.contiguous() for chunk in tensor_list)
|
|
||||||
|
|
||||||
return tensor_list
|
|
||||||
|
|
||||||
def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False):
|
|
||||||
""" Break a tensor into equal 1D chunks across tensor parallel ranks.
|
|
||||||
|
|
||||||
Returns a Tensor or View with this rank's portion of the data.
|
|
||||||
|
|
||||||
Arguments:
|
|
||||||
tensor: The tensor to split
|
|
||||||
|
|
||||||
Keyword Arguments:
|
|
||||||
new_buffer (bool): If True, returns a new Tensor.
|
|
||||||
If False, returns a view into the existing Tensor.
|
|
||||||
Default is False
|
|
||||||
|
|
||||||
"""
|
|
||||||
partition_size = torch.numel(tensor) // \
|
|
||||||
parallel_state.get_tensor_model_parallel_world_size()
|
|
||||||
start_index = partition_size * parallel_state.get_tensor_model_parallel_rank()
|
|
||||||
end_index = start_index + partition_size
|
|
||||||
if new_buffer:
|
|
||||||
data = torch.empty(partition_size, dtype=tensor.dtype,
|
|
||||||
device=torch.cuda.current_device(),
|
|
||||||
requires_grad=False)
|
|
||||||
data.copy_(tensor.view(-1)[start_index:end_index])
|
|
||||||
else:
|
|
||||||
data = tensor.view(-1)[start_index:end_index]
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
def gather_split_1d_tensor(tensor):
|
|
||||||
""" Opposite of split_tensor_into_1d_equal_chunks. Gather values from tensor
|
|
||||||
model parallel ranks.
|
|
||||||
|
|
||||||
Returns a new Tensor with the gathered data.
|
|
||||||
|
|
||||||
Arguments:
|
|
||||||
tensor: A Tensor or view of this rank's portion of the data.
|
|
||||||
"""
|
|
||||||
numel_gathered = torch.numel(tensor) * \
|
|
||||||
parallel_state.get_tensor_model_parallel_world_size()
|
|
||||||
gathered = torch.empty(numel_gathered, dtype=tensor.dtype,
|
|
||||||
device=torch.cuda.current_device(),
|
|
||||||
requires_grad=False)
|
|
||||||
# TODO: This API is experimental in pytorch (as of Feb 2022) and
|
|
||||||
# this might break in future pytorch releases. We chose this API
|
|
||||||
# as opposed to torch.distributed.all_gather for efficiency reasons.
|
|
||||||
# This API calls directly NCCL all-gather versus the former does
|
|
||||||
# internal copies and can potentially cause slow down.
|
|
||||||
torch.distributed._all_gather_base(gathered, tensor,
|
|
||||||
group=parallel_state.get_tensor_model_parallel_group())
|
|
||||||
return gathered
|
|
||||||
|
|
||||||
|
|
||||||
class VocabUtility:
|
|
||||||
""" Split the vocabulary into `world_size` chunks and return the first
|
|
||||||
and last index of the vocabulary belonging to the `rank`
|
|
||||||
partition: Note that indices in [fist, last)
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def vocab_range_from_per_partition_vocab_size(
|
|
||||||
per_partition_vocab_size: int, rank, world_size: int
|
|
||||||
) -> Sequence[int]:
|
|
||||||
index_f = rank * per_partition_vocab_size
|
|
||||||
index_l = index_f + per_partition_vocab_size
|
|
||||||
return index_f, index_l
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def vocab_range_from_global_vocab_size(global_vocab_size: int, rank: int, world_size: int) -> Sequence[int]:
|
|
||||||
per_partition_vocab_size = divide(global_vocab_size, world_size)
|
|
||||||
return VocabUtility.vocab_range_from_per_partition_vocab_size(
|
|
||||||
per_partition_vocab_size, rank, world_size
|
|
||||||
)
|
|
||||||
@ -1,120 +0,0 @@
|
|||||||
"""Utility functions used throughout Megatron core"""
|
|
||||||
from functools import reduce
|
|
||||||
import operator
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from cacheflow.parallel_utils import parallel_state
|
|
||||||
|
|
||||||
|
|
||||||
def ensure_divisibility(numerator, denominator):
|
|
||||||
"""Ensure that numerator is divisible by the denominator."""
|
|
||||||
assert numerator % denominator == 0, "{} is not divisible by {}".format(
|
|
||||||
numerator, denominator
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def divide(numerator, denominator):
|
|
||||||
"""Ensure that numerator is divisible by the denominator and return
|
|
||||||
the division value."""
|
|
||||||
ensure_divisibility(numerator, denominator)
|
|
||||||
return numerator // denominator
|
|
||||||
|
|
||||||
|
|
||||||
class GlobalMemoryBuffer:
|
|
||||||
"""Global buffer to avoid dynamic memory allocations.
|
|
||||||
Caller should ensure that buffers of the same name
|
|
||||||
are not used concurrently."""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.buffer = {}
|
|
||||||
|
|
||||||
def get_tensor(self, tensor_shape, dtype, name):
|
|
||||||
required_len = reduce(operator.mul, tensor_shape, 1)
|
|
||||||
if self.buffer.get((name, dtype), None) is None or \
|
|
||||||
self.buffer[(name, dtype)].numel() < required_len:
|
|
||||||
self.buffer[(name, dtype)] = \
|
|
||||||
torch.empty(required_len,
|
|
||||||
dtype=dtype,
|
|
||||||
device=torch.cuda.current_device(),
|
|
||||||
requires_grad=False)
|
|
||||||
|
|
||||||
return self.buffer[(name, dtype)][0:required_len].view(*tensor_shape)
|
|
||||||
|
|
||||||
def _kernel_make_viewless_tensor(inp, requires_grad):
|
|
||||||
'''Make a viewless tensor.
|
|
||||||
|
|
||||||
View tensors have the undesirable side-affect of retaining a reference
|
|
||||||
to the originally-viewed tensor, even after manually setting the '.data'
|
|
||||||
field. This method creates a new tensor that links to the old tensor's
|
|
||||||
data, without linking the viewed tensor, referenced via the '._base'
|
|
||||||
field.
|
|
||||||
'''
|
|
||||||
out = torch.empty(
|
|
||||||
(1,),
|
|
||||||
dtype = inp.dtype,
|
|
||||||
device = inp.device,
|
|
||||||
requires_grad = requires_grad,
|
|
||||||
)
|
|
||||||
out.data = inp.data
|
|
||||||
return out
|
|
||||||
|
|
||||||
class MakeViewlessTensor(torch.autograd.Function):
|
|
||||||
'''
|
|
||||||
Autograd function to make a viewless tensor.
|
|
||||||
|
|
||||||
This function should be used in cases where the computation graph needs
|
|
||||||
to be propagated, but we only want a viewless tensor (e.g.,
|
|
||||||
ParallelTransformer's hidden_states). Call this function by passing
|
|
||||||
'keep_graph = True' to 'make_viewless_tensor()'.
|
|
||||||
'''
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, inp, requires_grad):
|
|
||||||
return _kernel_make_viewless_tensor(inp, requires_grad)
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, grad_output):
|
|
||||||
return grad_output, None
|
|
||||||
|
|
||||||
def make_viewless_tensor(inp, requires_grad, keep_graph):
|
|
||||||
'''
|
|
||||||
Entry-point for creating viewless tensors.
|
|
||||||
|
|
||||||
This method should be used, rather than calling 'MakeViewlessTensor'
|
|
||||||
or '_kernel_make_viewless_tensor' directly. This method acts as a
|
|
||||||
switch for determining if an autograd function or a regular method
|
|
||||||
should be used to create the tensor.
|
|
||||||
'''
|
|
||||||
|
|
||||||
# return tensor as-is, if not a 'view'
|
|
||||||
if inp._base is None:
|
|
||||||
return inp
|
|
||||||
|
|
||||||
# create viewless tensor
|
|
||||||
if keep_graph:
|
|
||||||
return MakeViewlessTensor.apply(inp, requires_grad)
|
|
||||||
else:
|
|
||||||
return _kernel_make_viewless_tensor(inp, requires_grad)
|
|
||||||
|
|
||||||
def assert_viewless_tensor(tensor, extra_msg = None):
|
|
||||||
'''Assert that a tensor is not a view (i.e., its '._base' field is
|
|
||||||
not set).'''
|
|
||||||
if isinstance(tensor, list):
|
|
||||||
[ assert_viewless_tensor(t) for t in tensor ]
|
|
||||||
return tensor
|
|
||||||
if not isinstance(tensor, torch.Tensor):
|
|
||||||
return tensor
|
|
||||||
assert tensor._base is None, (
|
|
||||||
"Ensure tensor._base is None before setting tensor.data or storing "
|
|
||||||
"tensor to memory buffer. Otherwise, a memory leak will occur (and "
|
|
||||||
"likely accumulate over iterations). %s"
|
|
||||||
) % extra_msg
|
|
||||||
return tensor
|
|
||||||
|
|
||||||
def safely_set_viewless_tensor_data(tensor, new_data_tensor):
|
|
||||||
'''Safely set tensor's '.data' field.
|
|
||||||
|
|
||||||
Check first that the tensor is viewless (i.e., '._base' not set). If not,
|
|
||||||
raise an exception.
|
|
||||||
'''
|
|
||||||
assert_viewless_tensor(tensor, extra_msg = "FYI, tensor._base has shape %s, and new_data_tensor has shape %s." % ("--" if tensor._base is None else tensor._base.shape, new_data_tensor.shape))
|
|
||||||
tensor.data = new_data_tensor
|
|
||||||
@ -1,84 +0,0 @@
|
|||||||
from typing import Optional, Set, Dict
|
|
||||||
|
|
||||||
|
|
||||||
class SamplingParams:
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
n: int,
|
|
||||||
temperature: float,
|
|
||||||
top_p: float,
|
|
||||||
use_beam_search: bool,
|
|
||||||
stop_token_ids: Set[int],
|
|
||||||
max_num_steps: int,
|
|
||||||
num_logprobs: int,
|
|
||||||
context_window_size: Optional[int],
|
|
||||||
) -> None:
|
|
||||||
if n < 1:
|
|
||||||
raise ValueError(f'n must be at least 1, got {n}.')
|
|
||||||
if temperature < 0.0:
|
|
||||||
raise ValueError(
|
|
||||||
f'temperature must be non-negative, got {temperature}.')
|
|
||||||
if not 0.0 < top_p <= 1.0:
|
|
||||||
raise ValueError(f'top_p must be in (0, 1], got {top_p}.')
|
|
||||||
if max_num_steps < 1:
|
|
||||||
raise ValueError(
|
|
||||||
f'max_num_steps must be at least 1, got {max_num_steps}.')
|
|
||||||
if num_logprobs < 0:
|
|
||||||
raise ValueError(
|
|
||||||
f'num_logprobs must be non-negative, got {num_logprobs}.')
|
|
||||||
if context_window_size is not None and context_window_size < 0:
|
|
||||||
raise ValueError(
|
|
||||||
'context_window_size must be non-negative, '
|
|
||||||
f'got {context_window_size}.')
|
|
||||||
|
|
||||||
if use_beam_search:
|
|
||||||
if n == 1:
|
|
||||||
raise ValueError(
|
|
||||||
'n must be greater than 1 when using beam search.')
|
|
||||||
if temperature > 0.0:
|
|
||||||
raise ValueError(
|
|
||||||
'temperature must be 0 when using beam search.')
|
|
||||||
if top_p < 1.0:
|
|
||||||
raise ValueError(
|
|
||||||
'top_p must be 1 when using beam search.')
|
|
||||||
elif temperature == 0.0:
|
|
||||||
# Zero temperature means greedy sampling.
|
|
||||||
if n > 1:
|
|
||||||
raise ValueError(
|
|
||||||
'n must be 1 when using greedy sampling.')
|
|
||||||
if top_p < 1.0:
|
|
||||||
raise ValueError(
|
|
||||||
'top_p must be 1 when using greedy sampling.')
|
|
||||||
|
|
||||||
self.n = n
|
|
||||||
self.temperature = temperature
|
|
||||||
self.top_p = top_p
|
|
||||||
self.use_beam_search = use_beam_search
|
|
||||||
self.stop_token_ids = stop_token_ids
|
|
||||||
self.max_num_steps = max_num_steps
|
|
||||||
self.num_logprobs = num_logprobs
|
|
||||||
self.context_window_size = context_window_size
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return (f'SamplingParams(n={self.n}, '
|
|
||||||
f'temperature={self.temperature}, '
|
|
||||||
f'top_p={self.top_p}, '
|
|
||||||
f'use_beam_search={self.use_beam_search}, '
|
|
||||||
f'stop_token_ids={self.stop_token_ids}, '
|
|
||||||
f'max_num_steps={self.max_num_steps}, '
|
|
||||||
f'num_logprobs={self.num_logprobs}, '
|
|
||||||
f'context_window_size={self.context_window_size})')
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_dict(cls, d: Dict) -> 'SamplingParams':
|
|
||||||
return cls(
|
|
||||||
n=d.get('n', 1),
|
|
||||||
temperature=d.get('temperature', 1.0),
|
|
||||||
top_p=d.get('top_p', 1.0),
|
|
||||||
use_beam_search=d.get('use_beam_search', False),
|
|
||||||
stop_token_ids=set(d.get('stop_token_ids', set())),
|
|
||||||
max_num_steps=d.get('max_num_steps', 16),
|
|
||||||
num_logprobs=d.get('num_logprobs', 0),
|
|
||||||
context_window_size=d.get('context_window_size', None),
|
|
||||||
)
|
|
||||||
@ -1,169 +0,0 @@
|
|||||||
import copy
|
|
||||||
import enum
|
|
||||||
from typing import Dict, List, Optional
|
|
||||||
|
|
||||||
from cacheflow.block import LogicalTokenBlock
|
|
||||||
from cacheflow.sampling_params import SamplingParams
|
|
||||||
|
|
||||||
|
|
||||||
class SequenceStatus(enum.Enum):
|
|
||||||
WAITING = enum.auto()
|
|
||||||
RUNNING = enum.auto()
|
|
||||||
SWAPPED = enum.auto()
|
|
||||||
FINISHED = enum.auto()
|
|
||||||
|
|
||||||
|
|
||||||
class Sequence:
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
seq_id: int,
|
|
||||||
token_ids: List[int],
|
|
||||||
block_size: int,
|
|
||||||
) -> None:
|
|
||||||
self.seq_id = seq_id
|
|
||||||
self.block_size = block_size
|
|
||||||
|
|
||||||
self.logical_token_blocks: List[LogicalTokenBlock] = []
|
|
||||||
# Initialize the logical token blocks with the given token ids.
|
|
||||||
self.add(token_ids)
|
|
||||||
|
|
||||||
self.prompt_len = len(token_ids)
|
|
||||||
self.status = SequenceStatus.WAITING
|
|
||||||
self.output_logprobs: List[Dict[int, float]] = []
|
|
||||||
self.cumulative_logprobs = 0.0
|
|
||||||
|
|
||||||
def add_block(self) -> None:
|
|
||||||
block = LogicalTokenBlock(
|
|
||||||
block_number=len(self.logical_token_blocks),
|
|
||||||
block_size=self.block_size,
|
|
||||||
)
|
|
||||||
self.logical_token_blocks.append(block)
|
|
||||||
|
|
||||||
def add(self, token_ids: List[int]) -> None:
|
|
||||||
while token_ids:
|
|
||||||
if not self.logical_token_blocks:
|
|
||||||
self.add_block()
|
|
||||||
|
|
||||||
last_block = self.logical_token_blocks[-1]
|
|
||||||
if last_block.is_full():
|
|
||||||
self.add_block()
|
|
||||||
last_block = self.logical_token_blocks[-1]
|
|
||||||
|
|
||||||
num_empty_slots = last_block.get_num_empty_slots()
|
|
||||||
last_block.append(token_ids[:num_empty_slots])
|
|
||||||
token_ids = token_ids[num_empty_slots:]
|
|
||||||
|
|
||||||
def append(self, token_id: int, logprobs: Dict[int, float]) -> None:
|
|
||||||
assert token_id in logprobs
|
|
||||||
self.add([token_id])
|
|
||||||
self.output_logprobs.append(logprobs)
|
|
||||||
self.cumulative_logprobs += logprobs[token_id]
|
|
||||||
|
|
||||||
def get_len(self) -> int:
|
|
||||||
return sum(block.num_tokens for block in self.logical_token_blocks)
|
|
||||||
|
|
||||||
def get_token_ids(self) -> List[int]:
|
|
||||||
token_ids: List[int] = []
|
|
||||||
for block in self.logical_token_blocks:
|
|
||||||
token_ids.extend(block.get_token_ids())
|
|
||||||
return token_ids
|
|
||||||
|
|
||||||
def get_last_token_id(self) -> int:
|
|
||||||
return self.logical_token_blocks[-1].get_last_token_id()
|
|
||||||
|
|
||||||
def fork(self, child_seq: 'Sequence') -> 'Sequence':
|
|
||||||
child_seq.logical_token_blocks = copy.deepcopy(self.logical_token_blocks)
|
|
||||||
child_seq.output_logprobs = copy.deepcopy(self.output_logprobs)
|
|
||||||
child_seq.cumulative_logprobs = self.cumulative_logprobs
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return (f'Sequence(seq_id={self.seq_id}, '
|
|
||||||
f'status={self.status.name}, '
|
|
||||||
f'num_blocks={len(self.logical_token_blocks)})')
|
|
||||||
|
|
||||||
|
|
||||||
class SequenceGroup:
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
group_id: int,
|
|
||||||
seqs: List[Sequence],
|
|
||||||
arrival_time: float,
|
|
||||||
) -> None:
|
|
||||||
self.group_id = group_id
|
|
||||||
self.seqs = seqs
|
|
||||||
self.arrival_time = arrival_time
|
|
||||||
|
|
||||||
def get_seqs(
|
|
||||||
self,
|
|
||||||
status: Optional[SequenceStatus] = None,
|
|
||||||
) -> List[Sequence]:
|
|
||||||
if status is None:
|
|
||||||
return self.seqs
|
|
||||||
else:
|
|
||||||
return [seq for seq in self.seqs if seq.status == status]
|
|
||||||
|
|
||||||
def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
|
|
||||||
return len(self.get_seqs(status))
|
|
||||||
|
|
||||||
def find(self, seq_id: int) -> Sequence:
|
|
||||||
for seq in self.seqs:
|
|
||||||
if seq.seq_id == seq_id:
|
|
||||||
return seq
|
|
||||||
raise ValueError(f'Sequence {seq_id} not found.')
|
|
||||||
|
|
||||||
def is_finished(self) -> bool:
|
|
||||||
return all(seq.status == SequenceStatus.FINISHED for seq in self.seqs)
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return (f'SequenceGroup(group_id={self.group_id}, '
|
|
||||||
f'num_seqs={len(self.seqs)})')
|
|
||||||
|
|
||||||
|
|
||||||
class SequenceGroupInputs:
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
group_id: int,
|
|
||||||
is_prompt: bool,
|
|
||||||
input_tokens: Dict[int, List[int]], # Seq id -> token ids.
|
|
||||||
context_len: int,
|
|
||||||
seq_logprobs: Dict[int, float], # Seq id -> cumulative logprobs.
|
|
||||||
sampling_params: SamplingParams,
|
|
||||||
block_tables: Dict[int, List[int]], # Seq id -> List of physical block numbers.
|
|
||||||
) -> None:
|
|
||||||
self.group_id = group_id
|
|
||||||
self.is_prompt = is_prompt
|
|
||||||
self.input_tokens = input_tokens
|
|
||||||
self.context_len = context_len
|
|
||||||
self.seq_logprobs = seq_logprobs
|
|
||||||
self.sampling_params = sampling_params
|
|
||||||
self.block_tables = block_tables
|
|
||||||
|
|
||||||
|
|
||||||
class SequenceOutputs:
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
seq_id: int,
|
|
||||||
parent_seq_id: int,
|
|
||||||
output_token: int,
|
|
||||||
logprobs: Dict[int, float], # Token id -> logP(x_i+1 | x_0, ..., x_i).
|
|
||||||
) -> None:
|
|
||||||
self.seq_id = seq_id
|
|
||||||
self.parent_seq_id = parent_seq_id
|
|
||||||
self.output_token = output_token
|
|
||||||
self.logprobs = logprobs
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return (f'SequenceOutputs(seq_id={self.seq_id}, '
|
|
||||||
f'parent_seq_id={self.parent_seq_id}, '
|
|
||||||
f'output_token={self.output_token}), '
|
|
||||||
f'logprobs={self.logprobs}')
|
|
||||||
|
|
||||||
def __eq__(self, other: 'SequenceOutputs') -> bool:
|
|
||||||
return (self.seq_id == other.seq_id and
|
|
||||||
self.parent_seq_id == other.parent_seq_id and
|
|
||||||
self.output_token == other.output_token and
|
|
||||||
self.logprobs == other.logprobs)
|
|
||||||
@ -1,47 +0,0 @@
|
|||||||
import enum
|
|
||||||
import random
|
|
||||||
import psutil
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from cacheflow.parallel_utils.parallel_state import model_parallel_is_initialized
|
|
||||||
from cacheflow.parallel_utils.tensor_parallel import model_parallel_cuda_manual_seed
|
|
||||||
|
|
||||||
|
|
||||||
class Device(enum.Enum):
|
|
||||||
GPU = enum.auto()
|
|
||||||
CPU = enum.auto()
|
|
||||||
|
|
||||||
|
|
||||||
class Counter:
|
|
||||||
|
|
||||||
def __init__(self, start: int = 0) -> None:
|
|
||||||
self.counter = start
|
|
||||||
|
|
||||||
def __next__(self) -> int:
|
|
||||||
id = self.counter
|
|
||||||
self.counter += 1
|
|
||||||
return id
|
|
||||||
|
|
||||||
def reset(self) -> None:
|
|
||||||
self.counter = 0
|
|
||||||
|
|
||||||
|
|
||||||
def set_random_seed(seed: int):
|
|
||||||
random.seed(seed)
|
|
||||||
np.random.seed(seed)
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.manual_seed_all(seed)
|
|
||||||
|
|
||||||
if model_parallel_is_initialized():
|
|
||||||
model_parallel_cuda_manual_seed(seed)
|
|
||||||
|
|
||||||
|
|
||||||
def get_gpu_memory(gpu: int = 0) -> int:
|
|
||||||
return torch.cuda.get_device_properties(gpu).total_memory
|
|
||||||
|
|
||||||
|
|
||||||
def get_cpu_memory() -> int:
|
|
||||||
return psutil.virtual_memory().total
|
|
||||||
@ -1,101 +0,0 @@
|
|||||||
from typing import Dict, List, Union, Tuple
|
|
||||||
|
|
||||||
import ray
|
|
||||||
|
|
||||||
from cacheflow.master.scheduler import Scheduler
|
|
||||||
from cacheflow.sequence import SequenceGroupInputs
|
|
||||||
from cacheflow.worker.worker import Worker
|
|
||||||
|
|
||||||
|
|
||||||
DeviceID = Tuple[int, str, int] # rank, node resource (node IP), device id
|
|
||||||
|
|
||||||
|
|
||||||
class Controller:
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
stage_id: int,
|
|
||||||
stage_devices: List[DeviceID],
|
|
||||||
world_size: int,
|
|
||||||
tensor_parallel_size: int,
|
|
||||||
pipeline_parallel_size: int,
|
|
||||||
distributed_init_method: str,
|
|
||||||
model_name: str,
|
|
||||||
block_size: int,
|
|
||||||
num_gpu_blocks: int,
|
|
||||||
num_cpu_blocks: int,
|
|
||||||
dtype: str,
|
|
||||||
seed: int,
|
|
||||||
model_path: str,
|
|
||||||
use_dummy_weights: bool,
|
|
||||||
max_num_batched_tokens: int,
|
|
||||||
) -> None:
|
|
||||||
self.stage_id = stage_id
|
|
||||||
self.stage_devices = stage_devices
|
|
||||||
self.model_name = model_name
|
|
||||||
self.block_size = block_size
|
|
||||||
self.num_gpu_blocks = num_gpu_blocks
|
|
||||||
self.num_cpu_blocks = num_cpu_blocks
|
|
||||||
|
|
||||||
# Which pipeline stage is this node assigned to?
|
|
||||||
self.is_first_stage = stage_id == 0
|
|
||||||
self.is_last_stage = False
|
|
||||||
|
|
||||||
self.workers: List[Worker] = []
|
|
||||||
for rank, node_resource, device_id in stage_devices:
|
|
||||||
worker_cls = ray.remote(num_cpus=0,
|
|
||||||
num_gpus=1,
|
|
||||||
resources={node_resource: 1e-5})(Worker)
|
|
||||||
worker = worker_cls.remote(
|
|
||||||
model_name=model_name,
|
|
||||||
block_size=block_size,
|
|
||||||
num_gpu_blocks=num_gpu_blocks,
|
|
||||||
num_cpu_blocks=num_cpu_blocks,
|
|
||||||
dtype=dtype,
|
|
||||||
seed=seed,
|
|
||||||
distributed_init_method=distributed_init_method,
|
|
||||||
rank=rank,
|
|
||||||
world_size=world_size,
|
|
||||||
tensor_parallel_size=tensor_parallel_size,
|
|
||||||
pipeline_parallel_size=pipeline_parallel_size,
|
|
||||||
model_path=model_path,
|
|
||||||
use_dummy_weights=use_dummy_weights,
|
|
||||||
max_num_batched_tokens=max_num_batched_tokens,
|
|
||||||
)
|
|
||||||
self.workers.append(worker)
|
|
||||||
|
|
||||||
def set_next(
|
|
||||||
self,
|
|
||||||
next_node: Union['Controller', 'Scheduler'],
|
|
||||||
) -> None:
|
|
||||||
self.next_node = next_node
|
|
||||||
self.is_last_stage = isinstance(next_node, Scheduler)
|
|
||||||
|
|
||||||
def execute_stage(
|
|
||||||
self,
|
|
||||||
input_seq_groups: List[SequenceGroupInputs],
|
|
||||||
blocks_to_swap_in: Dict[int, int],
|
|
||||||
blocks_to_swap_out: Dict[int, int],
|
|
||||||
blocks_to_copy: Dict[int, List[int]],
|
|
||||||
) -> None:
|
|
||||||
futures = []
|
|
||||||
for worker in self.workers:
|
|
||||||
future = worker.execute_stage.remote(
|
|
||||||
input_seq_groups,
|
|
||||||
blocks_to_swap_in,
|
|
||||||
blocks_to_swap_out,
|
|
||||||
blocks_to_copy,
|
|
||||||
)
|
|
||||||
futures.append(future)
|
|
||||||
|
|
||||||
all_outputs = ray.get(futures)
|
|
||||||
# Make sure all workers have the same results.
|
|
||||||
output = all_outputs[0]
|
|
||||||
for other_output in all_outputs[1:]:
|
|
||||||
assert output == other_output
|
|
||||||
|
|
||||||
if self.is_last_stage:
|
|
||||||
self.next_node.post_step(output)
|
|
||||||
else:
|
|
||||||
# TODO: Support pipeline parallelism.
|
|
||||||
assert False
|
|
||||||
@ -1,264 +0,0 @@
|
|||||||
from typing import Dict, List, Tuple
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from cacheflow.models import get_model
|
|
||||||
from cacheflow.models import InputMetadata
|
|
||||||
from cacheflow.sampling_params import SamplingParams
|
|
||||||
from cacheflow.sequence import SequenceGroupInputs
|
|
||||||
from cacheflow.sequence import SequenceOutputs
|
|
||||||
from cacheflow.worker.cache_engine import CacheEngine
|
|
||||||
from cacheflow.parallel_utils.parallel_state import (
|
|
||||||
initialize_model_parallel,
|
|
||||||
initialize_all_reduce_launcher,
|
|
||||||
get_tensor_model_parallel_world_size)
|
|
||||||
from cacheflow.utils import set_random_seed
|
|
||||||
|
|
||||||
|
|
||||||
class Worker:
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_name: str,
|
|
||||||
block_size: int,
|
|
||||||
num_gpu_blocks: int,
|
|
||||||
num_cpu_blocks: int,
|
|
||||||
dtype: str,
|
|
||||||
seed: int,
|
|
||||||
distributed_init_method: str,
|
|
||||||
rank: int,
|
|
||||||
world_size: int,
|
|
||||||
model_path: str,
|
|
||||||
use_dummy_weights: bool,
|
|
||||||
max_num_batched_tokens: int,
|
|
||||||
tensor_parallel_size: int = 1,
|
|
||||||
pipeline_parallel_size: int = 1,
|
|
||||||
) -> None:
|
|
||||||
self.init_distributed_environment(distributed_init_method,
|
|
||||||
rank,
|
|
||||||
world_size,
|
|
||||||
tensor_parallel_size,
|
|
||||||
pipeline_parallel_size)
|
|
||||||
self.worker_id = rank
|
|
||||||
self.block_size = block_size
|
|
||||||
set_random_seed(seed)
|
|
||||||
|
|
||||||
# Initialize the model.
|
|
||||||
self.model, self.dtype = get_model(
|
|
||||||
model_name, dtype=dtype, path=model_path, use_dummy_weights=use_dummy_weights)
|
|
||||||
tensor_model_parallel_world_size = (
|
|
||||||
get_tensor_model_parallel_world_size())
|
|
||||||
initialize_all_reduce_launcher(
|
|
||||||
max_num_batched_tokens, self.model.config.hidden_size, self.dtype)
|
|
||||||
self.num_layers = self.model.config.num_hidden_layers
|
|
||||||
assert self.model.config.num_attention_heads % tensor_model_parallel_world_size == 0
|
|
||||||
self.num_heads = self.model.config.num_attention_heads // tensor_model_parallel_world_size
|
|
||||||
self.head_size = self.model.config.hidden_size // (self.num_heads * tensor_model_parallel_world_size)
|
|
||||||
|
|
||||||
# We reset the seed after initializing the model to ensure that
|
|
||||||
# the random state is not affected by the model initialization.
|
|
||||||
set_random_seed(seed)
|
|
||||||
|
|
||||||
self.cache_engine = CacheEngine(
|
|
||||||
worker_id=self.worker_id,
|
|
||||||
num_layers=self.num_layers,
|
|
||||||
num_heads=self.num_heads,
|
|
||||||
head_size=self.head_size,
|
|
||||||
block_size=block_size,
|
|
||||||
num_gpu_blocks=num_gpu_blocks,
|
|
||||||
num_cpu_blocks=num_cpu_blocks,
|
|
||||||
dtype=self.dtype,
|
|
||||||
)
|
|
||||||
self.cache_events = self.cache_engine.events
|
|
||||||
self.gpu_cache = self.cache_engine.gpu_cache
|
|
||||||
|
|
||||||
|
|
||||||
def init_distributed_environment(self,
|
|
||||||
distributed_init_method: str,
|
|
||||||
rank: int,
|
|
||||||
world_size: int,
|
|
||||||
tensor_parallel_size: int = 1,
|
|
||||||
pipeline_parallel_size: int = 1) -> None:
|
|
||||||
"""Initialize the distributed environment."""
|
|
||||||
torch.distributed.init_process_group(
|
|
||||||
backend='nccl',
|
|
||||||
init_method=distributed_init_method,
|
|
||||||
world_size=world_size,
|
|
||||||
rank=rank,
|
|
||||||
)
|
|
||||||
# A small all_reduce for warmup.
|
|
||||||
torch.distributed.all_reduce(torch.zeros(1).cuda())
|
|
||||||
initialize_model_parallel(tensor_parallel_size,
|
|
||||||
pipeline_parallel_size)
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_inputs(
|
|
||||||
self,
|
|
||||||
input_seq_groups: List[SequenceGroupInputs],
|
|
||||||
) -> Tuple[torch.LongTensor, torch.LongTensor, InputMetadata]:
|
|
||||||
seq_groups: List[Tuple[List[int], SamplingParams]] = []
|
|
||||||
seq_logprobs: Dict[int, float] = {}
|
|
||||||
sampling_params: Dict[int, SamplingParams] = {}
|
|
||||||
input_tokens: List[int] = []
|
|
||||||
input_positions: List[int] = []
|
|
||||||
slot_mapping: List[int] = []
|
|
||||||
|
|
||||||
# Add prompt tokens.
|
|
||||||
prompt_lens: List[int] = []
|
|
||||||
for input_seq_group in input_seq_groups:
|
|
||||||
if not input_seq_group.is_prompt:
|
|
||||||
continue
|
|
||||||
|
|
||||||
seq_ids = list(input_seq_group.input_tokens.keys())
|
|
||||||
sampling_params = input_seq_group.sampling_params
|
|
||||||
seq_groups.append((seq_ids, sampling_params))
|
|
||||||
seq_logprobs.update(input_seq_group.seq_logprobs)
|
|
||||||
|
|
||||||
# Use any sequence in the group.
|
|
||||||
seq_id = seq_ids[0]
|
|
||||||
|
|
||||||
prompt_tokens = input_seq_group.input_tokens[seq_id]
|
|
||||||
prompt_len = len(prompt_tokens)
|
|
||||||
prompt_lens.append(prompt_len)
|
|
||||||
|
|
||||||
input_tokens.extend(prompt_tokens)
|
|
||||||
# NOTE(woosuk): Here we assume that the first token in the prompt
|
|
||||||
# is always the first token in the sequence.
|
|
||||||
input_positions.extend(range(len(prompt_tokens)))
|
|
||||||
|
|
||||||
# Compute the slot mapping.
|
|
||||||
block_table = input_seq_group.block_tables[seq_id]
|
|
||||||
for i in range(prompt_len):
|
|
||||||
block_number = block_table[i // self.block_size]
|
|
||||||
block_offset = i % self.block_size
|
|
||||||
slot = block_number * self.block_size + block_offset
|
|
||||||
slot_mapping.append(slot)
|
|
||||||
|
|
||||||
cumulative_prompt_lens: List[int] = [0]
|
|
||||||
for prompt_len in prompt_lens:
|
|
||||||
cumulative_prompt_lens.append(
|
|
||||||
cumulative_prompt_lens[-1] + prompt_len)
|
|
||||||
|
|
||||||
# Add generation tokens.
|
|
||||||
max_context_len = 0
|
|
||||||
max_num_blocks_per_seq = 0
|
|
||||||
context_lens: List[int] = []
|
|
||||||
generation_block_tables: List[List[int]] = []
|
|
||||||
for input_seq_group in input_seq_groups:
|
|
||||||
if input_seq_group.is_prompt:
|
|
||||||
continue
|
|
||||||
|
|
||||||
seq_ids = list(input_seq_group.input_tokens.keys())
|
|
||||||
sampling_params = input_seq_group.sampling_params
|
|
||||||
seq_groups.append((seq_ids, sampling_params))
|
|
||||||
seq_logprobs.update(input_seq_group.seq_logprobs)
|
|
||||||
|
|
||||||
for seq_id in seq_ids:
|
|
||||||
assert len(input_seq_group.input_tokens[seq_id]) == 1
|
|
||||||
generation_token = input_seq_group.input_tokens[seq_id][0]
|
|
||||||
input_tokens.append(generation_token)
|
|
||||||
|
|
||||||
position = input_seq_group.context_len - 1
|
|
||||||
input_positions.append(position)
|
|
||||||
|
|
||||||
block_table = input_seq_group.block_tables[seq_id]
|
|
||||||
generation_block_tables.append(block_table)
|
|
||||||
|
|
||||||
max_context_len = max(
|
|
||||||
max_context_len, input_seq_group.context_len)
|
|
||||||
max_num_blocks_per_seq = max(
|
|
||||||
max_num_blocks_per_seq, len(block_table))
|
|
||||||
context_lens.append(input_seq_group.context_len)
|
|
||||||
|
|
||||||
block_number = block_table[position // self.block_size]
|
|
||||||
block_offset = position % self.block_size
|
|
||||||
slot = block_number * self.block_size + block_offset
|
|
||||||
slot_mapping.append(slot)
|
|
||||||
|
|
||||||
# Optimization: Pad the input length to be a multiple of 8.
|
|
||||||
# This is required for utilizing the Tensor Cores in NVIDIA GPUs.
|
|
||||||
input_tokens = _pad_to_alignment(input_tokens, multiple_of=8)
|
|
||||||
input_positions = _pad_to_alignment(input_positions, multiple_of=8)
|
|
||||||
|
|
||||||
# Convert to tensors.
|
|
||||||
tokens_tensor = torch.tensor(
|
|
||||||
input_tokens, dtype=torch.long, device='cuda')
|
|
||||||
positions_tensor = torch.tensor(
|
|
||||||
input_positions, dtype=torch.long, device='cuda')
|
|
||||||
slot_mapping_tensor = torch.tensor(
|
|
||||||
slot_mapping, dtype=torch.int, device='cuda')
|
|
||||||
context_lens_tensor = torch.tensor(
|
|
||||||
context_lens, dtype=torch.int, device='cuda')
|
|
||||||
padded_block_tables = [
|
|
||||||
_pad_to_max(block_table, max_num_blocks_per_seq)
|
|
||||||
for block_table in generation_block_tables]
|
|
||||||
block_tables_tensor = torch.tensor(
|
|
||||||
padded_block_tables, dtype=torch.int, device='cuda')
|
|
||||||
cumulative_prompt_lens_tensor = torch.tensor(
|
|
||||||
cumulative_prompt_lens, dtype=torch.int, device='cuda')
|
|
||||||
|
|
||||||
input_metadata = InputMetadata(
|
|
||||||
seq_groups=seq_groups,
|
|
||||||
seq_logprobs=seq_logprobs,
|
|
||||||
prompt_lens=prompt_lens,
|
|
||||||
cumulative_prompt_lens=cumulative_prompt_lens_tensor,
|
|
||||||
slot_mapping=slot_mapping_tensor,
|
|
||||||
context_lens=context_lens_tensor,
|
|
||||||
max_context_len=max_context_len,
|
|
||||||
block_tables=block_tables_tensor,
|
|
||||||
)
|
|
||||||
return tokens_tensor, positions_tensor, input_metadata
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
|
||||||
def execute_stage(
|
|
||||||
self,
|
|
||||||
input_seq_groups: List[SequenceGroupInputs],
|
|
||||||
blocks_to_swap_in: Dict[int, int],
|
|
||||||
blocks_to_swap_out: Dict[int, int],
|
|
||||||
blocks_to_copy: Dict[int, List[int]],
|
|
||||||
) -> Dict[int, SequenceOutputs]:
|
|
||||||
# Issue cache operations.
|
|
||||||
command_issued = False
|
|
||||||
if blocks_to_swap_in:
|
|
||||||
self.cache_engine.swap_in(blocks_to_swap_in)
|
|
||||||
command_issued = True
|
|
||||||
if blocks_to_swap_out:
|
|
||||||
self.cache_engine.swap_out(blocks_to_swap_out)
|
|
||||||
command_issued = True
|
|
||||||
if blocks_to_copy:
|
|
||||||
self.cache_engine.copy(blocks_to_copy)
|
|
||||||
command_issued = True
|
|
||||||
|
|
||||||
if command_issued:
|
|
||||||
cache_events = self.cache_events
|
|
||||||
else:
|
|
||||||
cache_events = None
|
|
||||||
|
|
||||||
# If there is no input, we don't need to execute the model.
|
|
||||||
if not input_seq_groups:
|
|
||||||
if cache_events is not None:
|
|
||||||
for event in cache_events:
|
|
||||||
event.wait()
|
|
||||||
return {}
|
|
||||||
|
|
||||||
# Prepare input tensors.
|
|
||||||
input_tokens, input_positions, input_metadata = self.prepare_inputs(
|
|
||||||
input_seq_groups)
|
|
||||||
|
|
||||||
# Execute the model.
|
|
||||||
output = self.model(
|
|
||||||
input_ids=input_tokens,
|
|
||||||
positions=input_positions,
|
|
||||||
kv_caches=self.gpu_cache,
|
|
||||||
input_metadata=input_metadata,
|
|
||||||
cache_events=cache_events,
|
|
||||||
)
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
def _pad_to_alignment(x: List[int], multiple_of: int) -> List[int]:
|
|
||||||
return x + [0] * ((-len(x)) % multiple_of)
|
|
||||||
|
|
||||||
|
|
||||||
def _pad_to_max(x: List[int], max_len: int) -> List[int]:
|
|
||||||
return x + [0] * (max_len - len(x))
|
|
||||||
@ -4,9 +4,25 @@ void silu_and_mul(
|
|||||||
torch::Tensor& out,
|
torch::Tensor& out,
|
||||||
torch::Tensor& input);
|
torch::Tensor& input);
|
||||||
|
|
||||||
|
void gelu_new(
|
||||||
|
torch::Tensor& out,
|
||||||
|
torch::Tensor& input);
|
||||||
|
|
||||||
|
void gelu_fast(
|
||||||
|
torch::Tensor& out,
|
||||||
|
torch::Tensor& input);
|
||||||
|
|
||||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||||
m.def(
|
m.def(
|
||||||
"silu_and_mul",
|
"silu_and_mul",
|
||||||
&silu_and_mul,
|
&silu_and_mul,
|
||||||
"Activation function used in SwiGLU.");
|
"Activation function used in SwiGLU.");
|
||||||
|
m.def(
|
||||||
|
"gelu_new",
|
||||||
|
&gelu_new,
|
||||||
|
"GELU implementation used in GPT-2.");
|
||||||
|
m.def(
|
||||||
|
"gelu_fast",
|
||||||
|
&gelu_fast,
|
||||||
|
"Approximate GELU implementation.");
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,7 +1,9 @@
|
|||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
|
||||||
namespace cacheflow {
|
#include "dispatch_utils.h"
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
__device__ __forceinline__ T silu(const T& x) {
|
__device__ __forceinline__ T silu(const T& x) {
|
||||||
@ -22,7 +24,7 @@ __global__ void silu_and_mul_kernel(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace cacheflow
|
} // namespace vllm
|
||||||
|
|
||||||
void silu_and_mul(
|
void silu_and_mul(
|
||||||
torch::Tensor& out, // [num_tokens, d]
|
torch::Tensor& out, // [num_tokens, d]
|
||||||
@ -34,13 +36,79 @@ void silu_and_mul(
|
|||||||
dim3 grid(num_tokens);
|
dim3 grid(num_tokens);
|
||||||
dim3 block(std::min(d, 1024));
|
dim3 block(std::min(d, 1024));
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
VLLM_DISPATCH_FLOATING_TYPES(
|
||||||
input.scalar_type(),
|
input.scalar_type(),
|
||||||
"silu_and_mul_kernel",
|
"silu_and_mul_kernel",
|
||||||
[&] {
|
[&] {
|
||||||
cacheflow::silu_and_mul_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
vllm::silu_and_mul_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||||
out.data_ptr<scalar_t>(),
|
out.data_ptr<scalar_t>(),
|
||||||
input.data_ptr<scalar_t>(),
|
input.data_ptr<scalar_t>(),
|
||||||
d);
|
d);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
|
||||||
|
// Element-wise activation kernel template.
|
||||||
|
template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
|
||||||
|
__global__ void activation_kernel(
|
||||||
|
scalar_t* __restrict__ out, // [num_tokens, d]
|
||||||
|
const scalar_t* __restrict__ input, // [num_tokens, d]
|
||||||
|
const int d) {
|
||||||
|
const int token_idx = blockIdx.x;
|
||||||
|
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||||
|
const scalar_t x = __ldg(&input[token_idx * d + idx]);
|
||||||
|
out[token_idx * d + idx] = ACT_FN(x);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace vllm
|
||||||
|
|
||||||
|
// Launch element-wise activation kernel.
|
||||||
|
#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
|
||||||
|
int num_tokens = input.size(0); \
|
||||||
|
int d = input.size(1); \
|
||||||
|
dim3 grid(num_tokens); \
|
||||||
|
dim3 block(std::min(d, 1024)); \
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|
||||||
|
VLLM_DISPATCH_FLOATING_TYPES( \
|
||||||
|
input.scalar_type(), \
|
||||||
|
"activation_kernel", \
|
||||||
|
[&] { \
|
||||||
|
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>( \
|
||||||
|
out.data_ptr<scalar_t>(), \
|
||||||
|
input.data_ptr<scalar_t>(), \
|
||||||
|
d); \
|
||||||
|
});
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
__device__ __forceinline__ T gelu_new_kernel(const T& x) {
|
||||||
|
const float x3 = (float) (x * x * x);
|
||||||
|
const T t = (T) tanhf((T) (0.79788456f * (float) (x + (T) (0.044715f * x3))));
|
||||||
|
return ((T) 0.5) * x * (((T) 1.0) + t);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
__device__ __forceinline__ T gelu_fast_kernel(const T& x) {
|
||||||
|
const float f = (float) x;
|
||||||
|
const T t = (T) tanhf(((T) (f * 0.79788456f)) * (((T) 1.0) + (T) (0.044715f * f) * x));
|
||||||
|
return ((T) 0.5) * x * (((T) 1.0) + t);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace vllm
|
||||||
|
|
||||||
|
void gelu_new(
|
||||||
|
torch::Tensor& out, // [num_tokens, d]
|
||||||
|
torch::Tensor& input) // [num_tokens, d]
|
||||||
|
{
|
||||||
|
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel);
|
||||||
|
}
|
||||||
|
|
||||||
|
void gelu_fast(
|
||||||
|
torch::Tensor& out, // [num_tokens, d]
|
||||||
|
torch::Tensor& input) // [num_tokens, d]
|
||||||
|
{
|
||||||
|
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel);
|
||||||
|
}
|
||||||
|
|||||||
@ -1,19 +1,42 @@
|
|||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
|
#include <c10/util/Optional.h>
|
||||||
|
|
||||||
void single_query_cached_kv_attention(
|
void paged_attention_v1(
|
||||||
torch::Tensor& out,
|
torch::Tensor& out,
|
||||||
torch::Tensor& query,
|
torch::Tensor& query,
|
||||||
torch::Tensor& key_cache,
|
torch::Tensor& key_cache,
|
||||||
torch::Tensor& value_cache,
|
torch::Tensor& value_cache,
|
||||||
|
torch::Tensor& head_mapping,
|
||||||
float scale,
|
float scale,
|
||||||
torch::Tensor& block_tables,
|
torch::Tensor& block_tables,
|
||||||
torch::Tensor& context_lens,
|
torch::Tensor& context_lens,
|
||||||
int block_size,
|
int block_size,
|
||||||
int max_context_len);
|
int max_context_len,
|
||||||
|
const c10::optional<torch::Tensor>& alibi_slopes);
|
||||||
|
|
||||||
|
void paged_attention_v2(
|
||||||
|
torch::Tensor& out,
|
||||||
|
torch::Tensor& exp_sums,
|
||||||
|
torch::Tensor& max_logits,
|
||||||
|
torch::Tensor& tmp_out,
|
||||||
|
torch::Tensor& query,
|
||||||
|
torch::Tensor& key_cache,
|
||||||
|
torch::Tensor& value_cache,
|
||||||
|
torch::Tensor& head_mapping,
|
||||||
|
float scale,
|
||||||
|
torch::Tensor& block_tables,
|
||||||
|
torch::Tensor& context_lens,
|
||||||
|
int block_size,
|
||||||
|
int max_context_len,
|
||||||
|
const c10::optional<torch::Tensor>& alibi_slopes);
|
||||||
|
|
||||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||||
m.def(
|
m.def(
|
||||||
"single_query_cached_kv_attention",
|
"paged_attention_v1",
|
||||||
&single_query_cached_kv_attention,
|
&paged_attention_v1,
|
||||||
"Compute the attention between an input query and the cached key/value tensors");
|
"Compute the attention between an input query and the cached keys/values using PagedAttention.");
|
||||||
|
m.def(
|
||||||
|
"paged_attention_v2",
|
||||||
|
&paged_attention_v2,
|
||||||
|
"PagedAttention V2.");
|
||||||
}
|
}
|
||||||
|
|||||||
6
csrc/attention/attention_dtypes.h
Normal file
6
csrc/attention/attention_dtypes.h
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "attention_generic.cuh"
|
||||||
|
#include "dtype_float16.cuh"
|
||||||
|
#include "dtype_float32.cuh"
|
||||||
|
#include "dtype_bfloat16.cuh"
|
||||||
64
csrc/attention/attention_generic.cuh
Normal file
64
csrc/attention/attention_generic.cuh
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
/*
|
||||||
|
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
|
||||||
|
* Copyright (c) 2023, The vLLM team.
|
||||||
|
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
|
||||||
|
// A vector type to store Q, K, V elements.
|
||||||
|
template<typename T, int VEC_SIZE>
|
||||||
|
struct Vec {};
|
||||||
|
|
||||||
|
// A vector type to store FP32 accumulators.
|
||||||
|
template<typename T>
|
||||||
|
struct FloatVec {};
|
||||||
|
|
||||||
|
// Template vector operations.
|
||||||
|
template<typename Acc, typename A, typename B>
|
||||||
|
inline __device__ Acc mul(A a, B b);
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
inline __device__ float sum(T v);
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
inline __device__ float dot(T a, T b) {
|
||||||
|
return sum(mul<T, T, T>(a, b));
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename A, typename T>
|
||||||
|
inline __device__ float dot(T a, T b) {
|
||||||
|
return sum(mul<A, T, T>(a, b));
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
inline __device__ void zero(T& dst) {
|
||||||
|
constexpr int WORDS = sizeof(T) / 4;
|
||||||
|
union {
|
||||||
|
T raw;
|
||||||
|
uint32_t words[WORDS];
|
||||||
|
} tmp;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int ii = 0; ii < WORDS; ++ii) {
|
||||||
|
tmp.words[ii] = 0u;
|
||||||
|
}
|
||||||
|
dst = tmp.raw;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace vllm
|
||||||
866
csrc/attention/attention_kernels.cu
Normal file
866
csrc/attention/attention_kernels.cu
Normal file
@ -0,0 +1,866 @@
|
|||||||
|
/*
|
||||||
|
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
||||||
|
* Copyright (c) 2023, The vLLM team.
|
||||||
|
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include <torch/extension.h>
|
||||||
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
|
||||||
|
#include "attention_dtypes.h"
|
||||||
|
#include "attention_utils.cuh"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
|
||||||
|
#define WARP_SIZE 32
|
||||||
|
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||||
|
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||||
|
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
|
||||||
|
// Utility function for attention softmax.
|
||||||
|
template<int NUM_WARPS>
|
||||||
|
inline __device__ float block_sum(float* red_smem, float sum) {
|
||||||
|
// Decompose the thread index into warp / lane.
|
||||||
|
int warp = threadIdx.x / WARP_SIZE;
|
||||||
|
int lane = threadIdx.x % WARP_SIZE;
|
||||||
|
|
||||||
|
// Compute the sum per warp.
|
||||||
|
#pragma unroll
|
||||||
|
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
|
||||||
|
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Warp leaders store the data to shared memory.
|
||||||
|
if (lane == 0) {
|
||||||
|
red_smem[warp] = sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make sure the data is in shared memory.
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// The warps compute the final sums.
|
||||||
|
if (lane < NUM_WARPS) {
|
||||||
|
sum = red_smem[lane];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parallel reduction inside the warp.
|
||||||
|
#pragma unroll
|
||||||
|
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
|
||||||
|
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Broadcast to other threads.
|
||||||
|
return __shfl_sync(uint32_t(-1), sum, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(woosuk): Merge the last two dimensions of the grid.
|
||||||
|
// Grid: (num_heads, num_seqs, max_num_partitions).
|
||||||
|
template<
|
||||||
|
typename scalar_t,
|
||||||
|
int HEAD_SIZE,
|
||||||
|
int BLOCK_SIZE,
|
||||||
|
int NUM_THREADS,
|
||||||
|
int PARTITION_SIZE = 0> // Zero means no partitioning.
|
||||||
|
__device__ void paged_attention_kernel(
|
||||||
|
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
||||||
|
float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
|
||||||
|
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size]
|
||||||
|
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||||
|
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
||||||
|
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
||||||
|
const int* __restrict__ head_mapping, // [num_heads]
|
||||||
|
const float scale,
|
||||||
|
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||||
|
const int* __restrict__ context_lens, // [num_seqs]
|
||||||
|
const int max_num_blocks_per_seq,
|
||||||
|
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||||
|
const int q_stride,
|
||||||
|
const int kv_block_stride,
|
||||||
|
const int kv_head_stride) {
|
||||||
|
const int seq_idx = blockIdx.y;
|
||||||
|
const int partition_idx = blockIdx.z;
|
||||||
|
const int max_num_partitions = gridDim.z;
|
||||||
|
constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0;
|
||||||
|
const int context_len = context_lens[seq_idx];
|
||||||
|
if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= context_len) {
|
||||||
|
// No work to do. Terminate the thread block.
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
|
||||||
|
const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_context_blocks;
|
||||||
|
|
||||||
|
// [start_block_idx, end_block_idx) is the range of blocks to process.
|
||||||
|
const int start_block_idx = USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0;
|
||||||
|
const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_context_blocks);
|
||||||
|
const int num_blocks = end_block_idx - start_block_idx;
|
||||||
|
|
||||||
|
// [start_token_idx, end_token_idx) is the range of tokens to process.
|
||||||
|
const int start_token_idx = start_block_idx * BLOCK_SIZE;
|
||||||
|
const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, context_len);
|
||||||
|
const int num_tokens = end_token_idx - start_token_idx;
|
||||||
|
|
||||||
|
constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
|
||||||
|
constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS
|
||||||
|
assert(NUM_THREADS % THREAD_GROUP_SIZE == 0);
|
||||||
|
constexpr int NUM_TOKENS_PER_THREAD_GROUP = DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE);
|
||||||
|
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
||||||
|
const int thread_idx = threadIdx.x;
|
||||||
|
const int warp_idx = thread_idx / WARP_SIZE;
|
||||||
|
const int lane = thread_idx % WARP_SIZE;
|
||||||
|
|
||||||
|
const int head_idx = blockIdx.x;
|
||||||
|
const int num_heads = gridDim.x;
|
||||||
|
const int kv_head_idx = head_mapping[head_idx];
|
||||||
|
const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
|
||||||
|
|
||||||
|
// A vector type to store a part of a key or a query.
|
||||||
|
// The vector size is configured in such a way that the threads in a thread group
|
||||||
|
// fetch or compute 16 bytes at a time.
|
||||||
|
// For example, if the size of a thread group is 4 and the data type is half,
|
||||||
|
// then the vector size is 16 / (4 * sizeof(half)) == 2.
|
||||||
|
constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
|
||||||
|
using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
|
||||||
|
using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
|
||||||
|
|
||||||
|
constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
|
||||||
|
constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
|
||||||
|
|
||||||
|
const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE;
|
||||||
|
const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE;
|
||||||
|
|
||||||
|
// Load the query to registers.
|
||||||
|
// Each thread in a thread group has a different part of the query.
|
||||||
|
// For example, if the the thread group size is 4, then the first thread in the group
|
||||||
|
// has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ...
|
||||||
|
// th vectors of the query, and so on.
|
||||||
|
// NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous.
|
||||||
|
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
|
||||||
|
__shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS) {
|
||||||
|
const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
|
||||||
|
q_vecs[thread_group_offset][i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
|
||||||
|
}
|
||||||
|
__syncthreads(); // TODO(naed90): possible speedup if this is replaced with a memory wall right before we use q_vecs
|
||||||
|
|
||||||
|
// Memory planning.
|
||||||
|
extern __shared__ char shared_mem[];
|
||||||
|
// NOTE(woosuk): We use FP32 for the softmax logits for better accuracy.
|
||||||
|
float* logits = reinterpret_cast<float*>(shared_mem);
|
||||||
|
// Workspace for reduction.
|
||||||
|
__shared__ float red_smem[2 * NUM_WARPS];
|
||||||
|
|
||||||
|
// x == THREAD_GROUP_SIZE * VEC_SIZE
|
||||||
|
// Each thread group fetches x elements from the key at a time.
|
||||||
|
constexpr int x = 16 / sizeof(scalar_t);
|
||||||
|
float qk_max = -FLT_MAX;
|
||||||
|
|
||||||
|
// Iterate over the key blocks.
|
||||||
|
// Each warp fetches a block of keys for each iteration.
|
||||||
|
// Each thread group in a warp fetches a key from the block, and computes
|
||||||
|
// dot product with the query.
|
||||||
|
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
|
||||||
|
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
|
||||||
|
const int physical_block_number = block_table[block_idx];
|
||||||
|
|
||||||
|
// Load a key to registers.
|
||||||
|
// Each thread in a thread group has a different part of the key.
|
||||||
|
// For example, if the the thread group size is 4, then the first thread in the group
|
||||||
|
// has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th
|
||||||
|
// vectors of the key, and so on.
|
||||||
|
for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
|
||||||
|
const int physical_block_offset = (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
|
||||||
|
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
|
||||||
|
K_vec k_vecs[NUM_VECS_PER_THREAD];
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
|
||||||
|
const scalar_t* k_ptr = k_cache + physical_block_number * kv_block_stride
|
||||||
|
+ kv_head_idx * kv_head_stride
|
||||||
|
+ physical_block_offset * x;
|
||||||
|
const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
|
||||||
|
const int offset1 = (vec_idx * VEC_SIZE) / x;
|
||||||
|
const int offset2 = (vec_idx * VEC_SIZE) % x;
|
||||||
|
k_vecs[j] = *reinterpret_cast<const K_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute dot product.
|
||||||
|
// This includes a reduction across the threads in the same thread group.
|
||||||
|
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs);
|
||||||
|
// Add the ALiBi bias if slopes are given.
|
||||||
|
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0;
|
||||||
|
|
||||||
|
if (thread_group_offset == 0) {
|
||||||
|
// Store the partial reductions to shared memory.
|
||||||
|
// NOTE(woosuk): It is required to zero out the masked logits.
|
||||||
|
const bool mask = token_idx >= context_len;
|
||||||
|
logits[token_idx - start_token_idx] = mask ? 0.f : qk;
|
||||||
|
// Update the max value.
|
||||||
|
qk_max = mask ? qk_max : fmaxf(qk_max, qk);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Perform reduction across the threads in the same warp to get the
|
||||||
|
// max qk value for each "warp" (not across the thread block yet).
|
||||||
|
// The 0-th thread of each thread group already has its max qk value.
|
||||||
|
#pragma unroll
|
||||||
|
for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
|
||||||
|
qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
|
||||||
|
}
|
||||||
|
if (lane == 0) {
|
||||||
|
red_smem[warp_idx] = qk_max;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// TODO(woosuk): Refactor this part.
|
||||||
|
// Get the max qk value for the sequence.
|
||||||
|
qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
|
||||||
|
#pragma unroll
|
||||||
|
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
|
||||||
|
qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
|
||||||
|
}
|
||||||
|
// Broadcast the max qk value to all threads.
|
||||||
|
qk_max = __shfl_sync(uint32_t(-1), qk_max, 0);
|
||||||
|
|
||||||
|
// Get the sum of the exp values.
|
||||||
|
float exp_sum = 0.f;
|
||||||
|
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
|
||||||
|
float val = __expf(logits[i] - qk_max);
|
||||||
|
logits[i] = val;
|
||||||
|
exp_sum += val;
|
||||||
|
}
|
||||||
|
exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum);
|
||||||
|
|
||||||
|
// Compute softmax.
|
||||||
|
const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
|
||||||
|
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
|
||||||
|
logits[i] *= inv_sum;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// If partitioning is enabled, store the max logit and exp_sum.
|
||||||
|
if (USE_PARTITIONING && thread_idx == 0) {
|
||||||
|
float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions
|
||||||
|
+ head_idx * max_num_partitions
|
||||||
|
+ partition_idx;
|
||||||
|
*max_logits_ptr = qk_max;
|
||||||
|
float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions
|
||||||
|
+ head_idx * max_num_partitions
|
||||||
|
+ partition_idx;
|
||||||
|
*exp_sums_ptr = exp_sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Each thread will fetch 16 bytes from the value cache at a time.
|
||||||
|
constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
|
||||||
|
using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
|
||||||
|
using L_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
|
||||||
|
using Float_L_vec = typename FloatVec<L_vec>::Type;
|
||||||
|
|
||||||
|
constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
|
||||||
|
constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
|
||||||
|
constexpr int NUM_ROWS_PER_THREAD = DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER);
|
||||||
|
|
||||||
|
// NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
|
||||||
|
float accs[NUM_ROWS_PER_THREAD];
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
||||||
|
accs[i] = 0.f;
|
||||||
|
}
|
||||||
|
|
||||||
|
scalar_t zero_value;
|
||||||
|
zero(zero_value);
|
||||||
|
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
|
||||||
|
const int physical_block_number = block_table[block_idx];
|
||||||
|
const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
|
||||||
|
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
|
||||||
|
L_vec logits_vec;
|
||||||
|
from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx - start_token_idx));
|
||||||
|
|
||||||
|
const scalar_t* v_ptr = v_cache + physical_block_number * kv_block_stride
|
||||||
|
+ kv_head_idx * kv_head_stride;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
||||||
|
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
||||||
|
if (row_idx < HEAD_SIZE) {
|
||||||
|
const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
|
||||||
|
V_vec v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
|
||||||
|
if (block_idx == num_context_blocks - 1) {
|
||||||
|
// NOTE(woosuk): When v_vec contains the tokens that are out of the context,
|
||||||
|
// we should explicitly zero out the values since they may contain NaNs.
|
||||||
|
// See https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
|
||||||
|
scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < V_VEC_SIZE; j++) {
|
||||||
|
v_vec_ptr[j] = token_idx + j < context_len ? v_vec_ptr[j] : zero_value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
accs[i] += dot(logits_vec, v_vec);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Perform reduction within each warp.
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
||||||
|
float acc = accs[i];
|
||||||
|
#pragma unroll
|
||||||
|
for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
|
||||||
|
acc += __shfl_xor_sync(uint32_t(-1), acc, mask);
|
||||||
|
}
|
||||||
|
accs[i] = acc;
|
||||||
|
}
|
||||||
|
|
||||||
|
// NOTE(woosuk): A barrier is required because the shared memory space for logits
|
||||||
|
// is reused for the output.
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// Perform reduction across warps.
|
||||||
|
float* out_smem = reinterpret_cast<float*>(shared_mem);
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = NUM_WARPS; i > 1; i /= 2) {
|
||||||
|
int mid = i / 2;
|
||||||
|
// Upper warps write to shared memory.
|
||||||
|
if (warp_idx >= mid && warp_idx < i) {
|
||||||
|
float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
||||||
|
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
||||||
|
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
|
||||||
|
dst[row_idx] = accs[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// Lower warps update the output.
|
||||||
|
if (warp_idx < mid) {
|
||||||
|
const float* src = &out_smem[warp_idx * HEAD_SIZE];
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
||||||
|
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
||||||
|
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
|
||||||
|
accs[i] += src[row_idx];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write the final output.
|
||||||
|
if (warp_idx == 0) {
|
||||||
|
scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
|
||||||
|
+ head_idx * max_num_partitions * HEAD_SIZE
|
||||||
|
+ partition_idx * HEAD_SIZE;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
||||||
|
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
||||||
|
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
|
||||||
|
from_float(*(out_ptr + row_idx), accs[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Grid: (num_heads, num_seqs, 1).
|
||||||
|
template<
|
||||||
|
typename scalar_t,
|
||||||
|
int HEAD_SIZE,
|
||||||
|
int BLOCK_SIZE,
|
||||||
|
int NUM_THREADS>
|
||||||
|
__global__ void paged_attention_v1_kernel(
|
||||||
|
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
||||||
|
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||||
|
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
||||||
|
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
||||||
|
const int* __restrict__ head_mapping, // [num_heads]
|
||||||
|
const float scale,
|
||||||
|
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||||
|
const int* __restrict__ context_lens, // [num_seqs]
|
||||||
|
const int max_num_blocks_per_seq,
|
||||||
|
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||||
|
const int q_stride,
|
||||||
|
const int kv_block_stride,
|
||||||
|
const int kv_head_stride) {
|
||||||
|
paged_attention_kernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>(
|
||||||
|
/* exp_sums */ nullptr, /* max_logits */ nullptr,
|
||||||
|
out, q, k_cache, v_cache, head_mapping, scale, block_tables, context_lens,
|
||||||
|
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Grid: (num_heads, num_seqs, max_num_partitions).
|
||||||
|
template<
|
||||||
|
typename scalar_t,
|
||||||
|
int HEAD_SIZE,
|
||||||
|
int BLOCK_SIZE,
|
||||||
|
int NUM_THREADS,
|
||||||
|
int PARTITION_SIZE>
|
||||||
|
__global__ void paged_attention_v2_kernel(
|
||||||
|
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
||||||
|
float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
|
||||||
|
scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
|
||||||
|
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||||
|
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
||||||
|
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
||||||
|
const int* __restrict__ head_mapping, // [num_heads]
|
||||||
|
const float scale,
|
||||||
|
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||||
|
const int* __restrict__ context_lens, // [num_seqs]
|
||||||
|
const int max_num_blocks_per_seq,
|
||||||
|
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||||
|
const int q_stride,
|
||||||
|
const int kv_block_stride,
|
||||||
|
const int kv_head_stride) {
|
||||||
|
paged_attention_kernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, PARTITION_SIZE>(
|
||||||
|
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, head_mapping, scale,
|
||||||
|
block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes,
|
||||||
|
q_stride, kv_block_stride, kv_head_stride);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Grid: (num_heads, num_seqs).
|
||||||
|
template<
|
||||||
|
typename scalar_t,
|
||||||
|
int HEAD_SIZE,
|
||||||
|
int NUM_THREADS,
|
||||||
|
int PARTITION_SIZE>
|
||||||
|
__global__ void paged_attention_v2_reduce_kernel(
|
||||||
|
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
||||||
|
const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
||||||
|
const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
|
||||||
|
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
|
||||||
|
const int* __restrict__ context_lens, // [num_seqs]
|
||||||
|
const int max_num_partitions) {
|
||||||
|
const int num_heads = gridDim.x;
|
||||||
|
const int head_idx = blockIdx.x;
|
||||||
|
const int seq_idx = blockIdx.y;
|
||||||
|
const int context_len = context_lens[seq_idx];
|
||||||
|
const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE);
|
||||||
|
if (num_partitions == 1) {
|
||||||
|
// No need to reduce. Only copy tmp_out to out.
|
||||||
|
scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
|
||||||
|
const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
|
||||||
|
+ head_idx * max_num_partitions * HEAD_SIZE;
|
||||||
|
for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) {
|
||||||
|
out_ptr[i] = tmp_out_ptr[i];
|
||||||
|
}
|
||||||
|
// Terminate the thread block.
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
||||||
|
const int warp_idx = threadIdx.x / WARP_SIZE;
|
||||||
|
const int lane = threadIdx.x % WARP_SIZE;
|
||||||
|
|
||||||
|
// Size: 2 * num_partitions.
|
||||||
|
extern __shared__ char shared_mem[];
|
||||||
|
// Workspace for reduction.
|
||||||
|
__shared__ float red_smem[2 * NUM_WARPS];
|
||||||
|
|
||||||
|
// Load max logits to shared memory.
|
||||||
|
float* shared_max_logits = reinterpret_cast<float*>(shared_mem);
|
||||||
|
const float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions
|
||||||
|
+ head_idx * max_num_partitions;
|
||||||
|
float max_logit = -FLT_MAX;
|
||||||
|
for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
|
||||||
|
const float l = max_logits_ptr[i];
|
||||||
|
shared_max_logits[i] = l;
|
||||||
|
max_logit = fmaxf(max_logit, l);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// Get the global max logit.
|
||||||
|
// Reduce within the warp.
|
||||||
|
#pragma unroll
|
||||||
|
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
|
||||||
|
max_logit = fmaxf(max_logit, __shfl_xor_sync(uint32_t(-1), max_logit, mask));
|
||||||
|
}
|
||||||
|
if (lane == 0) {
|
||||||
|
red_smem[warp_idx] = max_logit;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
// Reduce across warps.
|
||||||
|
max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
|
||||||
|
#pragma unroll
|
||||||
|
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
|
||||||
|
max_logit = fmaxf(max_logit, __shfl_xor_sync(uint32_t(-1), max_logit, mask));
|
||||||
|
}
|
||||||
|
// Broadcast the max value to all threads.
|
||||||
|
max_logit = __shfl_sync(uint32_t(-1), max_logit, 0);
|
||||||
|
|
||||||
|
// Load rescaled exp sums to shared memory.
|
||||||
|
float* shared_exp_sums = reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions);
|
||||||
|
const float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions
|
||||||
|
+ head_idx * max_num_partitions;
|
||||||
|
float global_exp_sum = 0.0f;
|
||||||
|
for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
|
||||||
|
float l = shared_max_logits[i];
|
||||||
|
float rescaled_exp_sum = exp_sums_ptr[i] * expf(l - max_logit);
|
||||||
|
global_exp_sum += rescaled_exp_sum;
|
||||||
|
shared_exp_sums[i] = rescaled_exp_sum;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
global_exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], global_exp_sum);
|
||||||
|
const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f);
|
||||||
|
|
||||||
|
// Aggregate tmp_out to out.
|
||||||
|
const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
|
||||||
|
+ head_idx * max_num_partitions * HEAD_SIZE;
|
||||||
|
scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) {
|
||||||
|
float acc = 0.0f;
|
||||||
|
for (int j = 0; j < num_partitions; ++j) {
|
||||||
|
acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * inv_global_exp_sum;
|
||||||
|
}
|
||||||
|
from_float(out_ptr[i], acc);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace vllm
|
||||||
|
|
||||||
|
#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
|
||||||
|
cudaFuncSetAttribute( \
|
||||||
|
vllm::paged_attention_v1_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>, \
|
||||||
|
cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \
|
||||||
|
vllm::paged_attention_v1_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS> \
|
||||||
|
<<<grid, block, shared_mem_size, stream>>>( \
|
||||||
|
out_ptr, \
|
||||||
|
query_ptr, \
|
||||||
|
key_cache_ptr, \
|
||||||
|
value_cache_ptr, \
|
||||||
|
head_mapping_ptr, \
|
||||||
|
scale, \
|
||||||
|
block_tables_ptr, \
|
||||||
|
context_lens_ptr, \
|
||||||
|
max_num_blocks_per_seq, \
|
||||||
|
alibi_slopes_ptr, \
|
||||||
|
q_stride, \
|
||||||
|
kv_block_stride, \
|
||||||
|
kv_head_stride);
|
||||||
|
|
||||||
|
// TODO(woosuk): Tune NUM_THREADS.
|
||||||
|
template<
|
||||||
|
typename T,
|
||||||
|
int BLOCK_SIZE,
|
||||||
|
int NUM_THREADS = 128>
|
||||||
|
void paged_attention_v1_launcher(
|
||||||
|
torch::Tensor& out,
|
||||||
|
torch::Tensor& query,
|
||||||
|
torch::Tensor& key_cache,
|
||||||
|
torch::Tensor& value_cache,
|
||||||
|
torch::Tensor& head_mapping,
|
||||||
|
float scale,
|
||||||
|
torch::Tensor& block_tables,
|
||||||
|
torch::Tensor& context_lens,
|
||||||
|
int max_context_len,
|
||||||
|
const c10::optional<torch::Tensor>& alibi_slopes) {
|
||||||
|
int num_seqs = query.size(0);
|
||||||
|
int num_heads = query.size(1);
|
||||||
|
int head_size = query.size(2);
|
||||||
|
int max_num_blocks_per_seq = block_tables.size(1);
|
||||||
|
int q_stride = query.stride(0);
|
||||||
|
int kv_block_stride = key_cache.stride(0);
|
||||||
|
int kv_head_stride = key_cache.stride(1);
|
||||||
|
|
||||||
|
int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
|
||||||
|
assert(head_size % thread_group_size == 0);
|
||||||
|
|
||||||
|
// NOTE: alibi_slopes is optional.
|
||||||
|
const float* alibi_slopes_ptr = alibi_slopes ?
|
||||||
|
reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
|
||||||
|
: nullptr;
|
||||||
|
|
||||||
|
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
|
||||||
|
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
|
||||||
|
T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
|
||||||
|
T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
|
||||||
|
int* head_mapping_ptr = reinterpret_cast<int*>(head_mapping.data_ptr());
|
||||||
|
int* block_tables_ptr = block_tables.data_ptr<int>();
|
||||||
|
int* context_lens_ptr = context_lens.data_ptr<int>();
|
||||||
|
|
||||||
|
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
||||||
|
int padded_max_context_len = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE) * BLOCK_SIZE;
|
||||||
|
int logits_size = padded_max_context_len * sizeof(float);
|
||||||
|
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
|
||||||
|
// Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
|
||||||
|
// Keep that in sync with the logic here!
|
||||||
|
int shared_mem_size = std::max(logits_size, outputs_size);
|
||||||
|
|
||||||
|
dim3 grid(num_heads, num_seqs, 1);
|
||||||
|
dim3 block(NUM_THREADS);
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
switch (head_size) {
|
||||||
|
// NOTE(woosuk): To reduce the compilation time, we only compile for the
|
||||||
|
// head sizes that we use in the model. However, we can easily extend this
|
||||||
|
// to support any head size which is a multiple of 16.
|
||||||
|
case 64:
|
||||||
|
LAUNCH_PAGED_ATTENTION_V1(64);
|
||||||
|
break;
|
||||||
|
case 80:
|
||||||
|
LAUNCH_PAGED_ATTENTION_V1(80);
|
||||||
|
break;
|
||||||
|
case 96:
|
||||||
|
LAUNCH_PAGED_ATTENTION_V1(96);
|
||||||
|
break;
|
||||||
|
case 112:
|
||||||
|
LAUNCH_PAGED_ATTENTION_V1(112);
|
||||||
|
break;
|
||||||
|
case 128:
|
||||||
|
LAUNCH_PAGED_ATTENTION_V1(128);
|
||||||
|
break;
|
||||||
|
case 256:
|
||||||
|
LAUNCH_PAGED_ATTENTION_V1(256);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
TORCH_CHECK(false, "Unsupported head size: ", head_size);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#define CALL_V1_LAUNCHER(T, BLOCK_SIZE) \
|
||||||
|
paged_attention_v1_launcher<T, BLOCK_SIZE>( \
|
||||||
|
out, \
|
||||||
|
query, \
|
||||||
|
key_cache, \
|
||||||
|
value_cache, \
|
||||||
|
head_mapping, \
|
||||||
|
scale, \
|
||||||
|
block_tables, \
|
||||||
|
context_lens, \
|
||||||
|
max_context_len, \
|
||||||
|
alibi_slopes);
|
||||||
|
|
||||||
|
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
|
||||||
|
// 1, 2, 4, 64, 128, 256.
|
||||||
|
#define CALL_V1_LAUNCHER_BLOCK_SIZE(T) \
|
||||||
|
switch (block_size) { \
|
||||||
|
case 8: \
|
||||||
|
CALL_V1_LAUNCHER(T, 8); \
|
||||||
|
break; \
|
||||||
|
case 16: \
|
||||||
|
CALL_V1_LAUNCHER(T, 16); \
|
||||||
|
break; \
|
||||||
|
case 32: \
|
||||||
|
CALL_V1_LAUNCHER(T, 32); \
|
||||||
|
break; \
|
||||||
|
default: \
|
||||||
|
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
||||||
|
break; \
|
||||||
|
}
|
||||||
|
|
||||||
|
void paged_attention_v1(
|
||||||
|
torch::Tensor& out, // [num_seqs, num_heads, head_size]
|
||||||
|
torch::Tensor& query, // [num_seqs, num_heads, head_size]
|
||||||
|
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||||
|
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||||
|
torch::Tensor& head_mapping, // [num_heads]
|
||||||
|
float scale,
|
||||||
|
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||||
|
torch::Tensor& context_lens, // [num_seqs]
|
||||||
|
int block_size,
|
||||||
|
int max_context_len,
|
||||||
|
const c10::optional<torch::Tensor>& alibi_slopes) {
|
||||||
|
if (query.dtype() == at::ScalarType::Float) {
|
||||||
|
CALL_V1_LAUNCHER_BLOCK_SIZE(float);
|
||||||
|
} else if (query.dtype() == at::ScalarType::Half) {
|
||||||
|
CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t);
|
||||||
|
} else if (query.dtype() == at::ScalarType::BFloat16) {
|
||||||
|
CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16);
|
||||||
|
} else {
|
||||||
|
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
|
||||||
|
vllm::paged_attention_v2_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, PARTITION_SIZE> \
|
||||||
|
<<<grid, block, shared_mem_size, stream>>>( \
|
||||||
|
exp_sums_ptr, \
|
||||||
|
max_logits_ptr, \
|
||||||
|
tmp_out_ptr, \
|
||||||
|
query_ptr, \
|
||||||
|
key_cache_ptr, \
|
||||||
|
value_cache_ptr, \
|
||||||
|
head_mapping_ptr, \
|
||||||
|
scale, \
|
||||||
|
block_tables_ptr, \
|
||||||
|
context_lens_ptr, \
|
||||||
|
max_num_blocks_per_seq, \
|
||||||
|
alibi_slopes_ptr, \
|
||||||
|
q_stride, \
|
||||||
|
kv_block_stride, \
|
||||||
|
kv_head_stride); \
|
||||||
|
vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, PARTITION_SIZE> \
|
||||||
|
<<<reduce_grid, block, reduce_shared_mem_size, stream>>>( \
|
||||||
|
out_ptr, \
|
||||||
|
exp_sums_ptr, \
|
||||||
|
max_logits_ptr, \
|
||||||
|
tmp_out_ptr, \
|
||||||
|
context_lens_ptr, \
|
||||||
|
max_num_partitions);
|
||||||
|
|
||||||
|
template<
|
||||||
|
typename T,
|
||||||
|
int BLOCK_SIZE,
|
||||||
|
int NUM_THREADS = 128,
|
||||||
|
int PARTITION_SIZE = 512>
|
||||||
|
void paged_attention_v2_launcher(
|
||||||
|
torch::Tensor& out,
|
||||||
|
torch::Tensor& exp_sums,
|
||||||
|
torch::Tensor& max_logits,
|
||||||
|
torch::Tensor& tmp_out,
|
||||||
|
torch::Tensor& query,
|
||||||
|
torch::Tensor& key_cache,
|
||||||
|
torch::Tensor& value_cache,
|
||||||
|
torch::Tensor& head_mapping,
|
||||||
|
float scale,
|
||||||
|
torch::Tensor& block_tables,
|
||||||
|
torch::Tensor& context_lens,
|
||||||
|
int max_context_len,
|
||||||
|
const c10::optional<torch::Tensor>& alibi_slopes) {
|
||||||
|
int num_seqs = query.size(0);
|
||||||
|
int num_heads = query.size(1);
|
||||||
|
int head_size = query.size(2);
|
||||||
|
int max_num_blocks_per_seq = block_tables.size(1);
|
||||||
|
int q_stride = query.stride(0);
|
||||||
|
int kv_block_stride = key_cache.stride(0);
|
||||||
|
int kv_head_stride = key_cache.stride(1);
|
||||||
|
|
||||||
|
int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
|
||||||
|
assert(head_size % thread_group_size == 0);
|
||||||
|
|
||||||
|
// NOTE: alibi_slopes is optional.
|
||||||
|
const float* alibi_slopes_ptr = alibi_slopes ?
|
||||||
|
reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
|
||||||
|
: nullptr;
|
||||||
|
|
||||||
|
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
|
||||||
|
float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
|
||||||
|
float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
|
||||||
|
T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
|
||||||
|
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
|
||||||
|
T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
|
||||||
|
T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
|
||||||
|
int* head_mapping_ptr = reinterpret_cast<int*>(head_mapping.data_ptr());
|
||||||
|
int* block_tables_ptr = block_tables.data_ptr<int>();
|
||||||
|
int* context_lens_ptr = context_lens.data_ptr<int>();
|
||||||
|
|
||||||
|
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
||||||
|
int max_num_partitions = DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE);
|
||||||
|
int logits_size = PARTITION_SIZE * sizeof(float);
|
||||||
|
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
|
||||||
|
|
||||||
|
// For paged attention v2 kernel.
|
||||||
|
dim3 grid(num_heads, num_seqs, max_num_partitions);
|
||||||
|
int shared_mem_size = std::max(logits_size, outputs_size);
|
||||||
|
// For paged attention v2 reduce kernel.
|
||||||
|
dim3 reduce_grid(num_heads, num_seqs);
|
||||||
|
int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);
|
||||||
|
|
||||||
|
dim3 block(NUM_THREADS);
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
switch (head_size) {
|
||||||
|
// NOTE(woosuk): To reduce the compilation time, we only compile for the
|
||||||
|
// head sizes that we use in the model. However, we can easily extend this
|
||||||
|
// to support any head size which is a multiple of 16.
|
||||||
|
case 64:
|
||||||
|
LAUNCH_PAGED_ATTENTION_V2(64);
|
||||||
|
break;
|
||||||
|
case 80:
|
||||||
|
LAUNCH_PAGED_ATTENTION_V2(80);
|
||||||
|
break;
|
||||||
|
case 96:
|
||||||
|
LAUNCH_PAGED_ATTENTION_V2(96);
|
||||||
|
break;
|
||||||
|
case 112:
|
||||||
|
LAUNCH_PAGED_ATTENTION_V2(112);
|
||||||
|
break;
|
||||||
|
case 128:
|
||||||
|
LAUNCH_PAGED_ATTENTION_V2(128);
|
||||||
|
break;
|
||||||
|
case 256:
|
||||||
|
LAUNCH_PAGED_ATTENTION_V2(256);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
TORCH_CHECK(false, "Unsupported head size: ", head_size);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#define CALL_V2_LAUNCHER(T, BLOCK_SIZE) \
|
||||||
|
paged_attention_v2_launcher<T, BLOCK_SIZE>( \
|
||||||
|
out, \
|
||||||
|
exp_sums, \
|
||||||
|
max_logits, \
|
||||||
|
tmp_out, \
|
||||||
|
query, \
|
||||||
|
key_cache, \
|
||||||
|
value_cache, \
|
||||||
|
head_mapping, \
|
||||||
|
scale, \
|
||||||
|
block_tables, \
|
||||||
|
context_lens, \
|
||||||
|
max_context_len, \
|
||||||
|
alibi_slopes);
|
||||||
|
|
||||||
|
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
|
||||||
|
// 1, 2, 4, 64, 128, 256.
|
||||||
|
#define CALL_V2_LAUNCHER_BLOCK_SIZE(T) \
|
||||||
|
switch (block_size) { \
|
||||||
|
case 8: \
|
||||||
|
CALL_V2_LAUNCHER(T, 8); \
|
||||||
|
break; \
|
||||||
|
case 16: \
|
||||||
|
CALL_V2_LAUNCHER(T, 16); \
|
||||||
|
break; \
|
||||||
|
case 32: \
|
||||||
|
CALL_V2_LAUNCHER(T, 32); \
|
||||||
|
break; \
|
||||||
|
default: \
|
||||||
|
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
||||||
|
break; \
|
||||||
|
}
|
||||||
|
|
||||||
|
void paged_attention_v2(
|
||||||
|
torch::Tensor& out, // [num_seqs, num_heads, head_size]
|
||||||
|
torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
||||||
|
torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions]
|
||||||
|
torch::Tensor& tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
|
||||||
|
torch::Tensor& query, // [num_seqs, num_heads, head_size]
|
||||||
|
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||||
|
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||||
|
torch::Tensor& head_mapping, // [num_heads]
|
||||||
|
float scale,
|
||||||
|
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||||
|
torch::Tensor& context_lens, // [num_seqs]
|
||||||
|
int block_size,
|
||||||
|
int max_context_len,
|
||||||
|
const c10::optional<torch::Tensor>& alibi_slopes) {
|
||||||
|
if (query.dtype() == at::ScalarType::Float) {
|
||||||
|
CALL_V2_LAUNCHER_BLOCK_SIZE(float);
|
||||||
|
} else if (query.dtype() == at::ScalarType::Half) {
|
||||||
|
CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t);
|
||||||
|
} else if (query.dtype() == at::ScalarType::BFloat16) {
|
||||||
|
CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16);
|
||||||
|
} else {
|
||||||
|
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#undef WARP_SIZE
|
||||||
|
#undef MAX
|
||||||
|
#undef MIN
|
||||||
|
#undef DIVIDE_ROUND_UP
|
||||||
55
csrc/attention/attention_utils.cuh
Normal file
55
csrc/attention/attention_utils.cuh
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
/*
|
||||||
|
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
||||||
|
* Copyright (c) 2023, The vLLM team.
|
||||||
|
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "attention_dtypes.h"
|
||||||
|
|
||||||
|
#include <float.h>
|
||||||
|
#include <type_traits>
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
|
||||||
|
// Q*K^T operation.
|
||||||
|
template<int THREAD_GROUP_SIZE, typename Vec, int N>
|
||||||
|
inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
|
||||||
|
using A_vec = typename FloatVec<Vec>::Type;
|
||||||
|
// Compute the parallel products for Q*K^T (treat vector lanes separately).
|
||||||
|
A_vec qk_vec = mul<A_vec, Vec, Vec>(q[0], k[0]);
|
||||||
|
#pragma unroll
|
||||||
|
for (int ii = 1; ii < N; ++ii) {
|
||||||
|
qk_vec = fma(q[ii], k[ii], qk_vec);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Finalize the reduction across lanes.
|
||||||
|
float qk = sum(qk_vec);
|
||||||
|
#pragma unroll
|
||||||
|
for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) {
|
||||||
|
qk += __shfl_xor_sync(uint32_t(-1), qk, mask);
|
||||||
|
}
|
||||||
|
return qk;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T, int THREAD_GROUP_SIZE>
|
||||||
|
struct Qk_dot {
|
||||||
|
template<typename Vec, int N>
|
||||||
|
static inline __device__ float dot(const Vec (&q)[N], const Vec (&k)[N]) {
|
||||||
|
return qk_dot_<THREAD_GROUP_SIZE>(q, k);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace vllm
|
||||||
438
csrc/attention/dtype_bfloat16.cuh
Normal file
438
csrc/attention/dtype_bfloat16.cuh
Normal file
@ -0,0 +1,438 @@
|
|||||||
|
/*
|
||||||
|
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
||||||
|
* and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
|
||||||
|
* Copyright (c) 2023, The vLLM team.
|
||||||
|
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "attention_generic.cuh"
|
||||||
|
#include "dtype_float32.cuh"
|
||||||
|
|
||||||
|
#include <cuda_bf16.h>
|
||||||
|
#include <cuda_fp16.h>
|
||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
|
||||||
|
// Define custom BF16 vector data types.
|
||||||
|
struct bf16_4_t {
|
||||||
|
__nv_bfloat162 x;
|
||||||
|
__nv_bfloat162 y;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct bf16_8_t {
|
||||||
|
__nv_bfloat162 x;
|
||||||
|
__nv_bfloat162 y;
|
||||||
|
__nv_bfloat162 z;
|
||||||
|
__nv_bfloat162 w;
|
||||||
|
};
|
||||||
|
|
||||||
|
// BF16 vector types for Q, K, V.
|
||||||
|
template<>
|
||||||
|
struct Vec<__nv_bfloat16, 1> {
|
||||||
|
using Type = __nv_bfloat16;
|
||||||
|
};
|
||||||
|
template<>
|
||||||
|
struct Vec<__nv_bfloat16, 2> {
|
||||||
|
using Type = __nv_bfloat162;
|
||||||
|
};
|
||||||
|
template<>
|
||||||
|
struct Vec<__nv_bfloat16, 4> {
|
||||||
|
using Type = bf16_4_t;
|
||||||
|
};
|
||||||
|
template<>
|
||||||
|
struct Vec<__nv_bfloat16, 8> {
|
||||||
|
using Type = bf16_8_t;
|
||||||
|
};
|
||||||
|
|
||||||
|
// FP32 accumulator vector types corresponding to Vec.
|
||||||
|
template<>
|
||||||
|
struct FloatVec<__nv_bfloat16> {
|
||||||
|
using Type = float;
|
||||||
|
};
|
||||||
|
template<>
|
||||||
|
struct FloatVec<__nv_bfloat162> {
|
||||||
|
using Type = float2;
|
||||||
|
};
|
||||||
|
template<>
|
||||||
|
struct FloatVec<bf16_4_t> {
|
||||||
|
using Type = Float4_;
|
||||||
|
};
|
||||||
|
template<>
|
||||||
|
struct FloatVec<bf16_8_t> {
|
||||||
|
using Type = Float8_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Utility functions for type conversions.
|
||||||
|
inline __device__ float2 bf1622float2(const __nv_bfloat162 val) {
|
||||||
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||||
|
assert(false);
|
||||||
|
#else
|
||||||
|
return __bfloat1622float2(val);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) {
|
||||||
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||||
|
assert(false);
|
||||||
|
#else
|
||||||
|
return __bfloat162bfloat162(val);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
// Vector addition.
|
||||||
|
inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) {
|
||||||
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||||
|
assert(false);
|
||||||
|
#else
|
||||||
|
return a + b;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) {
|
||||||
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||||
|
assert(false);
|
||||||
|
#else
|
||||||
|
return __hadd2(a, b);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ bf16_4_t add(bf16_4_t a, bf16_4_t b) {
|
||||||
|
bf16_4_t c;
|
||||||
|
c.x = add(a.x, b.x);
|
||||||
|
c.y = add(a.y, b.y);
|
||||||
|
return c;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ bf16_8_t add(bf16_8_t a, bf16_8_t b) {
|
||||||
|
bf16_8_t c;
|
||||||
|
c.x = add(a.x, b.x);
|
||||||
|
c.y = add(a.y, b.y);
|
||||||
|
c.z = add(a.z, b.z);
|
||||||
|
c.w = add(a.w, b.w);
|
||||||
|
return c;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ float2 add(__nv_bfloat162 a, float2 fb) {
|
||||||
|
float2 fa = bf1622float2(a);
|
||||||
|
return add(fa, fb);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ Float4_ add(bf16_4_t a, Float4_ fb) {
|
||||||
|
Float4_ fc;
|
||||||
|
fc.x = add(a.x, fb.x);
|
||||||
|
fc.y = add(a.y, fb.y);
|
||||||
|
return fc;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) {
|
||||||
|
Float8_ fc;
|
||||||
|
fc.x = add(a.x, fb.x);
|
||||||
|
fc.y = add(a.y, fb.y);
|
||||||
|
fc.z = add(a.z, fb.z);
|
||||||
|
fc.w = add(a.w, fb.w);
|
||||||
|
return fc;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Vector multiplication.
|
||||||
|
template<>
|
||||||
|
inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) {
|
||||||
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||||
|
assert(false);
|
||||||
|
#else
|
||||||
|
return __hmul(a, b);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) {
|
||||||
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||||
|
assert(false);
|
||||||
|
#else
|
||||||
|
return __hmul2(a, b);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
inline __device__ __nv_bfloat162 mul(__nv_bfloat16 a, __nv_bfloat162 b) {
|
||||||
|
return mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) {
|
||||||
|
bf16_4_t c;
|
||||||
|
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
|
||||||
|
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
|
||||||
|
return c;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) {
|
||||||
|
__nv_bfloat162 s = bf162bf162(a);
|
||||||
|
bf16_4_t c;
|
||||||
|
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x);
|
||||||
|
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y);
|
||||||
|
return c;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) {
|
||||||
|
bf16_8_t c;
|
||||||
|
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
|
||||||
|
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
|
||||||
|
c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.z, b.z);
|
||||||
|
c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.w, b.w);
|
||||||
|
return c;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) {
|
||||||
|
__nv_bfloat162 s = bf162bf162(a);
|
||||||
|
bf16_8_t c;
|
||||||
|
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x);
|
||||||
|
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y);
|
||||||
|
c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.z);
|
||||||
|
c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.w);
|
||||||
|
return c;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
inline __device__ float mul(__nv_bfloat16 a, __nv_bfloat16 b) {
|
||||||
|
float fa = __bfloat162float(a);
|
||||||
|
float fb = __bfloat162float(b);
|
||||||
|
return fa * fb;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
inline __device__ float2 mul(__nv_bfloat162 a, __nv_bfloat162 b) {
|
||||||
|
float2 fa = bf1622float2(a);
|
||||||
|
float2 fb = bf1622float2(b);
|
||||||
|
return mul<float2, float2, float2>(fa, fb);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
inline __device__ float2 mul(__nv_bfloat16 a, __nv_bfloat162 b) {
|
||||||
|
return mul<float2, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) {
|
||||||
|
Float4_ fc;
|
||||||
|
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
|
||||||
|
fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
|
||||||
|
return fc;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) {
|
||||||
|
__nv_bfloat162 s = bf162bf162(a);
|
||||||
|
Float4_ fc;
|
||||||
|
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.x);
|
||||||
|
fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.y);
|
||||||
|
return fc;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) {
|
||||||
|
Float8_ fc;
|
||||||
|
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
|
||||||
|
fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
|
||||||
|
fc.z = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.z, b.z);
|
||||||
|
fc.w = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.w, b.w);
|
||||||
|
return fc;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) {
|
||||||
|
__nv_bfloat162 s = bf162bf162(a);
|
||||||
|
Float8_ fc;
|
||||||
|
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.x);
|
||||||
|
fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.y);
|
||||||
|
fc.z = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.z);
|
||||||
|
fc.w = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.w);
|
||||||
|
return fc;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Vector fused multiply-add.
|
||||||
|
inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) {
|
||||||
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||||
|
assert(false);
|
||||||
|
#else
|
||||||
|
return __hfma2(a, b, c);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, __nv_bfloat162 c) {
|
||||||
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||||
|
assert(false);
|
||||||
|
#else
|
||||||
|
return __hfma2(bf162bf162(a), b, c);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ bf16_4_t fma(bf16_4_t a, bf16_4_t b, bf16_4_t c) {
|
||||||
|
bf16_4_t d;
|
||||||
|
d.x = fma(a.x, b.x, c.x);
|
||||||
|
d.y = fma(a.y, b.y, c.y);
|
||||||
|
return d;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ bf16_4_t fma(__nv_bfloat16 a, bf16_4_t b, bf16_4_t c) {
|
||||||
|
__nv_bfloat162 s = bf162bf162(a);
|
||||||
|
bf16_4_t d;
|
||||||
|
d.x = fma(s, b.x, c.x);
|
||||||
|
d.y = fma(s, b.y, c.y);
|
||||||
|
return d;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ bf16_8_t fma(bf16_8_t a, bf16_8_t b, bf16_8_t c) {
|
||||||
|
bf16_8_t d;
|
||||||
|
d.x = fma(a.x, b.x, c.x);
|
||||||
|
d.y = fma(a.y, b.y, c.y);
|
||||||
|
d.z = fma(a.z, b.z, c.z);
|
||||||
|
d.w = fma(a.w, b.w, c.w);
|
||||||
|
return d;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ bf16_8_t fma(__nv_bfloat16 a, bf16_8_t b, bf16_8_t c) {
|
||||||
|
__nv_bfloat162 s = bf162bf162(a);
|
||||||
|
bf16_8_t d;
|
||||||
|
d.x = fma(s, b.x, c.x);
|
||||||
|
d.y = fma(s, b.y, c.y);
|
||||||
|
d.z = fma(s, b.z, c.z);
|
||||||
|
d.w = fma(s, b.w, c.w);
|
||||||
|
return d;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ float fma(__nv_bfloat16 a, __nv_bfloat16 b, float fc) {
|
||||||
|
return __bfloat162float(a) * __bfloat162float(b) + fc;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ float2 fma(__nv_bfloat162 a, __nv_bfloat162 b, float2 fc) {
|
||||||
|
float2 fa = bf1622float2(a);
|
||||||
|
float2 fb = bf1622float2(b);
|
||||||
|
return fma(fa, fb, fc);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ float2 fma(__nv_bfloat16 a, __nv_bfloat162 b, float2 fc) {
|
||||||
|
return fma(bf162bf162(a), b, fc);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ Float4_ fma(bf16_4_t a, bf16_4_t b, Float4_ fc) {
|
||||||
|
Float4_ fd;
|
||||||
|
fd.x = fma(a.x, b.x, fc.x);
|
||||||
|
fd.y = fma(a.y, b.y, fc.y);
|
||||||
|
return fd;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ Float4_ fma(__nv_bfloat16 a, bf16_4_t b, Float4_ fc) {
|
||||||
|
__nv_bfloat162 s = bf162bf162(a);
|
||||||
|
Float4_ fd;
|
||||||
|
fd.x = fma(s, b.x, fc.x);
|
||||||
|
fd.y = fma(s, b.y, fc.y);
|
||||||
|
return fd;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ Float8_ fma(bf16_8_t a, bf16_8_t b, Float8_ fc) {
|
||||||
|
Float8_ fd;
|
||||||
|
fd.x = fma(a.x, b.x, fc.x);
|
||||||
|
fd.y = fma(a.y, b.y, fc.y);
|
||||||
|
fd.z = fma(a.z, b.z, fc.z);
|
||||||
|
fd.w = fma(a.w, b.w, fc.w);
|
||||||
|
return fd;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ Float8_ fma(__nv_bfloat16 a, bf16_8_t b, Float8_ fc) {
|
||||||
|
__nv_bfloat162 s = bf162bf162(a);
|
||||||
|
Float8_ fd;
|
||||||
|
fd.x = fma(s, b.x, fc.x);
|
||||||
|
fd.y = fma(s, b.y, fc.y);
|
||||||
|
fd.z = fma(s, b.z, fc.z);
|
||||||
|
fd.w = fma(s, b.w, fc.w);
|
||||||
|
return fd;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Vector sum.
|
||||||
|
template<>
|
||||||
|
inline __device__ float sum(__nv_bfloat16 v) {
|
||||||
|
return __bfloat162float(v);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
inline __device__ float sum(__nv_bfloat162 v) {
|
||||||
|
float2 vf = bf1622float2(v);
|
||||||
|
return vf.x + vf.y;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
inline __device__ float sum(bf16_4_t v) {
|
||||||
|
return sum(v.x) + sum(v.y);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
inline __device__ float sum(bf16_8_t v) {
|
||||||
|
return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w);
|
||||||
|
}
|
||||||
|
|
||||||
|
// From float32 to bfloat16.
|
||||||
|
inline __device__ void from_float(__nv_bfloat16& dst, float src) {
|
||||||
|
dst = __float2bfloat16(src);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ void from_float(__nv_bfloat162& dst, float2 src) {
|
||||||
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||||
|
assert(false);
|
||||||
|
#else
|
||||||
|
dst = __float22bfloat162_rn(src);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ void from_float(bf16_4_t& dst, Float4_ src) {
|
||||||
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||||
|
assert(false);
|
||||||
|
#else
|
||||||
|
dst.x = __float22bfloat162_rn(src.x);
|
||||||
|
dst.y = __float22bfloat162_rn(src.y);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ void from_float(bf16_8_t& dst, Float8_ src) {
|
||||||
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||||
|
assert(false);
|
||||||
|
#else
|
||||||
|
dst.x = __float22bfloat162_rn(src.x);
|
||||||
|
dst.y = __float22bfloat162_rn(src.y);
|
||||||
|
dst.z = __float22bfloat162_rn(src.z);
|
||||||
|
dst.w = __float22bfloat162_rn(src.w);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
// From bfloat16 to float32.
|
||||||
|
inline __device__ float to_float(__nv_bfloat16 u) {
|
||||||
|
return __bfloat162float(u);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Zero-out a variable.
|
||||||
|
inline __device__ void zero(__nv_bfloat16& dst) {
|
||||||
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||||
|
assert(false);
|
||||||
|
#else
|
||||||
|
// Same as CUDART_ZERO_BF16 introduced in CUDA 12.2.
|
||||||
|
dst = __ushort_as_bfloat16((unsigned short)0x0000U);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace vllm
|
||||||
444
csrc/attention/dtype_float16.cuh
Normal file
444
csrc/attention/dtype_float16.cuh
Normal file
@ -0,0 +1,444 @@
|
|||||||
|
/*
|
||||||
|
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
||||||
|
* and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
|
||||||
|
* Copyright (c) 2023, The vLLM team.
|
||||||
|
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "attention_generic.cuh"
|
||||||
|
#include "dtype_float32.cuh"
|
||||||
|
|
||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
|
||||||
|
// FP16 vector types for Q, K, V.
|
||||||
|
template<>
|
||||||
|
struct Vec<uint16_t, 1> {
|
||||||
|
using Type = uint16_t;
|
||||||
|
};
|
||||||
|
template<>
|
||||||
|
struct Vec<uint16_t, 2> {
|
||||||
|
using Type = uint32_t;
|
||||||
|
};
|
||||||
|
template<>
|
||||||
|
struct Vec<uint16_t, 4> {
|
||||||
|
using Type = uint2;
|
||||||
|
};
|
||||||
|
template<>
|
||||||
|
struct Vec<uint16_t, 8> {
|
||||||
|
using Type = uint4;
|
||||||
|
};
|
||||||
|
|
||||||
|
// FP32 accumulator vector types corresponding to Vec.
|
||||||
|
template<>
|
||||||
|
struct FloatVec<uint16_t> {
|
||||||
|
using Type = float;
|
||||||
|
};
|
||||||
|
template<>
|
||||||
|
struct FloatVec<uint32_t> {
|
||||||
|
using Type = float2;
|
||||||
|
};
|
||||||
|
template<>
|
||||||
|
struct FloatVec<uint2> {
|
||||||
|
using Type = Float4_;
|
||||||
|
};
|
||||||
|
template<>
|
||||||
|
struct FloatVec<uint4> {
|
||||||
|
using Type = Float8_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Utility functions for type conversions.
|
||||||
|
inline __device__ uint32_t h0_h0(uint16_t a) {
|
||||||
|
uint32_t b;
|
||||||
|
asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a));
|
||||||
|
return b;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ float half_to_float(uint16_t h) {
|
||||||
|
float f;
|
||||||
|
asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h));
|
||||||
|
return f;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ float2 half2_to_float2(uint32_t v) {
|
||||||
|
uint16_t lo, hi;
|
||||||
|
asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v));
|
||||||
|
return make_float2(half_to_float(lo), half_to_float(hi));
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ uint16_t float_to_half(float f) {
|
||||||
|
union {
|
||||||
|
uint32_t u32;
|
||||||
|
uint16_t u16[2];
|
||||||
|
} tmp;
|
||||||
|
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f));
|
||||||
|
return tmp.u16[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ uint32_t float2_to_half2(float2 f) {
|
||||||
|
union {
|
||||||
|
uint32_t u32;
|
||||||
|
uint16_t u16[2];
|
||||||
|
} tmp;
|
||||||
|
|
||||||
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||||
|
asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x));
|
||||||
|
#else
|
||||||
|
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
|
||||||
|
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
|
||||||
|
#endif
|
||||||
|
return tmp.u32;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Vector addition.
|
||||||
|
inline __device__ uint16_t add(uint16_t a, uint16_t b) {
|
||||||
|
uint16_t c;
|
||||||
|
asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
|
||||||
|
return c;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ uint32_t add(uint32_t a, uint32_t b) {
|
||||||
|
uint32_t c;
|
||||||
|
asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
|
||||||
|
return c;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ uint2 add(uint2 a, uint2 b) {
|
||||||
|
uint2 c;
|
||||||
|
c.x = add(a.x, b.x);
|
||||||
|
c.y = add(a.y, b.y);
|
||||||
|
return c;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ uint4 add(uint4 a, uint4 b) {
|
||||||
|
uint4 c;
|
||||||
|
c.x = add(a.x, b.x);
|
||||||
|
c.y = add(a.y, b.y);
|
||||||
|
c.z = add(a.z, b.z);
|
||||||
|
c.w = add(a.w, b.w);
|
||||||
|
return c;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ float2 add(uint32_t a, float2 fb) {
|
||||||
|
float2 fa = half2_to_float2(a);
|
||||||
|
return add(fa, fb);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ Float4_ add(uint2 a, Float4_ fb) {
|
||||||
|
Float4_ fc;
|
||||||
|
fc.x = add(a.x, fb.x);
|
||||||
|
fc.y = add(a.y, fb.y);
|
||||||
|
return fc;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ Float8_ add(uint4 a, Float8_ fb) {
|
||||||
|
Float8_ fc;
|
||||||
|
fc.x = add(a.x, fb.x);
|
||||||
|
fc.y = add(a.y, fb.y);
|
||||||
|
fc.z = add(a.z, fb.z);
|
||||||
|
fc.w = add(a.w, fb.w);
|
||||||
|
return fc;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Vector multiplication.
|
||||||
|
template<>
|
||||||
|
inline __device__ uint16_t mul(uint16_t a, uint16_t b) {
|
||||||
|
uint16_t c;
|
||||||
|
asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
|
||||||
|
return c;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
inline __device__ uint32_t mul(uint32_t a, uint32_t b) {
|
||||||
|
uint32_t c;
|
||||||
|
asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
|
||||||
|
return c;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
inline __device__ uint32_t mul(uint16_t a, uint32_t b) {
|
||||||
|
return mul<uint32_t, uint32_t, uint32_t>(h0_h0(a), b);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
inline __device__ uint2 mul(uint2 a, uint2 b) {
|
||||||
|
uint2 c;
|
||||||
|
c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
|
||||||
|
c.y = mul<uint32_t, uint32_t, uint32_t>(a.y, b.y);
|
||||||
|
return c;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
inline __device__ uint2 mul(uint16_t a, uint2 b) {
|
||||||
|
uint32_t s = h0_h0(a);
|
||||||
|
uint2 c;
|
||||||
|
c.x = mul<uint32_t, uint32_t, uint32_t>(s, b.x);
|
||||||
|
c.y = mul<uint32_t, uint32_t, uint32_t>(s, b.y);
|
||||||
|
return c;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
inline __device__ uint4 mul(uint4 a, uint4 b) {
|
||||||
|
uint4 c;
|
||||||
|
c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
|
||||||
|
c.y = mul<uint32_t, uint32_t, uint32_t>(a.y, b.y);
|
||||||
|
c.z = mul<uint32_t, uint32_t, uint32_t>(a.z, b.z);
|
||||||
|
c.w = mul<uint32_t, uint32_t, uint32_t>(a.w, b.w);
|
||||||
|
return c;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
inline __device__ uint4 mul(uint16_t a, uint4 b) {
|
||||||
|
uint32_t s = h0_h0(a);
|
||||||
|
uint4 c;
|
||||||
|
c.x = mul<uint32_t, uint32_t, uint32_t>(s, b.x);
|
||||||
|
c.y = mul<uint32_t, uint32_t, uint32_t>(s, b.y);
|
||||||
|
c.z = mul<uint32_t, uint32_t, uint32_t>(s, b.z);
|
||||||
|
c.w = mul<uint32_t, uint32_t, uint32_t>(s, b.w);
|
||||||
|
return c;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
inline __device__ float mul(uint16_t a, uint16_t b) {
|
||||||
|
float fa = half_to_float(a);
|
||||||
|
float fb = half_to_float(b);
|
||||||
|
return fa * fb;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
inline __device__ float2 mul(uint32_t a, uint32_t b) {
|
||||||
|
float2 fa = half2_to_float2(a);
|
||||||
|
float2 fb = half2_to_float2(b);
|
||||||
|
return mul<float2, float2, float2>(fa, fb);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
inline __device__ float2 mul(uint16_t a, uint32_t b) {
|
||||||
|
return mul<float2, uint32_t, uint32_t>(h0_h0(a), b);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
inline __device__ Float4_ mul(uint2 a, uint2 b) {
|
||||||
|
Float4_ fc;
|
||||||
|
fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
|
||||||
|
fc.y = mul<float2, uint32_t, uint32_t>(a.y, b.y);
|
||||||
|
return fc;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
inline __device__ Float4_ mul(uint16_t a, uint2 b) {
|
||||||
|
uint32_t s = h0_h0(a);
|
||||||
|
Float4_ fc;
|
||||||
|
fc.x = mul<float2, uint32_t, uint32_t>(s, b.x);
|
||||||
|
fc.y = mul<float2, uint32_t, uint32_t>(s, b.y);
|
||||||
|
return fc;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
inline __device__ Float8_ mul(uint4 a, uint4 b) {
|
||||||
|
Float8_ fc;
|
||||||
|
fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
|
||||||
|
fc.y = mul<float2, uint32_t, uint32_t>(a.y, b.y);
|
||||||
|
fc.z = mul<float2, uint32_t, uint32_t>(a.z, b.z);
|
||||||
|
fc.w = mul<float2, uint32_t, uint32_t>(a.w, b.w);
|
||||||
|
return fc;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
inline __device__ Float8_ mul(uint16_t a, uint4 b) {
|
||||||
|
uint32_t s = h0_h0(a);
|
||||||
|
Float8_ fc;
|
||||||
|
fc.x = mul<float2, uint32_t, uint32_t>(s, b.x);
|
||||||
|
fc.y = mul<float2, uint32_t, uint32_t>(s, b.y);
|
||||||
|
fc.z = mul<float2, uint32_t, uint32_t>(s, b.z);
|
||||||
|
fc.w = mul<float2, uint32_t, uint32_t>(s, b.w);
|
||||||
|
return fc;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Vector fused multiply-add.
|
||||||
|
inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) {
|
||||||
|
uint32_t d;
|
||||||
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c));
|
||||||
|
return d;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) {
|
||||||
|
return fma(h0_h0(a), b, c);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ uint2 fma(uint2 a, uint2 b, uint2 c) {
|
||||||
|
uint2 d;
|
||||||
|
d.x = fma(a.x, b.x, c.x);
|
||||||
|
d.y = fma(a.y, b.y, c.y);
|
||||||
|
return d;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ uint2 fma(uint16_t a, uint2 b, uint2 c) {
|
||||||
|
uint32_t s = h0_h0(a);
|
||||||
|
uint2 d;
|
||||||
|
d.x = fma(s, b.x, c.x);
|
||||||
|
d.y = fma(s, b.y, c.y);
|
||||||
|
return d;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ uint4 fma(uint4 a, uint4 b, uint4 c) {
|
||||||
|
uint4 d;
|
||||||
|
d.x = fma(a.x, b.x, c.x);
|
||||||
|
d.y = fma(a.y, b.y, c.y);
|
||||||
|
d.z = fma(a.z, b.z, c.z);
|
||||||
|
d.w = fma(a.w, b.w, c.w);
|
||||||
|
return d;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ uint4 fma(uint16_t a, uint4 b, uint4 c) {
|
||||||
|
uint32_t s = h0_h0(a);
|
||||||
|
uint4 d;
|
||||||
|
d.x = fma(s, b.x, c.x);
|
||||||
|
d.y = fma(s, b.y, c.y);
|
||||||
|
d.z = fma(s, b.z, c.z);
|
||||||
|
d.w = fma(s, b.w, c.w);
|
||||||
|
return d;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ float fma(uint16_t a, uint16_t b, float fc) {
|
||||||
|
float fa = half_to_float(a);
|
||||||
|
float fb = half_to_float(b);
|
||||||
|
return fa * fb + fc;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ float2 fma(uint32_t a, uint32_t b, float2 fc) {
|
||||||
|
float2 fa = half2_to_float2(a);
|
||||||
|
float2 fb = half2_to_float2(b);
|
||||||
|
return fma(fa, fb, fc);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ float2 fma(uint16_t a, uint32_t b, float2 fc) {
|
||||||
|
return fma(h0_h0(a), b, fc);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ Float4_ fma(uint2 a, uint2 b, Float4_ fc) {
|
||||||
|
Float4_ fd;
|
||||||
|
fd.x = fma(a.x, b.x, fc.x);
|
||||||
|
fd.y = fma(a.y, b.y, fc.y);
|
||||||
|
return fd;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ Float4_ fma(uint16_t a, uint2 b, Float4_ fc) {
|
||||||
|
uint32_t s = h0_h0(a);
|
||||||
|
Float4_ fd;
|
||||||
|
fd.x = fma(s, b.x, fc.x);
|
||||||
|
fd.y = fma(s, b.y, fc.y);
|
||||||
|
return fd;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ Float8_ fma(uint4 a, uint4 b, Float8_ fc) {
|
||||||
|
Float8_ fd;
|
||||||
|
fd.x = fma(a.x, b.x, fc.x);
|
||||||
|
fd.y = fma(a.y, b.y, fc.y);
|
||||||
|
fd.z = fma(a.z, b.z, fc.z);
|
||||||
|
fd.w = fma(a.w, b.w, fc.w);
|
||||||
|
return fd;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc) {
|
||||||
|
uint32_t s = h0_h0(a);
|
||||||
|
Float8_ fd;
|
||||||
|
fd.x = fma(s, b.x, fc.x);
|
||||||
|
fd.y = fma(s, b.y, fc.y);
|
||||||
|
fd.z = fma(s, b.z, fc.z);
|
||||||
|
fd.w = fma(s, b.w, fc.w);
|
||||||
|
return fd;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Vector sum.
|
||||||
|
template<>
|
||||||
|
inline __device__ float sum(uint16_t v) {
|
||||||
|
return half_to_float(v);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
inline __device__ float sum(uint32_t v) {
|
||||||
|
float2 tmp = half2_to_float2(v);
|
||||||
|
return tmp.x + tmp.y;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
inline __device__ float sum(uint2 v) {
|
||||||
|
uint32_t c = add(v.x, v.y);
|
||||||
|
return sum(c);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
inline __device__ float sum(uint4 v) {
|
||||||
|
uint32_t c = add(v.x, v.y);
|
||||||
|
c = add(c, v.z);
|
||||||
|
c = add(c, v.w);
|
||||||
|
return sum(c);
|
||||||
|
}
|
||||||
|
|
||||||
|
// From float32 to float16.
|
||||||
|
inline __device__ void from_float(uint16_t& dst, float src) {
|
||||||
|
dst = float_to_half(src);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ void from_float(uint32_t& dst, float2 src) {
|
||||||
|
dst = float2_to_half2(src);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ void from_float(uint2& dst, Float4_ src) {
|
||||||
|
dst.x = float2_to_half2(src.x);
|
||||||
|
dst.y = float2_to_half2(src.y);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ void from_float(uint4& dst, Float8_ src) {
|
||||||
|
dst.x = float2_to_half2(src.x);
|
||||||
|
dst.y = float2_to_half2(src.y);
|
||||||
|
dst.z = float2_to_half2(src.z);
|
||||||
|
dst.w = float2_to_half2(src.w);
|
||||||
|
}
|
||||||
|
|
||||||
|
// From float16 to float32.
|
||||||
|
inline __device__ float to_float(uint16_t u) {
|
||||||
|
return half_to_float(u);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ float2 to_float(uint32_t u) {
|
||||||
|
return half2_to_float2(u);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ Float4_ to_float(uint2 u) {
|
||||||
|
Float4_ tmp;
|
||||||
|
tmp.x = half2_to_float2(u.x);
|
||||||
|
tmp.y = half2_to_float2(u.y);
|
||||||
|
return tmp;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ Float8_ to_float(uint4 u) {
|
||||||
|
Float8_ tmp;
|
||||||
|
tmp.x = half2_to_float2(u.x);
|
||||||
|
tmp.y = half2_to_float2(u.y);
|
||||||
|
tmp.z = half2_to_float2(u.z);
|
||||||
|
tmp.w = half2_to_float2(u.w);
|
||||||
|
return tmp;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Zero-out a variable.
|
||||||
|
inline __device__ void zero(uint16_t& dst) {
|
||||||
|
dst = uint16_t(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace vllm
|
||||||
273
csrc/attention/dtype_float32.cuh
Normal file
273
csrc/attention/dtype_float32.cuh
Normal file
@ -0,0 +1,273 @@
|
|||||||
|
/*
|
||||||
|
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
||||||
|
* and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
|
||||||
|
* Copyright (c) 2023, The vLLM team.
|
||||||
|
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "attention_generic.cuh"
|
||||||
|
|
||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
|
||||||
|
// Define custom FP32 vector data types.
|
||||||
|
struct Float4_ {
|
||||||
|
float2 x;
|
||||||
|
float2 y;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Float8_ {
|
||||||
|
float2 x;
|
||||||
|
float2 y;
|
||||||
|
float2 z;
|
||||||
|
float2 w;
|
||||||
|
};
|
||||||
|
|
||||||
|
// FP32 vector types for Q, K, V.
|
||||||
|
template<>
|
||||||
|
struct Vec<float, 1> {
|
||||||
|
using Type = float;
|
||||||
|
};
|
||||||
|
template<>
|
||||||
|
struct Vec<float, 2> {
|
||||||
|
using Type = float2;
|
||||||
|
};
|
||||||
|
template<>
|
||||||
|
struct Vec<float, 4> {
|
||||||
|
using Type = float4;
|
||||||
|
};
|
||||||
|
|
||||||
|
// FP32 accumulator vector types corresponding to Vec.
|
||||||
|
template<>
|
||||||
|
struct FloatVec<float> {
|
||||||
|
using Type = float;
|
||||||
|
};
|
||||||
|
template<>
|
||||||
|
struct FloatVec<float2> {
|
||||||
|
using Type = float2;
|
||||||
|
};
|
||||||
|
template<>
|
||||||
|
struct FloatVec<float4> {
|
||||||
|
using Type = float4;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Vector addition.
|
||||||
|
inline __device__ float add(float a, float b) {
|
||||||
|
return a + b;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ float2 add(float2 a, float2 b) {
|
||||||
|
float2 c;
|
||||||
|
c.x = add(a.x, b.x);
|
||||||
|
c.y = add(a.y, b.y);
|
||||||
|
return c;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ float4 add(float4 a, float4 b) {
|
||||||
|
float4 c;
|
||||||
|
c.x = add(a.x, b.x);
|
||||||
|
c.y = add(a.y, b.y);
|
||||||
|
c.z = add(a.z, b.z);
|
||||||
|
c.w = add(a.w, b.w);
|
||||||
|
return c;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Vector multiplication.
|
||||||
|
template<>
|
||||||
|
inline __device__ float mul<float, float>(float a, float b) {
|
||||||
|
return a * b;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
inline __device__ float2 mul(float2 a, float2 b) {
|
||||||
|
float2 c;
|
||||||
|
c.x = a.x * b.x;
|
||||||
|
c.y = a.y * b.y;
|
||||||
|
return c;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
inline __device__ float2 mul(float a, float2 b) {
|
||||||
|
float2 c;
|
||||||
|
c.x = a * b.x;
|
||||||
|
c.y = a * b.y;
|
||||||
|
return c;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
inline __device__ float4 mul(float4 a, float4 b) {
|
||||||
|
float4 c;
|
||||||
|
c.x = a.x * b.x;
|
||||||
|
c.y = a.y * b.y;
|
||||||
|
c.z = a.z * b.z;
|
||||||
|
c.w = a.w * b.w;
|
||||||
|
return c;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
inline __device__ float4 mul(float a, float4 b) {
|
||||||
|
float4 c;
|
||||||
|
c.x = a * b.x;
|
||||||
|
c.y = a * b.y;
|
||||||
|
c.z = a * b.z;
|
||||||
|
c.w = a * b.w;
|
||||||
|
return c;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Vector fused multiply-add.
|
||||||
|
inline __device__ float fma(float a, float b, float c) {
|
||||||
|
return a * b + c;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ float2 fma(float2 a, float2 b, float2 c) {
|
||||||
|
float2 d;
|
||||||
|
d.x = fma(a.x, b.x, c.x);
|
||||||
|
d.y = fma(a.y, b.y, c.y);
|
||||||
|
return d;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ float2 fma(float a, float2 b, float2 c) {
|
||||||
|
float2 d;
|
||||||
|
d.x = fma(a, b.x, c.x);
|
||||||
|
d.y = fma(a, b.y, c.y);
|
||||||
|
return d;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ float4 fma(float4 a, float4 b, float4 c) {
|
||||||
|
float4 d;
|
||||||
|
d.x = fma(a.x, b.x, c.x);
|
||||||
|
d.y = fma(a.y, b.y, c.y);
|
||||||
|
d.z = fma(a.z, b.z, c.z);
|
||||||
|
d.w = fma(a.w, b.w, c.w);
|
||||||
|
return d;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ float4 fma(float a, float4 b, float4 c) {
|
||||||
|
float4 d;
|
||||||
|
d.x = fma(a, b.x, c.x);
|
||||||
|
d.y = fma(a, b.y, c.y);
|
||||||
|
d.z = fma(a, b.z, c.z);
|
||||||
|
d.w = fma(a, b.w, c.w);
|
||||||
|
return d;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ Float4_ fma(float a, Float4_ b, Float4_ c) {
|
||||||
|
Float4_ d;
|
||||||
|
d.x = fma(a, b.x, c.x);
|
||||||
|
d.y = fma(a, b.y, c.y);
|
||||||
|
return d;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) {
|
||||||
|
Float8_ d;
|
||||||
|
d.x = fma(a, b.x, c.x);
|
||||||
|
d.y = fma(a, b.y, c.y);
|
||||||
|
d.z = fma(a, b.z, c.z);
|
||||||
|
d.w = fma(a, b.w, c.w);
|
||||||
|
return d;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Vector sum.
|
||||||
|
template<>
|
||||||
|
inline __device__ float sum(float v) {
|
||||||
|
return v;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
inline __device__ float sum(float2 v) {
|
||||||
|
return v.x + v.y;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
inline __device__ float sum(float4 v) {
|
||||||
|
return v.x + v.y + v.z + v.w;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
inline __device__ float sum(Float4_ v) {
|
||||||
|
return v.x.x + v.x.y + v.y.x + v.y.y;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
inline __device__ float sum(Float8_ v) {
|
||||||
|
return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Vector dot product.
|
||||||
|
inline __device__ float dot(float a, float b) {
|
||||||
|
return a * b;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ float dot(float2 a, float2 b) {
|
||||||
|
float2 c = mul<float2, float2, float2>(a, b);
|
||||||
|
return c.x + c.y;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ float dot(Float4_ a, Float4_ b) {
|
||||||
|
float2 acc = mul<float2, float2, float2>(a.x, b.x);
|
||||||
|
acc = fma(a.y, b.y, acc);
|
||||||
|
return acc.x + acc.y;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ float dot(Float8_ a, Float8_ b) {
|
||||||
|
float2 acc = mul<float2, float2, float2>(a.x, b.x);
|
||||||
|
acc = fma(a.y, b.y, acc);
|
||||||
|
acc = fma(a.z, b.z, acc);
|
||||||
|
acc = fma(a.w, b.w, acc);
|
||||||
|
return acc.x + acc.y;
|
||||||
|
}
|
||||||
|
|
||||||
|
// From float to float.
|
||||||
|
inline __device__ void from_float(float& dst, float src) {
|
||||||
|
dst = src;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ void from_float(float2& dst, float2 src) {
|
||||||
|
dst = src;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ void from_float(float4& dst, float4 src) {
|
||||||
|
dst = src;
|
||||||
|
}
|
||||||
|
|
||||||
|
// From float to float.
|
||||||
|
inline __device__ float to_float(float u) {
|
||||||
|
return u;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ float2 to_float(float2 u) {
|
||||||
|
return u;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ float4 to_float(float4 u) {
|
||||||
|
return u;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ Float4_ to_float(Float4_ u) {
|
||||||
|
return u;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ Float8_ to_float(Float8_ u) {
|
||||||
|
return u;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Zero-out a variable.
|
||||||
|
inline __device__ void zero(float& dst) {
|
||||||
|
dst = 0.f;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace vllm
|
||||||
@ -1,896 +0,0 @@
|
|||||||
#include <torch/extension.h>
|
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
|
||||||
|
|
||||||
#include "attention_utils.h"
|
|
||||||
#include "cuda_primitives.h"
|
|
||||||
#include "reduction_utils.h"
|
|
||||||
|
|
||||||
#include <algorithm>
|
|
||||||
|
|
||||||
#define WARP_SIZE 32
|
|
||||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
|
||||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
|
||||||
|
|
||||||
namespace cacheflow {
|
|
||||||
|
|
||||||
// Grid: (num_heads, num_seqs).
|
|
||||||
template<
|
|
||||||
typename scalar_t,
|
|
||||||
int HEAD_SIZE,
|
|
||||||
int BLOCK_SIZE,
|
|
||||||
int NUM_THREADS>
|
|
||||||
__global__ void single_query_cached_kv_attention_kernel(
|
|
||||||
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
|
||||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
|
||||||
const scalar_t* __restrict__ k_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
|
||||||
const scalar_t* __restrict__ v_cache, // [num_blocks, num_heads, head_size, block_size]
|
|
||||||
const float scale,
|
|
||||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
|
||||||
const int* __restrict__ context_lens, // [num_seqs]
|
|
||||||
const int max_num_blocks_per_seq,
|
|
||||||
const int q_stride) {
|
|
||||||
constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
|
|
||||||
constexpr int NUM_TOKENS_PER_THREAD_GROUP = (BLOCK_SIZE + WARP_SIZE - 1) / WARP_SIZE;
|
|
||||||
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
|
||||||
const int thread_idx = threadIdx.x;
|
|
||||||
const int warp_idx = thread_idx / WARP_SIZE;
|
|
||||||
const int lane = thread_idx % WARP_SIZE;
|
|
||||||
|
|
||||||
const int head_idx = blockIdx.x;
|
|
||||||
const int num_heads = gridDim.x;
|
|
||||||
const int seq_idx = blockIdx.y;
|
|
||||||
|
|
||||||
// A vector type to store a part of a key or a query.
|
|
||||||
// The vector size is configured in such a way that the threads in a thread group
|
|
||||||
// fetch or compute 16 bytes at a time.
|
|
||||||
// For example, if the size of a thread group is 4 and the data type is half,
|
|
||||||
// then the vector size is 16 / (4 * sizeof(half)) == 2.
|
|
||||||
constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
|
|
||||||
using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
|
|
||||||
using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
|
|
||||||
|
|
||||||
constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
|
|
||||||
constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
|
|
||||||
|
|
||||||
const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE;
|
|
||||||
const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE;
|
|
||||||
|
|
||||||
// Load the query to registers.
|
|
||||||
// Each thread in a thread group has a different part of the query.
|
|
||||||
// For example, if the the thread group size is 4, then the first thread in the group
|
|
||||||
// has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ...
|
|
||||||
// th vectors of the query, and so on.
|
|
||||||
// NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous.
|
|
||||||
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
|
|
||||||
Q_vec q_vecs[NUM_VECS_PER_THREAD];
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < NUM_VECS_PER_THREAD; i++) {
|
|
||||||
const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
|
|
||||||
q_vecs[i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Memory planning.
|
|
||||||
extern __shared__ char shared_mem[];
|
|
||||||
// NOTE(woosuk): We use FP32 logits and accumulation.
|
|
||||||
float *logits = reinterpret_cast<float*>(shared_mem);
|
|
||||||
// Workspace for reduction.
|
|
||||||
__shared__ float red_smem[2 * NUM_WARPS];
|
|
||||||
|
|
||||||
// x == THREAD_GROUP_SIZE * VEC_SIZE
|
|
||||||
// Each thread group fetches x elements from the key at a time.
|
|
||||||
constexpr int x = 16 / sizeof(scalar_t);
|
|
||||||
float qk_max = -FLT_MAX;
|
|
||||||
|
|
||||||
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
|
|
||||||
const int context_len = context_lens[seq_idx];
|
|
||||||
const int num_blocks = (context_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
|
||||||
|
|
||||||
// Iterate over the key blocks.
|
|
||||||
// Each warp fetches a block of keys for each iteration.
|
|
||||||
// Each thread group in a warp fetches a key from the block, and computes
|
|
||||||
// dot product with the query.
|
|
||||||
for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) {
|
|
||||||
const int physical_block_number = block_table[block_idx];
|
|
||||||
|
|
||||||
// Load a key to registers.
|
|
||||||
// Each thread in a thread group has a different part of the key.
|
|
||||||
// For example, if the the thread group size is 4, then the first thread in the group
|
|
||||||
// has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th
|
|
||||||
// vectors of the key, and so on.
|
|
||||||
for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
|
|
||||||
const int physical_block_offset = (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
|
|
||||||
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
|
|
||||||
K_vec k_vecs[NUM_VECS_PER_THREAD];
|
|
||||||
|
|
||||||
#pragma unroll
|
|
||||||
for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
|
|
||||||
const scalar_t* k_ptr = k_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE
|
|
||||||
+ head_idx * HEAD_SIZE * BLOCK_SIZE
|
|
||||||
+ physical_block_offset * x;
|
|
||||||
const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
|
|
||||||
const int offset1 = (vec_idx * VEC_SIZE) / x;
|
|
||||||
const int offset2 = (vec_idx * VEC_SIZE) % x;
|
|
||||||
k_vecs[j] = *reinterpret_cast<const K_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Compute dot product.
|
|
||||||
// This includes a reduction across the threads in the same thread group.
|
|
||||||
const float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs, k_vecs);
|
|
||||||
const bool mask = token_idx >= context_len;
|
|
||||||
|
|
||||||
if (thread_group_offset == 0) {
|
|
||||||
// Store the partial reductions to shared memory.
|
|
||||||
// NOTE(woosuk): It is required to zero out the masked logits.
|
|
||||||
logits[token_idx] = mask ? 0.f : qk;
|
|
||||||
// Update the max value.
|
|
||||||
qk_max = mask ? qk_max : fmaxf(qk_max, qk);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Perform reduction across the threads in the same warp to get the
|
|
||||||
// max qk value for each "warp" (not across the thread block yet).
|
|
||||||
// The 0-th thread of each thread group already has its max qk value.
|
|
||||||
#pragma unroll
|
|
||||||
for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
|
|
||||||
qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
|
|
||||||
}
|
|
||||||
if (lane == 0) {
|
|
||||||
red_smem[warp_idx] = qk_max;
|
|
||||||
}
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
// TODO(woosuk): Refactor this part.
|
|
||||||
// Get the max qk value for the sequence.
|
|
||||||
qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
|
|
||||||
#pragma unroll
|
|
||||||
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
|
|
||||||
qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
|
|
||||||
}
|
|
||||||
// Broadcast the max qk value to all threads.
|
|
||||||
qk_max = __shfl_sync(uint32_t(-1), qk_max, 0);
|
|
||||||
|
|
||||||
// Get the sum of the exp values.
|
|
||||||
float exp_sum = 0.f;
|
|
||||||
for (int i = thread_idx; i < context_len; i += NUM_THREADS) {
|
|
||||||
float val = __expf(logits[i] - qk_max);
|
|
||||||
logits[i] = val;
|
|
||||||
exp_sum += val;
|
|
||||||
}
|
|
||||||
exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum);
|
|
||||||
|
|
||||||
// Compute softmax.
|
|
||||||
const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
|
|
||||||
for (int i = thread_idx; i < context_len; i += NUM_THREADS) {
|
|
||||||
logits[i] *= inv_sum;
|
|
||||||
}
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
// Each thread will fetch 16 bytes from the value cache at a time.
|
|
||||||
constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
|
|
||||||
using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
|
|
||||||
using L_vec = typename FloatVec<V_vec>::Type;
|
|
||||||
|
|
||||||
constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
|
|
||||||
constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
|
|
||||||
constexpr int NUM_ROWS_PER_THREAD = (HEAD_SIZE + NUM_ROWS_PER_ITER - 1) / NUM_ROWS_PER_ITER;
|
|
||||||
|
|
||||||
float accs[NUM_ROWS_PER_THREAD];
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
|
||||||
accs[i] = 0.f;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) {
|
|
||||||
const int physical_block_number = block_table[block_idx];
|
|
||||||
const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
|
|
||||||
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
|
|
||||||
L_vec logits_vec = *reinterpret_cast<L_vec*>(logits + token_idx);
|
|
||||||
|
|
||||||
const scalar_t* v_ptr = v_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE
|
|
||||||
+ head_idx * HEAD_SIZE * BLOCK_SIZE;
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
|
||||||
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
|
||||||
if (row_idx < HEAD_SIZE) {
|
|
||||||
const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
|
|
||||||
V_vec v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
|
|
||||||
accs[i] += dot(logits_vec, cast_to_float(v_vec));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Perform reduction within each warp.
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
|
||||||
float acc = accs[i];
|
|
||||||
#pragma unroll
|
|
||||||
for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
|
|
||||||
acc += __shfl_xor_sync(uint32_t(-1), acc, mask);
|
|
||||||
}
|
|
||||||
accs[i] = acc;
|
|
||||||
}
|
|
||||||
|
|
||||||
// NOTE(woosuk): A barrier is required because the shared memory space for logits
|
|
||||||
// is reused for the output.
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
// Perform reduction across warps.
|
|
||||||
float* out_smem = reinterpret_cast<float*>(shared_mem);
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = NUM_WARPS; i > 1; i /= 2) {
|
|
||||||
int mid = i / 2;
|
|
||||||
// Upper warps write to shared memory.
|
|
||||||
if (warp_idx >= mid && warp_idx < i) {
|
|
||||||
float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
|
||||||
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
|
||||||
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
|
|
||||||
dst[row_idx] = accs[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
// Lower warps update the output.
|
|
||||||
if (warp_idx < mid) {
|
|
||||||
const float* src = &out_smem[warp_idx * HEAD_SIZE];
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
|
||||||
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
|
||||||
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
|
|
||||||
accs[i] += src[row_idx];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
__syncthreads();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write the final output.
|
|
||||||
if (warp_idx == 0) {
|
|
||||||
scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
|
||||||
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
|
||||||
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
|
|
||||||
convert_from_float(*(out_ptr + row_idx), accs[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace cacheflow
|
|
||||||
|
|
||||||
#define LAUNCH_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \
|
|
||||||
cacheflow::single_query_cached_kv_attention_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS> \
|
|
||||||
<<<grid, block, shared_mem_size, stream>>>( \
|
|
||||||
out_ptr, \
|
|
||||||
query_ptr, \
|
|
||||||
key_cache_ptr, \
|
|
||||||
value_cache_ptr, \
|
|
||||||
scale, \
|
|
||||||
block_tables_ptr, \
|
|
||||||
context_lens_ptr, \
|
|
||||||
max_num_blocks_per_seq, \
|
|
||||||
query_stride);
|
|
||||||
|
|
||||||
// TODO(woosuk): Tune NUM_THREADS.
|
|
||||||
template<
|
|
||||||
typename T,
|
|
||||||
int BLOCK_SIZE,
|
|
||||||
int NUM_THREADS = 128>
|
|
||||||
void single_query_cached_kv_attention_launcher(
|
|
||||||
torch::Tensor& out,
|
|
||||||
torch::Tensor& query,
|
|
||||||
torch::Tensor& key_cache,
|
|
||||||
torch::Tensor& value_cache,
|
|
||||||
float scale,
|
|
||||||
torch::Tensor& block_tables,
|
|
||||||
torch::Tensor& context_lens,
|
|
||||||
int max_context_len) {
|
|
||||||
int num_seqs = query.size(0);
|
|
||||||
int num_heads = query.size(1);
|
|
||||||
int head_size = query.size(2);
|
|
||||||
int max_num_blocks_per_seq = block_tables.size(1);
|
|
||||||
int query_stride = query.stride(0);
|
|
||||||
|
|
||||||
int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
|
|
||||||
assert(head_size % thread_group_size == 0);
|
|
||||||
|
|
||||||
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
|
|
||||||
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
|
|
||||||
T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
|
|
||||||
T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
|
|
||||||
int* block_tables_ptr = block_tables.data_ptr<int>();
|
|
||||||
int* context_lens_ptr = context_lens.data_ptr<int>();
|
|
||||||
|
|
||||||
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
|
||||||
int padded_max_context_len = ((max_context_len + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE;
|
|
||||||
int logits_size = padded_max_context_len * sizeof(float);
|
|
||||||
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
|
|
||||||
int shared_mem_size = std::max(logits_size, outputs_size);
|
|
||||||
|
|
||||||
dim3 grid(num_heads, num_seqs);
|
|
||||||
dim3 block(NUM_THREADS);
|
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
|
||||||
switch (head_size) {
|
|
||||||
case 32:
|
|
||||||
LAUNCH_ATTENTION_KERNEL(T, 32, BLOCK_SIZE, NUM_THREADS);
|
|
||||||
break;
|
|
||||||
case 64:
|
|
||||||
LAUNCH_ATTENTION_KERNEL(T, 64, BLOCK_SIZE, NUM_THREADS);
|
|
||||||
break;
|
|
||||||
case 80:
|
|
||||||
LAUNCH_ATTENTION_KERNEL(T, 80, BLOCK_SIZE, NUM_THREADS);
|
|
||||||
break;
|
|
||||||
case 96:
|
|
||||||
LAUNCH_ATTENTION_KERNEL(T, 96, BLOCK_SIZE, NUM_THREADS);
|
|
||||||
break;
|
|
||||||
case 128:
|
|
||||||
LAUNCH_ATTENTION_KERNEL(T, 128, BLOCK_SIZE, NUM_THREADS);
|
|
||||||
break;
|
|
||||||
case 160:
|
|
||||||
LAUNCH_ATTENTION_KERNEL(T, 160, BLOCK_SIZE, NUM_THREADS);
|
|
||||||
break;
|
|
||||||
case 192:
|
|
||||||
LAUNCH_ATTENTION_KERNEL(T, 192, BLOCK_SIZE, NUM_THREADS);
|
|
||||||
break;
|
|
||||||
case 256:
|
|
||||||
LAUNCH_ATTENTION_KERNEL(T, 256, BLOCK_SIZE, NUM_THREADS);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
assert(false);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#define CALL_KERNEL_LAUNCHER(T, BLOCK_SIZE) \
|
|
||||||
single_query_cached_kv_attention_launcher<T, BLOCK_SIZE>( \
|
|
||||||
out, \
|
|
||||||
query, \
|
|
||||||
key_cache, \
|
|
||||||
value_cache, \
|
|
||||||
scale, \
|
|
||||||
block_tables, \
|
|
||||||
context_lens, \
|
|
||||||
max_context_len);
|
|
||||||
|
|
||||||
void single_query_cached_kv_attention(
|
|
||||||
torch::Tensor& out, // [num_seqs, num_heads, head_size]
|
|
||||||
torch::Tensor& query, // [num_seqs, num_heads, head_size]
|
|
||||||
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
|
||||||
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
|
|
||||||
float scale,
|
|
||||||
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
|
|
||||||
torch::Tensor& context_lens, // [num_seqs]
|
|
||||||
int block_size,
|
|
||||||
int max_context_len) {
|
|
||||||
// TODO(woosuk): Support BF16.
|
|
||||||
if (query.element_size() == 2) {
|
|
||||||
// Half.
|
|
||||||
if (block_size == 1) {
|
|
||||||
CALL_KERNEL_LAUNCHER(uint16_t, 1);
|
|
||||||
} else if (block_size == 2) {
|
|
||||||
CALL_KERNEL_LAUNCHER(uint16_t, 2);
|
|
||||||
} else if (block_size == 4) {
|
|
||||||
CALL_KERNEL_LAUNCHER(uint16_t, 4);
|
|
||||||
} else if (block_size == 8) {
|
|
||||||
CALL_KERNEL_LAUNCHER(uint16_t, 8);
|
|
||||||
} else if (block_size == 16) {
|
|
||||||
CALL_KERNEL_LAUNCHER(uint16_t, 16);
|
|
||||||
} else if (block_size == 32) {
|
|
||||||
CALL_KERNEL_LAUNCHER(uint16_t, 32);
|
|
||||||
} else if (block_size == 64) {
|
|
||||||
CALL_KERNEL_LAUNCHER(uint16_t, 64);
|
|
||||||
} else if (block_size == 128) {
|
|
||||||
CALL_KERNEL_LAUNCHER(uint16_t, 128);
|
|
||||||
} else if (block_size == 256) {
|
|
||||||
CALL_KERNEL_LAUNCHER(uint16_t, 256);
|
|
||||||
} else {
|
|
||||||
assert(false);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Float.
|
|
||||||
assert(false);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// namespace cacheflow {
|
|
||||||
|
|
||||||
// // Grid: (num_heads, num_query_tokens).
|
|
||||||
// template<
|
|
||||||
// typename scalar_t,
|
|
||||||
// int HEAD_SIZE,
|
|
||||||
// int BLOCK_SIZE,
|
|
||||||
// int NUM_THREADS>
|
|
||||||
// __device__ void multi_query_cached_kv_attention_kernel_unoptimized_(
|
|
||||||
// scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
|
||||||
// const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
|
||||||
// const int seq_start_idx,
|
|
||||||
// const int seq_len,
|
|
||||||
// const scalar_t* __restrict__ k_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
|
||||||
// const scalar_t* __restrict__ v_cache, // [num_blocks, num_heads, head_size, block_size]
|
|
||||||
// const float scale,
|
|
||||||
// const int* __restrict__ block_table, // [num_seqs, max_num_blocks_per_seq]
|
|
||||||
// const int context_len,
|
|
||||||
// const int max_num_blocks_per_seq,
|
|
||||||
// const int q_stride) {
|
|
||||||
// constexpr int THREAD_GROUP_SIZE = WARP_SIZE / BLOCK_SIZE;
|
|
||||||
// constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
|
||||||
// const int thread_idx = threadIdx.x;
|
|
||||||
// const int warp_idx = thread_idx / WARP_SIZE;
|
|
||||||
// const int lane = thread_idx % WARP_SIZE;
|
|
||||||
|
|
||||||
// const int head_idx = blockIdx.x;
|
|
||||||
// const int num_heads = gridDim.x;
|
|
||||||
// const int seq_idx = blockIdx.y;
|
|
||||||
|
|
||||||
// // A vector type to store a part of a key or a query.
|
|
||||||
// // The vector size is configured in such a way that the threads in a thread group
|
|
||||||
// // fetch or comput 16 bytes at a time.
|
|
||||||
// // For example, if the size of a thread group is 4 and the data type is half,
|
|
||||||
// // then the vector size is 16 / (4 * sizeof(half)) == 2.
|
|
||||||
// constexpr int VEC_SIZE = 16 / (THREAD_GROUP_SIZE * sizeof(scalar_t));
|
|
||||||
// using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
|
|
||||||
// using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
|
|
||||||
|
|
||||||
// constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
|
|
||||||
// constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
|
|
||||||
|
|
||||||
// const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE;
|
|
||||||
// const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE;
|
|
||||||
|
|
||||||
// // Load the query to registers.
|
|
||||||
// // Each thread in a thread group has a different part of the query.
|
|
||||||
// // For example, if the the thread group size is 4, then the first thread in the group
|
|
||||||
// // has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ...
|
|
||||||
// // th vectors of the query, and so on.
|
|
||||||
// // NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous.
|
|
||||||
// const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
|
|
||||||
// Q_vec q_vecs[NUM_VECS_PER_THREAD];
|
|
||||||
// #pragma unroll
|
|
||||||
// for (int i = 0; i < NUM_VECS_PER_THREAD; i++) {
|
|
||||||
// const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
|
|
||||||
// q_vecs[i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
|
|
||||||
// }
|
|
||||||
|
|
||||||
// // Memory planning.
|
|
||||||
// extern __shared__ char shared_mem[];
|
|
||||||
// // NOTE(woosuk): We use FP32 logits and accumulation.
|
|
||||||
// float *logits = reinterpret_cast<float*>(shared_mem);
|
|
||||||
// // Workspace for reduction.
|
|
||||||
// __shared__ float red_smem[2 * NUM_WARPS];
|
|
||||||
|
|
||||||
// // x == THREAD_GROUP_SIZE * VEC_SIZE
|
|
||||||
// // Each thread group fetches x elements from the key at a time.
|
|
||||||
// constexpr int x = 16 / sizeof(scalar_t);
|
|
||||||
// float qk_max = -FLT_MAX;
|
|
||||||
|
|
||||||
// const int num_blocks = (context_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
|
||||||
// const int mask_boundary = context_len - seq_len + 1 + (seq_idx - seq_start_idx);
|
|
||||||
|
|
||||||
// // Iterate over the key blocks.
|
|
||||||
// // Each warp fetches a block of keys for each iteration.
|
|
||||||
// // Each thread group in a warp fetches a key from the block, and computes
|
|
||||||
// // dot product with the query.
|
|
||||||
// for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) {
|
|
||||||
// const int physical_block_number = block_table[block_idx];
|
|
||||||
// const int physical_block_offset = thread_group_idx % BLOCK_SIZE;
|
|
||||||
// const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
|
|
||||||
|
|
||||||
// // Load a key to registers.
|
|
||||||
// // Each thread in a thread group has a different part of the key.
|
|
||||||
// // For example, if the the thread group size is 4, then the first thread in the group
|
|
||||||
// // has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th
|
|
||||||
// // vectors of the key, and so on.
|
|
||||||
// K_vec k_vecs[NUM_VECS_PER_THREAD];
|
|
||||||
// #pragma unroll
|
|
||||||
// for (int i = 0; i < NUM_VECS_PER_THREAD; i++) {
|
|
||||||
// const scalar_t* k_ptr = k_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE
|
|
||||||
// + head_idx * HEAD_SIZE * BLOCK_SIZE
|
|
||||||
// + physical_block_offset * x;
|
|
||||||
// const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
|
|
||||||
// const int offset1 = (vec_idx * VEC_SIZE) / x;
|
|
||||||
// const int offset2 = (vec_idx * VEC_SIZE) % x;
|
|
||||||
// k_vecs[i] = *reinterpret_cast<const K_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
|
|
||||||
// }
|
|
||||||
|
|
||||||
// // Compute dot product.
|
|
||||||
// // This includes a reduction across the threads in the same thread group.
|
|
||||||
// const float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs, k_vecs);
|
|
||||||
// const bool mask = token_idx >= mask_boundary;
|
|
||||||
|
|
||||||
// if (thread_group_offset == 0) {
|
|
||||||
// // Store the partial reductions to shared memory.
|
|
||||||
// // NOTE(woosuk): It is required to zero out the masked logits.
|
|
||||||
// logits[token_idx] = mask ? 0.f : qk;
|
|
||||||
// // Update the max value.
|
|
||||||
// qk_max = mask ? qk_max : fmaxf(qk_max, qk);
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
// // Perform reduction across the threads in the same warp to get the
|
|
||||||
// // max qk value for each "warp" (not across the thread block yet).
|
|
||||||
// // The 0-th thread of each thread group already has its max qk value.
|
|
||||||
// #pragma unroll
|
|
||||||
// for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
|
|
||||||
// qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
|
|
||||||
// }
|
|
||||||
// if (lane == 0) {
|
|
||||||
// red_smem[warp_idx] = qk_max;
|
|
||||||
// }
|
|
||||||
// __syncthreads();
|
|
||||||
|
|
||||||
// // TODO(woosuk): Refactor this part.
|
|
||||||
// // Get the max qk value for the sequence.
|
|
||||||
// qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
|
|
||||||
// #pragma unroll
|
|
||||||
// for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
|
|
||||||
// qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
|
|
||||||
// }
|
|
||||||
// // Broadcast the max qk value to all threads.
|
|
||||||
// qk_max = __shfl_sync(uint32_t(-1), qk_max, 0);
|
|
||||||
|
|
||||||
// // Get the sum of the exp values.
|
|
||||||
// float exp_sum = 0.f;
|
|
||||||
// for (int i = thread_idx; i < mask_boundary; i += NUM_THREADS) {
|
|
||||||
// float val = __expf(logits[i] - qk_max);
|
|
||||||
// logits[i] = val;
|
|
||||||
// exp_sum += val;
|
|
||||||
// }
|
|
||||||
// exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum);
|
|
||||||
|
|
||||||
// // Compute softmax.
|
|
||||||
// const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
|
|
||||||
// for (int i = thread_idx; i < context_len; i += NUM_THREADS) {
|
|
||||||
// logits[i] *= inv_sum;
|
|
||||||
// }
|
|
||||||
// __syncthreads();
|
|
||||||
|
|
||||||
// // Each thread will fetch 16 bytes from the value cache at a time.
|
|
||||||
// constexpr int V_VEC_SIZE = 16 / sizeof(scalar_t);
|
|
||||||
// using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
|
|
||||||
// using L_vec = typename FloatVec<V_vec>::Type;
|
|
||||||
|
|
||||||
// constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
|
|
||||||
// constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
|
|
||||||
// constexpr int NUM_ROWS_PER_THREAD = (HEAD_SIZE + NUM_ROWS_PER_ITER - 1) / NUM_ROWS_PER_ITER;
|
|
||||||
|
|
||||||
// float accs[NUM_ROWS_PER_THREAD];
|
|
||||||
// #pragma unroll
|
|
||||||
// for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
|
||||||
// accs[i] = 0.f;
|
|
||||||
// }
|
|
||||||
|
|
||||||
// for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) {
|
|
||||||
// const int physical_block_number = block_table[block_idx];
|
|
||||||
// const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
|
|
||||||
// const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
|
|
||||||
// L_vec logits_vec = *reinterpret_cast<L_vec*>(logits + token_idx);
|
|
||||||
|
|
||||||
// const scalar_t* v_ptr = v_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE
|
|
||||||
// + head_idx * HEAD_SIZE * BLOCK_SIZE;
|
|
||||||
// #pragma unroll
|
|
||||||
// for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
|
||||||
// const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
|
||||||
// if (row_idx < HEAD_SIZE) {
|
|
||||||
// const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
|
|
||||||
// V_vec v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
|
|
||||||
// accs[i] += dot(logits_vec, cast_to_float(v_vec));
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
// // Perform reduction within each warp.
|
|
||||||
// #pragma unroll
|
|
||||||
// for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
|
||||||
// float acc = accs[i];
|
|
||||||
// #pragma unroll
|
|
||||||
// for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
|
|
||||||
// acc += __shfl_xor_sync(uint32_t(-1), acc, mask);
|
|
||||||
// }
|
|
||||||
// accs[i] = acc;
|
|
||||||
// }
|
|
||||||
|
|
||||||
// // NOTE(woosuk): A barrier is required because the shared memory space for logits
|
|
||||||
// // is reused for the output.
|
|
||||||
// __syncthreads();
|
|
||||||
|
|
||||||
// // Perform reduction across warps.
|
|
||||||
// float* out_smem = reinterpret_cast<float*>(shared_mem);
|
|
||||||
// #pragma unroll
|
|
||||||
// for (int i = NUM_WARPS; i > 1; i /= 2) {
|
|
||||||
// int mid = i / 2;
|
|
||||||
// // Upper warps write to shared memory.
|
|
||||||
// if (warp_idx >= mid && warp_idx < i) {
|
|
||||||
// float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
|
|
||||||
// #pragma unroll
|
|
||||||
// for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
|
||||||
// const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
|
||||||
// if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
|
|
||||||
// dst[row_idx] = accs[i];
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// __syncthreads();
|
|
||||||
|
|
||||||
// // Lower warps update the output.
|
|
||||||
// if (warp_idx < mid) {
|
|
||||||
// const float* src = &out_smem[warp_idx * HEAD_SIZE];
|
|
||||||
// #pragma unroll
|
|
||||||
// for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
|
||||||
// const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
|
||||||
// if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
|
|
||||||
// accs[i] += src[row_idx];
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// __syncthreads();
|
|
||||||
// }
|
|
||||||
|
|
||||||
// // Write the final output.
|
|
||||||
// if (warp_idx == 0) {
|
|
||||||
// scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
|
|
||||||
// #pragma unroll
|
|
||||||
// for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
|
||||||
// const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
|
||||||
// if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
|
|
||||||
// convert_from_float(*(out_ptr + row_idx), accs[i]);
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
|
|
||||||
// // Grid: (num_heads, num_query_tokens).
|
|
||||||
// template<
|
|
||||||
// typename scalar_t,
|
|
||||||
// int HEAD_SIZE,
|
|
||||||
// int BLOCK_SIZE,
|
|
||||||
// int NUM_THREADS>
|
|
||||||
// __global__ void multi_query_cached_kv_attention_kernel(
|
|
||||||
// const int* cu_query_lens, // [num_prompts+1]
|
|
||||||
// const int* seq_prompt_mapping, // [num_seqs] mapping from seq_idx to prompt_idx
|
|
||||||
// scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
|
||||||
// const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
|
||||||
// const scalar_t* __restrict__ k_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
|
||||||
// const scalar_t* __restrict__ v_cache, // [num_blocks, num_heads, head_size, block_size]
|
|
||||||
// const float scale,
|
|
||||||
// const int* __restrict__ block_tables, // [num_prompts, max_num_blocks_per_seq]
|
|
||||||
// const int* __restrict__ context_lens, // [num_prompts]
|
|
||||||
// const int max_num_blocks_per_seq,
|
|
||||||
// const int q_stride) {
|
|
||||||
// const int seq_idx = blockIdx.y;
|
|
||||||
// const int prompt_idx = seq_prompt_mapping[seq_idx];
|
|
||||||
// const int seq_start_idx = cu_query_lens[prompt_idx];
|
|
||||||
// const int seq_len = cu_query_lens[prompt_idx + 1] - seq_start_idx;
|
|
||||||
// const int* block_table = block_tables + prompt_idx * max_num_blocks_per_seq;
|
|
||||||
// const int context_len = context_lens[prompt_idx];
|
|
||||||
// multi_query_cached_kv_attention_kernel_unoptimized_<
|
|
||||||
// scalar_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>(
|
|
||||||
// out,
|
|
||||||
// q,
|
|
||||||
// seq_start_idx,
|
|
||||||
// seq_len,
|
|
||||||
// k_cache,
|
|
||||||
// v_cache,
|
|
||||||
// scale,
|
|
||||||
// block_table,
|
|
||||||
// context_len,
|
|
||||||
// max_num_blocks_per_seq,
|
|
||||||
// q_stride);
|
|
||||||
// }
|
|
||||||
|
|
||||||
// } // namespace cacheflow
|
|
||||||
|
|
||||||
// #define LAUNCH_MULTI_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \
|
|
||||||
// cacheflow::multi_query_cached_kv_attention_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS> \
|
|
||||||
// <<<grid, block, shared_mem_size, stream>>>( \
|
|
||||||
// cu_query_lens_ptr, \
|
|
||||||
// seq_prompt_mapping_ptr, \
|
|
||||||
// out_ptr, \
|
|
||||||
// query_ptr, \
|
|
||||||
// key_cache_ptr, \
|
|
||||||
// value_cache_ptr, \
|
|
||||||
// scale, \
|
|
||||||
// block_tables_ptr, \
|
|
||||||
// context_lens_ptr, \
|
|
||||||
// max_num_blocks_per_seq, \
|
|
||||||
// query_stride);
|
|
||||||
|
|
||||||
|
|
||||||
// // TODO(woosuk): Tune NUM_THREADS.
|
|
||||||
// template<
|
|
||||||
// typename T,
|
|
||||||
// int BLOCK_SIZE,
|
|
||||||
// int NUM_THREADS = 128>
|
|
||||||
// void multi_query_cached_kv_attention_launcher(
|
|
||||||
// torch::Tensor& cu_query_lens,
|
|
||||||
// torch::Tensor& seq_prompt_mapping,
|
|
||||||
// torch::Tensor& out,
|
|
||||||
// torch::Tensor& query,
|
|
||||||
// torch::Tensor& key_cache,
|
|
||||||
// torch::Tensor& value_cache,
|
|
||||||
// float scale,
|
|
||||||
// torch::Tensor& block_tables,
|
|
||||||
// torch::Tensor& context_lens,
|
|
||||||
// int max_context_len) {
|
|
||||||
// int num_seqs = query.size(0);
|
|
||||||
// int num_heads = query.size(1);
|
|
||||||
// int head_size = query.size(2);
|
|
||||||
// int max_num_blocks_per_seq = block_tables.size(1);
|
|
||||||
// int query_stride = query.stride(0);
|
|
||||||
|
|
||||||
// int* cu_query_lens_ptr = cu_query_lens.data_ptr<int>();
|
|
||||||
// int* seq_prompt_mapping_ptr = seq_prompt_mapping.data_ptr<int>();
|
|
||||||
// T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
|
|
||||||
// T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
|
|
||||||
// T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
|
|
||||||
// T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
|
|
||||||
// int* block_tables_ptr = block_tables.data_ptr<int>();
|
|
||||||
// int* context_lens_ptr = context_lens.data_ptr<int>();
|
|
||||||
|
|
||||||
// constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
|
||||||
// int padded_max_context_len = ((max_context_len + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE;
|
|
||||||
// int logits_size = padded_max_context_len * sizeof(float);
|
|
||||||
// int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
|
|
||||||
// int shared_mem_size = std::max(logits_size, outputs_size);
|
|
||||||
|
|
||||||
// dim3 grid(num_heads, num_seqs);
|
|
||||||
// dim3 block(NUM_THREADS);
|
|
||||||
// const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
|
||||||
// switch (head_size) {
|
|
||||||
// case 32:
|
|
||||||
// LAUNCH_MULTI_ATTENTION_KERNEL(T, 32, BLOCK_SIZE, NUM_THREADS);
|
|
||||||
// break;
|
|
||||||
// case 64:
|
|
||||||
// LAUNCH_MULTI_ATTENTION_KERNEL(T, 64, BLOCK_SIZE, NUM_THREADS);
|
|
||||||
// break;
|
|
||||||
// case 80:
|
|
||||||
// LAUNCH_MULTI_ATTENTION_KERNEL(T, 80, BLOCK_SIZE, NUM_THREADS);
|
|
||||||
// break;
|
|
||||||
// case 96:
|
|
||||||
// LAUNCH_MULTI_ATTENTION_KERNEL(T, 96, BLOCK_SIZE, NUM_THREADS);
|
|
||||||
// break;
|
|
||||||
// case 128:
|
|
||||||
// LAUNCH_MULTI_ATTENTION_KERNEL(T, 128, BLOCK_SIZE, NUM_THREADS);
|
|
||||||
// break;
|
|
||||||
// case 160:
|
|
||||||
// LAUNCH_MULTI_ATTENTION_KERNEL(T, 160, BLOCK_SIZE, NUM_THREADS);
|
|
||||||
// break;
|
|
||||||
// case 192:
|
|
||||||
// LAUNCH_MULTI_ATTENTION_KERNEL(T, 192, BLOCK_SIZE, NUM_THREADS);
|
|
||||||
// break;
|
|
||||||
// case 256:
|
|
||||||
// LAUNCH_MULTI_ATTENTION_KERNEL(T, 256, BLOCK_SIZE, NUM_THREADS);
|
|
||||||
// break;
|
|
||||||
// default:
|
|
||||||
// assert(false);
|
|
||||||
// break;
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
// void multi_query_cached_kv_attention(
|
|
||||||
// torch::Tensor& cu_query_lens,
|
|
||||||
// torch::Tensor& out,
|
|
||||||
// torch::Tensor& query,
|
|
||||||
// torch::Tensor& key_cache,
|
|
||||||
// torch::Tensor& value_cache,
|
|
||||||
// float scale,
|
|
||||||
// torch::Tensor& block_tables,
|
|
||||||
// torch::Tensor& context_lens,
|
|
||||||
// int block_size,
|
|
||||||
// int max_context_len) {
|
|
||||||
|
|
||||||
// torch::Tensor query_lens = cu_query_lens.to(torch::kCPU);
|
|
||||||
|
|
||||||
// int num_queries = query_lens.size(0) - 1;
|
|
||||||
// const int* query_lens_ptr = query_lens.data_ptr<int>();
|
|
||||||
// int num_seqs = query.size(0);
|
|
||||||
|
|
||||||
// torch::Tensor cpu_tensor = torch::empty({num_seqs}, torch::dtype(torch::kInt32));
|
|
||||||
// auto accessor = cpu_tensor.accessor<int32_t, 1>();
|
|
||||||
// for (int i = 0, query_cursor = 0; i < num_seqs; ++i) {
|
|
||||||
// if (i >= query_lens_ptr[query_cursor + 1]) {
|
|
||||||
// ++query_cursor;
|
|
||||||
// }
|
|
||||||
// accessor[i] = query_cursor;
|
|
||||||
// }
|
|
||||||
|
|
||||||
// // TODO(suquark): This can be slow, as it to(torch::kCPU) and to(torch::kCUDA)
|
|
||||||
// // implicitly synchronizes the CPU and GPU. And we can avoid this issue by giving
|
|
||||||
// // the mapping as an input parameter. Let's do this optimization in a later PR.
|
|
||||||
// torch::Tensor seq_prompt_mapping = cpu_tensor.to(torch::kCUDA);
|
|
||||||
|
|
||||||
// // TODO(woosuk): Support BF16.
|
|
||||||
// if (query.element_size() == 2) {
|
|
||||||
// // Half.
|
|
||||||
// if (block_size == 8) {
|
|
||||||
// multi_query_cached_kv_attention_launcher<uint16_t, 8>(
|
|
||||||
// cu_query_lens,
|
|
||||||
// seq_prompt_mapping,
|
|
||||||
// out,
|
|
||||||
// query,
|
|
||||||
// key_cache,
|
|
||||||
// value_cache,
|
|
||||||
// scale,
|
|
||||||
// block_tables,
|
|
||||||
// context_lens,
|
|
||||||
// max_context_len);
|
|
||||||
// } else if (block_size == 16) {
|
|
||||||
// multi_query_cached_kv_attention_launcher<uint16_t, 16>(
|
|
||||||
// cu_query_lens,
|
|
||||||
// seq_prompt_mapping,
|
|
||||||
// out,
|
|
||||||
// query,
|
|
||||||
// key_cache,
|
|
||||||
// value_cache,
|
|
||||||
// scale,
|
|
||||||
// block_tables,
|
|
||||||
// context_lens,
|
|
||||||
// max_context_len);
|
|
||||||
// } else if (block_size == 32) {
|
|
||||||
// multi_query_cached_kv_attention_launcher<uint16_t, 32>(
|
|
||||||
// cu_query_lens,
|
|
||||||
// seq_prompt_mapping,
|
|
||||||
// out,
|
|
||||||
// query,
|
|
||||||
// key_cache,
|
|
||||||
// value_cache,
|
|
||||||
// scale,
|
|
||||||
// block_tables,
|
|
||||||
// context_lens,
|
|
||||||
// max_context_len);
|
|
||||||
// } else {
|
|
||||||
// assert(false);
|
|
||||||
// }
|
|
||||||
// } else if (query.element_size() == 4) {
|
|
||||||
// // Float.
|
|
||||||
// if (block_size == 8) {
|
|
||||||
// multi_query_cached_kv_attention_launcher<float, 8>(
|
|
||||||
// cu_query_lens,
|
|
||||||
// seq_prompt_mapping,
|
|
||||||
// out,
|
|
||||||
// query,
|
|
||||||
// key_cache,
|
|
||||||
// value_cache,
|
|
||||||
// scale,
|
|
||||||
// block_tables,
|
|
||||||
// context_lens,
|
|
||||||
// max_context_len);
|
|
||||||
// } else if (block_size == 16) {
|
|
||||||
// multi_query_cached_kv_attention_launcher<float, 16>(
|
|
||||||
// cu_query_lens,
|
|
||||||
// seq_prompt_mapping,
|
|
||||||
// out,
|
|
||||||
// query,
|
|
||||||
// key_cache,
|
|
||||||
// value_cache,
|
|
||||||
// scale,
|
|
||||||
// block_tables,
|
|
||||||
// context_lens,
|
|
||||||
// max_context_len);
|
|
||||||
// } else if (block_size == 32) {
|
|
||||||
// multi_query_cached_kv_attention_launcher<float, 32>(
|
|
||||||
// cu_query_lens,
|
|
||||||
// seq_prompt_mapping,
|
|
||||||
// out,
|
|
||||||
// query,
|
|
||||||
// key_cache,
|
|
||||||
// value_cache,
|
|
||||||
// scale,
|
|
||||||
// block_tables,
|
|
||||||
// context_lens,
|
|
||||||
// max_context_len);
|
|
||||||
// } else {
|
|
||||||
// assert(false);
|
|
||||||
// }
|
|
||||||
// } else {
|
|
||||||
// assert(false);
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
#undef WARP_SIZE
|
|
||||||
#undef MAX
|
|
||||||
#undef MIN
|
|
||||||
@ -1,165 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "cuda_primitives.h"
|
|
||||||
|
|
||||||
#include <float.h>
|
|
||||||
#include <type_traits>
|
|
||||||
|
|
||||||
#define MMHA_USE_FP32_ACUM_FOR_FMA
|
|
||||||
#define MMHA_USE_FP32_ACUM_FOR_OUT
|
|
||||||
|
|
||||||
namespace cacheflow {
|
|
||||||
|
|
||||||
// A vector type to store Q, K, V elements.
|
|
||||||
template<typename T, int VEC_SIZE>
|
|
||||||
struct Vec {};
|
|
||||||
template<>
|
|
||||||
struct Vec<float, 1> {
|
|
||||||
using Type = float;
|
|
||||||
};
|
|
||||||
template<>
|
|
||||||
struct Vec<float, 2> {
|
|
||||||
using Type = float2;
|
|
||||||
};
|
|
||||||
template<>
|
|
||||||
struct Vec<float, 4> {
|
|
||||||
using Type = float4;
|
|
||||||
};
|
|
||||||
template<>
|
|
||||||
struct Vec<uint16_t, 1> {
|
|
||||||
using Type = uint16_t;
|
|
||||||
};
|
|
||||||
template<>
|
|
||||||
struct Vec<uint16_t, 2> {
|
|
||||||
using Type = uint32_t;
|
|
||||||
};
|
|
||||||
template<>
|
|
||||||
struct Vec<uint16_t, 4> {
|
|
||||||
using Type = uint2;
|
|
||||||
};
|
|
||||||
template<>
|
|
||||||
struct Vec<uint16_t, 8> {
|
|
||||||
using Type = uint4;
|
|
||||||
};
|
|
||||||
|
|
||||||
template<typename T>
|
|
||||||
struct FloatVec {};
|
|
||||||
template<>
|
|
||||||
struct FloatVec<float> {
|
|
||||||
using Type = float;
|
|
||||||
};
|
|
||||||
template<>
|
|
||||||
struct FloatVec<float2> {
|
|
||||||
using Type = float2;
|
|
||||||
};
|
|
||||||
template<>
|
|
||||||
struct FloatVec<float4> {
|
|
||||||
using Type = float4;
|
|
||||||
};
|
|
||||||
template<>
|
|
||||||
struct FloatVec<uint16_t> {
|
|
||||||
using Type = float;
|
|
||||||
};
|
|
||||||
template<>
|
|
||||||
struct FloatVec<uint32_t> {
|
|
||||||
using Type = float2;
|
|
||||||
};
|
|
||||||
template<>
|
|
||||||
struct FloatVec<uint2> {
|
|
||||||
using Type = Float4_;
|
|
||||||
};
|
|
||||||
template<>
|
|
||||||
struct FloatVec<uint4> {
|
|
||||||
using Type = Float8_;
|
|
||||||
};
|
|
||||||
|
|
||||||
template<int THREADS_PER_KEY, typename K_vec, int N>
|
|
||||||
inline __device__ float qk_dot_(const K_vec (&q)[N], const K_vec (&k)[N])
|
|
||||||
{
|
|
||||||
using K_vec_acum = typename FloatVec<K_vec>::Type;
|
|
||||||
// Compute the parallel products for Q*K^T (treat vector lanes separately).
|
|
||||||
K_vec_acum qk_vec = mul<K_vec_acum, K_vec, K_vec>(q[0], k[0]);
|
|
||||||
#pragma unroll
|
|
||||||
for (int ii = 1; ii < N; ++ii) {
|
|
||||||
qk_vec = fma(q[ii], k[ii], qk_vec);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Finalize the reduction across lanes.
|
|
||||||
float qk = sum(qk_vec);
|
|
||||||
#pragma unroll
|
|
||||||
for (int mask = THREADS_PER_KEY / 2; mask >= 1; mask /= 2) {
|
|
||||||
qk += __shfl_xor_sync(uint32_t(-1), qk, mask);
|
|
||||||
}
|
|
||||||
return qk;
|
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
template<typename T, int THREADS_PER_KEY>
|
|
||||||
struct Qk_dot {
|
|
||||||
template<typename K_vec, int N>
|
|
||||||
static inline __device__ float dot(const K_vec (&q)[N], const K_vec (&k)[N])
|
|
||||||
{
|
|
||||||
return qk_dot_<THREADS_PER_KEY>(q, k);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
inline __device__ float4 hmma_fp32(const uint2& a, uint32_t b)
|
|
||||||
{
|
|
||||||
float4 c;
|
|
||||||
float zero = 0.f;
|
|
||||||
asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 \n"
|
|
||||||
" {%0, %1, %2, %3}, \n"
|
|
||||||
" {%4, %5}, \n"
|
|
||||||
" {%6}, \n"
|
|
||||||
" {%7, %7, %7, %7}; \n"
|
|
||||||
|
|
||||||
: "=f"(c.x), "=f"(c.y), "=f"(c.z), "=f"(c.w)
|
|
||||||
: "r"(a.x) "r"(a.y), "r"(b), "f"(zero));
|
|
||||||
return c;
|
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
template<int N>
|
|
||||||
inline __device__ float qk_hmma_dot_(const uint32_t (&q)[N], const uint32_t (&k)[N])
|
|
||||||
{
|
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
|
|
||||||
using K_vec_acum = typename FloatVec<uint32_t>::Type;
|
|
||||||
K_vec_acum qk_vec = mul<K_vec_acum, uint32_t, uint32_t>(q[0], k[0]);
|
|
||||||
#pragma unroll
|
|
||||||
for (int ii = 1; ii < N; ++ii) {
|
|
||||||
qk_vec = fma(q[ii], k[ii], qk_vec);
|
|
||||||
}
|
|
||||||
#ifdef MMHA_USE_FP32_ACUM_FOR_FMA
|
|
||||||
uint32_t qk_vec_ = float2_to_half2(qk_vec);
|
|
||||||
return hmma_fp32(make_uint2(qk_vec_, 0u), 0x3c003c00u).x;
|
|
||||||
#else
|
|
||||||
return hmma_fp32(make_uint2(qk_vec, 0u), 0x3c003c00u).x;
|
|
||||||
#endif
|
|
||||||
#else
|
|
||||||
return 0.f;
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
template<>
|
|
||||||
struct Qk_dot<uint16_t, 4> {
|
|
||||||
template<int N>
|
|
||||||
static inline __device__ float dot(const uint32_t (&q)[N], const uint32_t (&k)[N])
|
|
||||||
{
|
|
||||||
#if __CUDA_ARCH__ >= 750 && defined(MMHA_USE_HMMA_FOR_REDUCTION)
|
|
||||||
return qk_hmma_dot_(q, k);
|
|
||||||
#else
|
|
||||||
return qk_dot_<4>(q, k);
|
|
||||||
#endif // defined MMHA_USE_HMMA_FOR_REDUCTION
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace cacheflow
|
|
||||||
|
|
||||||
#undef MMHA_USE_FP32_ACUM_FOR_FMA
|
|
||||||
#undef MMHA_USE_FP32_ACUM_FOR_OUT
|
|
||||||
@ -1,6 +1,8 @@
|
|||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
|
||||||
|
#include "dispatch_utils.h"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <map>
|
#include <map>
|
||||||
@ -14,14 +16,16 @@ void swap_blocks(
|
|||||||
torch::Device dst_device = dst.device();
|
torch::Device dst_device = dst.device();
|
||||||
cudaMemcpyKind memcpy_type;
|
cudaMemcpyKind memcpy_type;
|
||||||
if (src_device.is_cuda() && dst_device.is_cuda()) {
|
if (src_device.is_cuda() && dst_device.is_cuda()) {
|
||||||
assert(src_device.index() == dst_device.index());
|
TORCH_CHECK(
|
||||||
|
src_device.index() == dst_device.index(),
|
||||||
|
"src and dst must be on the same GPU");
|
||||||
memcpy_type = cudaMemcpyDeviceToDevice;
|
memcpy_type = cudaMemcpyDeviceToDevice;
|
||||||
} else if (src_device.is_cuda() && dst_device.is_cpu()) {
|
} else if (src_device.is_cuda() && dst_device.is_cpu()) {
|
||||||
memcpy_type = cudaMemcpyDeviceToHost;
|
memcpy_type = cudaMemcpyDeviceToHost;
|
||||||
} else if (src_device.is_cpu() && dst_device.is_cuda()) {
|
} else if (src_device.is_cpu() && dst_device.is_cuda()) {
|
||||||
memcpy_type = cudaMemcpyHostToDevice;
|
memcpy_type = cudaMemcpyHostToDevice;
|
||||||
} else {
|
} else {
|
||||||
assert(false);
|
TORCH_CHECK(false, "Invalid device combination");
|
||||||
}
|
}
|
||||||
|
|
||||||
void *src_ptr = src.data_ptr();
|
void *src_ptr = src.data_ptr();
|
||||||
@ -29,6 +33,7 @@ void swap_blocks(
|
|||||||
|
|
||||||
const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
|
const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
// NOTE(woosuk): This can be slow if the number of blocks is large.
|
||||||
for (const auto& pair : block_mapping) {
|
for (const auto& pair : block_mapping) {
|
||||||
int64_t src_block_number = pair.first;
|
int64_t src_block_number = pair.first;
|
||||||
int64_t dst_block_number = pair.second;
|
int64_t dst_block_number = pair.second;
|
||||||
@ -43,7 +48,7 @@ void swap_blocks(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace cacheflow {
|
namespace vllm {
|
||||||
|
|
||||||
// Grid: (num_layers, num_pairs)
|
// Grid: (num_layers, num_pairs)
|
||||||
template<typename scalar_t>
|
template<typename scalar_t>
|
||||||
@ -74,7 +79,7 @@ __global__ void copy_blocks_kernel(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace cacheflow
|
} // namespace vllm
|
||||||
|
|
||||||
void copy_blocks(
|
void copy_blocks(
|
||||||
std::vector<torch::Tensor>& key_caches,
|
std::vector<torch::Tensor>& key_caches,
|
||||||
@ -122,9 +127,9 @@ void copy_blocks(
|
|||||||
dim3 grid(num_layers, num_pairs);
|
dim3 grid(num_layers, num_pairs);
|
||||||
dim3 block(std::min(1024, numel_per_block));
|
dim3 block(std::min(1024, numel_per_block));
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
VLLM_DISPATCH_FLOATING_TYPES(
|
||||||
key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
|
key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
|
||||||
cacheflow::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||||
key_cache_ptrs_tensor.data_ptr<int64_t>(),
|
key_cache_ptrs_tensor.data_ptr<int64_t>(),
|
||||||
value_cache_ptrs_tensor.data_ptr<int64_t>(),
|
value_cache_ptrs_tensor.data_ptr<int64_t>(),
|
||||||
block_mapping_tensor.data_ptr<int>(),
|
block_mapping_tensor.data_ptr<int>(),
|
||||||
@ -132,7 +137,7 @@ void copy_blocks(
|
|||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace cacheflow {
|
namespace vllm {
|
||||||
|
|
||||||
template<typename scalar_t>
|
template<typename scalar_t>
|
||||||
__global__ void reshape_and_cache_kernel(
|
__global__ void reshape_and_cache_kernel(
|
||||||
@ -176,6 +181,48 @@ __global__ void reshape_and_cache_kernel(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
} // namespace vllm
|
||||||
|
|
||||||
|
void reshape_and_cache(
|
||||||
|
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
||||||
|
torch::Tensor& value, // [num_tokens, num_heads, head_size]
|
||||||
|
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||||
|
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||||
|
torch::Tensor& slot_mapping) // [num_tokens]
|
||||||
|
{
|
||||||
|
int num_tokens = key.size(0);
|
||||||
|
int num_heads = key.size(1);
|
||||||
|
int head_size = key.size(2);
|
||||||
|
int block_size = key_cache.size(3);
|
||||||
|
int x = key_cache.size(4);
|
||||||
|
|
||||||
|
int key_stride = key.stride(0);
|
||||||
|
int value_stride = value.stride(0);
|
||||||
|
|
||||||
|
dim3 grid(num_tokens);
|
||||||
|
dim3 block(std::min(num_heads * head_size, 512));
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
VLLM_DISPATCH_FLOATING_TYPES(
|
||||||
|
key.scalar_type(),
|
||||||
|
"reshape_and_cache_kernel",
|
||||||
|
[&] {
|
||||||
|
vllm::reshape_and_cache_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||||
|
key.data_ptr<scalar_t>(),
|
||||||
|
value.data_ptr<scalar_t>(),
|
||||||
|
key_cache.data_ptr<scalar_t>(),
|
||||||
|
value_cache.data_ptr<scalar_t>(),
|
||||||
|
slot_mapping.data_ptr<int>(),
|
||||||
|
key_stride,
|
||||||
|
value_stride,
|
||||||
|
num_heads,
|
||||||
|
head_size,
|
||||||
|
block_size,
|
||||||
|
x);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
|
||||||
// Grid: (num_blocks, block_size).
|
// Grid: (num_blocks, block_size).
|
||||||
template<typename scalar_t>
|
template<typename scalar_t>
|
||||||
__global__ void gather_cached_kv_kernel(
|
__global__ void gather_cached_kv_kernel(
|
||||||
@ -294,46 +341,7 @@ __global__ void gather_cached_kv_kernel_optimized(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace cacheflow
|
} // namespace vllm
|
||||||
|
|
||||||
void reshape_and_cache(
|
|
||||||
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
|
||||||
torch::Tensor& value, // [num_tokens, num_heads, head_size]
|
|
||||||
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
|
||||||
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
|
|
||||||
torch::Tensor& slot_mapping) // [num_tokens]
|
|
||||||
{
|
|
||||||
int num_tokens = key.size(0);
|
|
||||||
int num_heads = key.size(1);
|
|
||||||
int head_size = key.size(2);
|
|
||||||
int block_size = key_cache.size(3);
|
|
||||||
int x = key_cache.size(4);
|
|
||||||
|
|
||||||
int key_stride = key.stride(0);
|
|
||||||
int value_stride = value.stride(0);
|
|
||||||
|
|
||||||
dim3 grid(num_tokens);
|
|
||||||
dim3 block(std::min(num_heads * head_size, 512));
|
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
|
||||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
|
||||||
key.scalar_type(),
|
|
||||||
"reshape_and_cache_kernel",
|
|
||||||
[&] {
|
|
||||||
cacheflow::reshape_and_cache_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
|
||||||
key.data_ptr<scalar_t>(),
|
|
||||||
value.data_ptr<scalar_t>(),
|
|
||||||
key_cache.data_ptr<scalar_t>(),
|
|
||||||
value_cache.data_ptr<scalar_t>(),
|
|
||||||
slot_mapping.data_ptr<int>(),
|
|
||||||
key_stride,
|
|
||||||
value_stride,
|
|
||||||
num_heads,
|
|
||||||
head_size,
|
|
||||||
block_size,
|
|
||||||
x);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
void gather_cached_kv(
|
void gather_cached_kv(
|
||||||
torch::Tensor& key, // [out] [num_tokens, num_heads, head_size]
|
torch::Tensor& key, // [out] [num_tokens, num_heads, head_size]
|
||||||
@ -354,11 +362,11 @@ void gather_cached_kv(
|
|||||||
dim3 grid(num_tokens);
|
dim3 grid(num_tokens);
|
||||||
dim3 block(std::min(num_heads * head_size, 512));
|
dim3 block(std::min(num_heads * head_size, 512));
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
VLLM_DISPATCH_FLOATING_TYPES(
|
||||||
key.scalar_type(),
|
key.scalar_type(),
|
||||||
"gather_cached_kv_kernel_optimized",
|
"gather_cached_kv_kernel_optimized",
|
||||||
[&] {
|
[&] {
|
||||||
cacheflow::gather_cached_kv_kernel_optimized<scalar_t><<<grid, block, 0, stream>>>(
|
vllm::gather_cached_kv_kernel_optimized<scalar_t><<<grid, block, 0, stream>>>(
|
||||||
key.data_ptr<scalar_t>(),
|
key.data_ptr<scalar_t>(),
|
||||||
value.data_ptr<scalar_t>(),
|
value.data_ptr<scalar_t>(),
|
||||||
key_cache.data_ptr<scalar_t>(),
|
key_cache.data_ptr<scalar_t>(),
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
13
csrc/cuda_utils.cpp
Normal file
13
csrc/cuda_utils.cpp
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
#include <torch/extension.h>
|
||||||
|
|
||||||
|
int get_device_attribute(
|
||||||
|
int attribute,
|
||||||
|
int device_id);
|
||||||
|
|
||||||
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||||
|
m.def(
|
||||||
|
"get_device_attribute",
|
||||||
|
&get_device_attribute,
|
||||||
|
"Gets the specified device attribute.");
|
||||||
|
}
|
||||||
|
|
||||||
14
csrc/cuda_utils_kernels.cu
Normal file
14
csrc/cuda_utils_kernels.cu
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
int get_device_attribute(
|
||||||
|
int attribute,
|
||||||
|
int device_id)
|
||||||
|
{
|
||||||
|
int device, value;
|
||||||
|
if (device_id < 0) {
|
||||||
|
cudaGetDevice(&device);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
device = device_id;
|
||||||
|
}
|
||||||
|
cudaDeviceGetAttribute(&value, static_cast<cudaDeviceAttr>(attribute), device);
|
||||||
|
return value;
|
||||||
|
}
|
||||||
14
csrc/dispatch_utils.h
Normal file
14
csrc/dispatch_utils.h
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
/*
|
||||||
|
* Adapted from
|
||||||
|
* https://github.com/pytorch/pytorch/blob/v2.0.1/aten/src/ATen/Dispatch.h
|
||||||
|
*/
|
||||||
|
#include <torch/extension.h>
|
||||||
|
|
||||||
|
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
|
||||||
|
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||||
|
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
||||||
|
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
|
||||||
|
|
||||||
|
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
||||||
|
AT_DISPATCH_SWITCH( \
|
||||||
|
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
||||||
@ -1,9 +1,10 @@
|
|||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
|
||||||
#include "reduction_utils.h"
|
#include "dispatch_utils.h"
|
||||||
|
#include "reduction_utils.cuh"
|
||||||
|
|
||||||
namespace cacheflow {
|
namespace vllm {
|
||||||
|
|
||||||
// TODO(woosuk): Further optimize this kernel.
|
// TODO(woosuk): Further optimize this kernel.
|
||||||
template<typename scalar_t>
|
template<typename scalar_t>
|
||||||
@ -33,7 +34,7 @@ __global__ void rms_norm_kernel(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace cacheflow
|
} // namespace vllm
|
||||||
|
|
||||||
void rms_norm(
|
void rms_norm(
|
||||||
torch::Tensor& out, // [num_tokens, hidden_size]
|
torch::Tensor& out, // [num_tokens, hidden_size]
|
||||||
@ -46,11 +47,11 @@ void rms_norm(
|
|||||||
dim3 grid(num_tokens);
|
dim3 grid(num_tokens);
|
||||||
dim3 block(std::min(hidden_size, 1024));
|
dim3 block(std::min(hidden_size, 1024));
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
VLLM_DISPATCH_FLOATING_TYPES(
|
||||||
input.scalar_type(),
|
input.scalar_type(),
|
||||||
"rms_norm_kernel",
|
"rms_norm_kernel",
|
||||||
[&] {
|
[&] {
|
||||||
cacheflow::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||||
out.data_ptr<scalar_t>(),
|
out.data_ptr<scalar_t>(),
|
||||||
input.data_ptr<scalar_t>(),
|
input.data_ptr<scalar_t>(),
|
||||||
weight.data_ptr<scalar_t>(),
|
weight.data_ptr<scalar_t>(),
|
||||||
|
|||||||
@ -1,14 +1,16 @@
|
|||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
|
|
||||||
void rotary_embedding_neox(
|
void rotary_embedding(
|
||||||
torch::Tensor& positions,
|
torch::Tensor& positions,
|
||||||
torch::Tensor& query,
|
torch::Tensor& query,
|
||||||
torch::Tensor& key,
|
torch::Tensor& key,
|
||||||
torch::Tensor& cos_sin_cache);
|
int head_size,
|
||||||
|
torch::Tensor& cos_sin_cache,
|
||||||
|
bool is_neox);
|
||||||
|
|
||||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||||
m.def(
|
m.def(
|
||||||
"rotary_embedding_neox",
|
"rotary_embedding",
|
||||||
&rotary_embedding_neox,
|
&rotary_embedding,
|
||||||
"Apply GPT-NeoX style rotary embedding to query and key");
|
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,78 +1,127 @@
|
|||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
|
||||||
namespace cacheflow {
|
#include "dispatch_utils.h"
|
||||||
|
|
||||||
template<typename scalar_t>
|
namespace vllm {
|
||||||
__global__ void rotary_embedding_neox_kernel(
|
|
||||||
|
template<typename scalar_t, bool IS_NEOX>
|
||||||
|
inline __device__ void apply_rotary_embedding(
|
||||||
|
scalar_t* __restrict__ arr,
|
||||||
|
const scalar_t* __restrict__ cos_ptr,
|
||||||
|
const scalar_t* __restrict__ sin_ptr,
|
||||||
|
int rot_offset,
|
||||||
|
int embed_dim)
|
||||||
|
{
|
||||||
|
int x_index, y_index;
|
||||||
|
scalar_t cos, sin;
|
||||||
|
if (IS_NEOX) {
|
||||||
|
// GPT-NeoX style rotary embedding.
|
||||||
|
x_index = rot_offset;
|
||||||
|
y_index = embed_dim + rot_offset;
|
||||||
|
cos = __ldg(cos_ptr + x_index);
|
||||||
|
sin = __ldg(sin_ptr + x_index);
|
||||||
|
} else {
|
||||||
|
// GPT-J style rotary embedding.
|
||||||
|
x_index = 2 * rot_offset;
|
||||||
|
y_index = 2 * rot_offset + 1;
|
||||||
|
cos = __ldg(cos_ptr + x_index / 2);
|
||||||
|
sin = __ldg(sin_ptr + x_index / 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
const scalar_t x = arr[x_index];
|
||||||
|
const scalar_t y = arr[y_index];
|
||||||
|
arr[x_index] = x * cos - y * sin;
|
||||||
|
arr[y_index] = y * cos + x * sin;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename scalar_t, bool IS_NEOX>
|
||||||
|
__global__ void rotary_embedding_kernel(
|
||||||
const int64_t* __restrict__ positions, // [num_tokens]
|
const int64_t* __restrict__ positions, // [num_tokens]
|
||||||
scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size]
|
scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size]
|
||||||
scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
scalar_t* __restrict__ key, // [num_tokens, num_kv_heads, head_size]
|
||||||
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, head_size // 2]
|
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
|
||||||
const int stride,
|
const int rot_dim,
|
||||||
|
const int query_stride,
|
||||||
|
const int key_stride,
|
||||||
const int num_heads,
|
const int num_heads,
|
||||||
|
const int num_kv_heads,
|
||||||
const int head_size) {
|
const int head_size) {
|
||||||
// Each thread block is responsible for one token.
|
// Each thread block is responsible for one token.
|
||||||
const int token_idx = blockIdx.x;
|
const int token_idx = blockIdx.x;
|
||||||
int64_t pos = positions[token_idx];
|
int64_t pos = positions[token_idx];
|
||||||
const scalar_t* cache_ptr = cos_sin_cache + pos * head_size;
|
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
|
||||||
|
|
||||||
const int embed_dim = head_size / 2;
|
const int embed_dim = rot_dim / 2;
|
||||||
const int n = num_heads * embed_dim;
|
const scalar_t* cos_ptr = cache_ptr;
|
||||||
for (int i = threadIdx.x; i < n; i += blockDim.x) {
|
const scalar_t* sin_ptr = cache_ptr + embed_dim;
|
||||||
|
|
||||||
|
const int nq = num_heads * embed_dim;
|
||||||
|
for (int i = threadIdx.x; i < nq; i += blockDim.x) {
|
||||||
const int head_idx = i / embed_dim;
|
const int head_idx = i / embed_dim;
|
||||||
const int token_head = token_idx * stride + head_idx * head_size;
|
const int token_head = token_idx * query_stride + head_idx * head_size;
|
||||||
|
|
||||||
const int rot_offset = i % embed_dim;
|
const int rot_offset = i % embed_dim;
|
||||||
const int x_index = rot_offset;
|
apply_rotary_embedding<scalar_t, IS_NEOX>(query + token_head, cos_ptr,
|
||||||
const int y_index = embed_dim + rot_offset;
|
sin_ptr, rot_offset, embed_dim);
|
||||||
|
}
|
||||||
|
|
||||||
const int out_x = token_idx * stride + head_idx * head_size + x_index;
|
const int nk = num_kv_heads * embed_dim;
|
||||||
const int out_y = token_idx * stride + head_idx * head_size + y_index;
|
for (int i = threadIdx.x; i < nk; i += blockDim.x) {
|
||||||
|
const int head_idx = i / embed_dim;
|
||||||
const scalar_t cos = __ldg(cache_ptr + x_index);
|
const int token_head = token_idx * key_stride + head_idx * head_size;
|
||||||
const scalar_t sin = __ldg(cache_ptr + y_index);
|
const int rot_offset = i % embed_dim;
|
||||||
|
apply_rotary_embedding<scalar_t, IS_NEOX>(key + token_head, cos_ptr,
|
||||||
const scalar_t q_x = query[token_head + x_index];
|
sin_ptr, rot_offset, embed_dim);
|
||||||
const scalar_t q_y = query[token_head + y_index];
|
|
||||||
query[out_x] = q_x * cos - q_y * sin;
|
|
||||||
query[out_y] = q_y * cos + q_x * sin;
|
|
||||||
|
|
||||||
const scalar_t k_x = key[token_head + x_index];
|
|
||||||
const scalar_t k_y = key[token_head + y_index];
|
|
||||||
key[out_x] = k_x * cos - k_y * sin;
|
|
||||||
key[out_y] = k_y * cos + k_x * sin;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace cacheflow
|
} // namespace vllm
|
||||||
|
|
||||||
void rotary_embedding_neox(
|
void rotary_embedding(
|
||||||
torch::Tensor& positions, // [num_tokens]
|
torch::Tensor& positions, // [num_tokens]
|
||||||
torch::Tensor& query, // [num_tokens, num_heads * head_size]
|
torch::Tensor& query, // [num_tokens, num_heads * head_size]
|
||||||
torch::Tensor& key, // [num_tokens, num_heads * head_size]
|
torch::Tensor& key, // [num_tokens, num_kv_heads * head_size]
|
||||||
torch::Tensor& cos_sin_cache) // [max_position, head_size]
|
int head_size,
|
||||||
{
|
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
|
||||||
|
bool is_neox) {
|
||||||
int num_tokens = query.size(0);
|
int num_tokens = query.size(0);
|
||||||
int head_size = cos_sin_cache.size(1);
|
int rot_dim = cos_sin_cache.size(1);
|
||||||
int num_heads = query.size(1) / head_size;
|
int num_heads = query.size(1) / head_size;
|
||||||
int stride = query.stride(0);
|
int num_kv_heads = key.size(1) / head_size;
|
||||||
TORCH_CHECK(stride == key.stride(0));
|
int query_stride = query.stride(0);
|
||||||
|
int key_stride = key.stride(0);
|
||||||
|
|
||||||
dim3 grid(num_tokens);
|
dim3 grid(num_tokens);
|
||||||
dim3 block(std::min(num_heads * head_size / 2, 512));
|
dim3 block(std::min(num_heads * rot_dim / 2, 512));
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
VLLM_DISPATCH_FLOATING_TYPES(
|
||||||
query.scalar_type(),
|
query.scalar_type(),
|
||||||
"rotary_embedding_neox",
|
"rotary_embedding",
|
||||||
[&] {
|
[&] {
|
||||||
cacheflow::rotary_embedding_neox_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
if (is_neox) {
|
||||||
|
vllm::rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
|
||||||
positions.data_ptr<int64_t>(),
|
positions.data_ptr<int64_t>(),
|
||||||
query.data_ptr<scalar_t>(),
|
query.data_ptr<scalar_t>(),
|
||||||
key.data_ptr<scalar_t>(),
|
key.data_ptr<scalar_t>(),
|
||||||
cos_sin_cache.data_ptr<scalar_t>(),
|
cos_sin_cache.data_ptr<scalar_t>(),
|
||||||
stride,
|
rot_dim,
|
||||||
|
query_stride,
|
||||||
|
key_stride,
|
||||||
num_heads,
|
num_heads,
|
||||||
|
num_kv_heads,
|
||||||
head_size);
|
head_size);
|
||||||
|
} else {
|
||||||
|
vllm::rotary_embedding_kernel<scalar_t, false><<<grid, block, 0, stream>>>(
|
||||||
|
positions.data_ptr<int64_t>(),
|
||||||
|
query.data_ptr<scalar_t>(),
|
||||||
|
key.data_ptr<scalar_t>(),
|
||||||
|
cos_sin_cache.data_ptr<scalar_t>(),
|
||||||
|
rot_dim,
|
||||||
|
query_stride,
|
||||||
|
key_stride,
|
||||||
|
num_heads,
|
||||||
|
num_kv_heads,
|
||||||
|
head_size);
|
||||||
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|||||||
15
csrc/quantization.cpp
Normal file
15
csrc/quantization.cpp
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
#include <torch/extension.h>
|
||||||
|
|
||||||
|
torch::Tensor awq_gemm(
|
||||||
|
torch::Tensor _in_feats,
|
||||||
|
torch::Tensor _kernel,
|
||||||
|
torch::Tensor _scaling_factors,
|
||||||
|
torch::Tensor _zeros,
|
||||||
|
int split_k_iters);
|
||||||
|
|
||||||
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||||
|
m.def(
|
||||||
|
"awq_gemm",
|
||||||
|
&awq_gemm,
|
||||||
|
"Quantized GEMM for AWQ");
|
||||||
|
}
|
||||||
87
csrc/quantization/awq/dequantize.cuh
Normal file
87
csrc/quantization/awq/dequantize.cuh
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
/*
|
||||||
|
Adapted from https://github.com/mit-han-lab/llm-awq
|
||||||
|
Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
|
||||||
|
@article{lin2023awq,
|
||||||
|
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
|
||||||
|
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
|
||||||
|
journal={arXiv},
|
||||||
|
year={2023}
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
namespace awq {
|
||||||
|
|
||||||
|
__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
|
||||||
|
{
|
||||||
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
|
||||||
|
assert(false);
|
||||||
|
#else
|
||||||
|
uint4 result;
|
||||||
|
|
||||||
|
uint32_t* h = reinterpret_cast<uint32_t*>(&result);
|
||||||
|
uint32_t const i4s = reinterpret_cast<uint32_t const&>(source);
|
||||||
|
|
||||||
|
// First, we extract the i4s and construct an intermediate fp16 number.
|
||||||
|
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
|
||||||
|
static constexpr uint32_t BOTTOM_MASK = 0x000f000f;
|
||||||
|
static constexpr uint32_t TOP_MASK = 0x00f000f0;
|
||||||
|
static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;
|
||||||
|
|
||||||
|
// Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing
|
||||||
|
// format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions.
|
||||||
|
// In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and
|
||||||
|
// elt_67 to fp16 without having to shift them to the bottom bits before hand.
|
||||||
|
|
||||||
|
// Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue
|
||||||
|
// immediately before required.
|
||||||
|
const uint32_t top_i4s = i4s >> 8;
|
||||||
|
// Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
|
||||||
|
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||||
|
: "=r"(h[0])
|
||||||
|
: "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
||||||
|
// Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
|
||||||
|
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||||
|
: "=r"(h[1])
|
||||||
|
: "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
||||||
|
// Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
|
||||||
|
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||||
|
: "=r"(h[2])
|
||||||
|
: "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
||||||
|
// Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
|
||||||
|
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||||
|
: "=r"(h[3])
|
||||||
|
: "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
||||||
|
|
||||||
|
// I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the
|
||||||
|
// half2 ctor. In this case, I chose performance reliability over code readability.
|
||||||
|
|
||||||
|
// This is the half2 {1032, 1032} represented as an integer.
|
||||||
|
// static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408;
|
||||||
|
// Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7]
|
||||||
|
static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400;
|
||||||
|
// This is the half2 {1 / 16, 1 / 16} represented as an integer.
|
||||||
|
static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00;
|
||||||
|
// This is the half2 {-72, -72} represented as an integer.
|
||||||
|
// static constexpr uint32_t NEG_72 = 0xd480d480;
|
||||||
|
// Haotian: Let's use {-64, -64}.
|
||||||
|
static constexpr uint32_t NEG_64 = 0xd400d400;
|
||||||
|
|
||||||
|
// Finally, we construct the output numbers.
|
||||||
|
// Convert elt_01
|
||||||
|
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
|
||||||
|
// Convert elt_23
|
||||||
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
|
||||||
|
// Convert elt_45
|
||||||
|
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
|
||||||
|
// Convert elt_67
|
||||||
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
|
||||||
|
|
||||||
|
return result;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace awq
|
||||||
|
} // namespace vllm
|
||||||
560
csrc/quantization/awq/gemm_kernels.cu
Normal file
560
csrc/quantization/awq/gemm_kernels.cu
Normal file
@ -0,0 +1,560 @@
|
|||||||
|
/*
|
||||||
|
Adapted from https://github.com/mit-han-lab/llm-awq
|
||||||
|
@article{lin2023awq,
|
||||||
|
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
|
||||||
|
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
|
||||||
|
journal={arXiv},
|
||||||
|
year={2023}
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
|
||||||
|
#include <torch/extension.h>
|
||||||
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
|
||||||
|
#include "dequantize.cuh"
|
||||||
|
|
||||||
|
#include <cuda_fp16.h>
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
namespace awq {
|
||||||
|
|
||||||
|
// Pack two half values.
|
||||||
|
static inline __device__ __host__ unsigned
|
||||||
|
__pack_half2(const half x, const half y) {
|
||||||
|
unsigned v0 = *((unsigned short *)&x);
|
||||||
|
unsigned v1 = *((unsigned short *)&y);
|
||||||
|
return (v1 << 16) | v0;
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C)
|
||||||
|
{
|
||||||
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
|
||||||
|
assert(false);
|
||||||
|
#else
|
||||||
|
static constexpr uint32_t ZERO = 0x0;
|
||||||
|
float C_warp[32];
|
||||||
|
__shared__ half A_shared[16 * (32 + 8)];
|
||||||
|
__shared__ half B_shared[32 * (128 + 8)];
|
||||||
|
|
||||||
|
__shared__ half scaling_factors_shared[128];
|
||||||
|
__shared__ half zeros_shared[128];
|
||||||
|
|
||||||
|
int j_factors1 = ((OC + 128 - 1) / 128);
|
||||||
|
int blockIdx_x = 0;
|
||||||
|
int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1);
|
||||||
|
int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1);
|
||||||
|
|
||||||
|
half A_shared_warp[8];
|
||||||
|
half B_shared_warp[32];
|
||||||
|
for (int j_0_4_init = 0; j_0_4_init < 4; ++j_0_4_init) {
|
||||||
|
for (int i = 0; i < 8; ++i) {
|
||||||
|
C_warp[(j_0_4_init * 8) + i] = 0.0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static constexpr int row_stride_warp = 32 * 8 / 32;
|
||||||
|
static constexpr int row_stride = 2 * 32 * 8 / 128;
|
||||||
|
bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < 128;
|
||||||
|
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
|
||||||
|
bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id
|
||||||
|
// bool wb_C_flag = (threadIdx.x / 4) < M;
|
||||||
|
|
||||||
|
half* A_ptr = A
|
||||||
|
+ (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC
|
||||||
|
+ (((int)threadIdx.x) % (32 / 8)) * 8;
|
||||||
|
|
||||||
|
int* B_ptr = B
|
||||||
|
+ ((int)threadIdx.y) * (OC / 8) * 2
|
||||||
|
+ (((int)threadIdx.x) / (128 / 8)) * (OC / 8)
|
||||||
|
+ (((int)blockIdx_y) % j_factors1) * (128 / 8)
|
||||||
|
+ (((int)threadIdx.x) % (128 / 8)) * 1;
|
||||||
|
// Why * 1 in the above line?
|
||||||
|
|
||||||
|
half* A_shared_ptr = A_shared
|
||||||
|
+ ((int)threadIdx.y) * row_stride_warp * (32 + 8)
|
||||||
|
+ (((int)threadIdx.x) / (32 / 8)) * (32 + 8)
|
||||||
|
+ (((int)threadIdx.x) % (32 / 8) ) * 8;
|
||||||
|
|
||||||
|
half* B_shared_ptr = B_shared
|
||||||
|
+ ((int)threadIdx.y) * (row_stride / 2) * (128 + 8)
|
||||||
|
+ (((int)threadIdx.x) / (128 / 8)) * (128 + 8)
|
||||||
|
+ (((int)threadIdx.x) % (128 / 8)) * 8;
|
||||||
|
|
||||||
|
int* zeros_ptr = zeros
|
||||||
|
+ (((int)blockIdx_y) % j_factors1) * (128 / 8)
|
||||||
|
+ ((int)threadIdx.x) % (128 / 8);
|
||||||
|
|
||||||
|
half* scaling_factors_ptr = scaling_factors
|
||||||
|
+ (((int)blockIdx_y) % j_factors1) * (128)
|
||||||
|
+ (((int)threadIdx.x) % (128 / 8)) * 8;
|
||||||
|
|
||||||
|
half* C_ptr = C
|
||||||
|
+ static_cast<long long>(blockIdx_z) * M * OC // blockIdz.x -> split_k dim
|
||||||
|
+ (((int)blockIdx_y) % j_factors1) * 128
|
||||||
|
+ ((int)threadIdx.y) * 64
|
||||||
|
+ (((int)threadIdx.x) % 4) * 2;
|
||||||
|
|
||||||
|
// preload s.f. and zeros
|
||||||
|
int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters;
|
||||||
|
if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1;
|
||||||
|
for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) {
|
||||||
|
int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z;
|
||||||
|
__syncthreads();
|
||||||
|
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
|
||||||
|
if (ld_A_flag)
|
||||||
|
{
|
||||||
|
*(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32));
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
*(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
|
||||||
|
uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8));
|
||||||
|
uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
|
||||||
|
uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC));
|
||||||
|
/*
|
||||||
|
if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){
|
||||||
|
printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
// uint4 B_loaded_scale = make_uint4(0, 0, 0, 0);
|
||||||
|
int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);
|
||||||
|
|
||||||
|
for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 8; ++ax0_ax1_fused_0) {
|
||||||
|
|
||||||
|
// B: 32 x 136 (128+8) float16
|
||||||
|
// each warp: 32 x 4
|
||||||
|
// each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4
|
||||||
|
// *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8)));
|
||||||
|
// row stride in shared memory: (NWARPS * 32 * 8 / cta_N)
|
||||||
|
uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8));
|
||||||
|
uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
|
||||||
|
//uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8);
|
||||||
|
|
||||||
|
// uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8);
|
||||||
|
// - zero and * scale
|
||||||
|
// TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale.
|
||||||
|
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
|
||||||
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
|
||||||
|
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
|
||||||
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
|
||||||
|
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
|
||||||
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
|
||||||
|
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
|
||||||
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
|
||||||
|
/*
|
||||||
|
if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){
|
||||||
|
printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w);
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
// write back
|
||||||
|
*(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (128 + 8)) = B_loaded_fp16;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1) {
|
||||||
|
{
|
||||||
|
unsigned int addr;
|
||||||
|
__asm__ __volatile__(
|
||||||
|
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
|
||||||
|
: "=r"(addr)
|
||||||
|
: "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8))))
|
||||||
|
);
|
||||||
|
|
||||||
|
|
||||||
|
__asm__ __volatile__(
|
||||||
|
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
|
||||||
|
"{%0, %1, %2, %3}, [%4];\n"
|
||||||
|
: "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3])
|
||||||
|
: "r"(addr)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int ax1_0 = 0; ax1_0 < 4; ++ax1_0) {
|
||||||
|
{
|
||||||
|
unsigned int addr;
|
||||||
|
__asm__ __volatile__(
|
||||||
|
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
|
||||||
|
: "=r"(addr)
|
||||||
|
: "l"((void *)((&(B_shared[(((k_0_1 * 2176) + (((int)threadIdx.y) * 64)) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * 136) + ((((int)threadIdx.x) >> 4) * 8))))
|
||||||
|
);
|
||||||
|
__asm__ __volatile__(
|
||||||
|
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
|
||||||
|
"{%0, %1, %2, %3}, [%4];\n"
|
||||||
|
: "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3])
|
||||||
|
: "r"(addr)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int j_0_4 = 0; j_0_4 < 4; ++j_0_4) {
|
||||||
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
||||||
|
{
|
||||||
|
__asm__ __volatile__(
|
||||||
|
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
||||||
|
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
||||||
|
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
|
||||||
|
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
__asm__ __volatile__(
|
||||||
|
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
||||||
|
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
||||||
|
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
|
||||||
|
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
__asm__ __volatile__(
|
||||||
|
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
||||||
|
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
||||||
|
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
|
||||||
|
: "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
__asm__ __volatile__(
|
||||||
|
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
||||||
|
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
||||||
|
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
|
||||||
|
: "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
{
|
||||||
|
__asm__ __volatile__(
|
||||||
|
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
|
||||||
|
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
|
||||||
|
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
|
||||||
|
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
__asm__ __volatile__(
|
||||||
|
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
|
||||||
|
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
|
||||||
|
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
|
||||||
|
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Shang: Hoist loop invariance.
|
||||||
|
for (int ax1_0_1 = 0; ax1_0_1 < 4; ++ax1_0_1) {
|
||||||
|
for (int local_id = 0; local_id < 8; ++local_id) {
|
||||||
|
int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8;
|
||||||
|
if (row_offset < M)
|
||||||
|
{
|
||||||
|
*(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C)
|
||||||
|
{
|
||||||
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
|
||||||
|
assert(false);
|
||||||
|
#else
|
||||||
|
static constexpr uint32_t ZERO = 0x0;
|
||||||
|
float C_warp[32];
|
||||||
|
__shared__ half A_shared[16 * (32 + 8)];
|
||||||
|
__shared__ half B_shared[32 * (64 + 8)];
|
||||||
|
|
||||||
|
__shared__ half scaling_factors_shared[64];
|
||||||
|
__shared__ half zeros_shared[64];
|
||||||
|
|
||||||
|
int j_factors1 = ((OC + 64 - 1) / 64);
|
||||||
|
|
||||||
|
int blockIdx_x = 0;
|
||||||
|
int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1);
|
||||||
|
int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1);
|
||||||
|
|
||||||
|
half A_shared_warp[8];
|
||||||
|
half B_shared_warp[16];
|
||||||
|
for (int j_0_4_init = 0; j_0_4_init < 2; ++j_0_4_init) {
|
||||||
|
for (int i = 0; i < 8; ++i) {
|
||||||
|
C_warp[(j_0_4_init * 8) + i] = 0.0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static constexpr int row_stride_warp = 32 * 8 / 32;
|
||||||
|
static constexpr int row_stride = 2 * 32 * 8 / 64;
|
||||||
|
bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < 64;
|
||||||
|
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
|
||||||
|
bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id
|
||||||
|
// bool wb_C_flag = (threadIdx.x / 4) < M;
|
||||||
|
|
||||||
|
half* A_ptr = A
|
||||||
|
+ (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC
|
||||||
|
+ (((int)threadIdx.x) % (32 / 8)) * 8;
|
||||||
|
|
||||||
|
int* B_ptr = B
|
||||||
|
+ ((int)threadIdx.y) * (OC / 8) * 4
|
||||||
|
+ (((int)threadIdx.x) / (64 / 8)) * (OC / 8)
|
||||||
|
+ (((int)blockIdx_y) % j_factors1) * (64 / 8)
|
||||||
|
+ (((int)threadIdx.x) % (64 / 8)) * 1;
|
||||||
|
// Why * 1 in the above line?
|
||||||
|
|
||||||
|
half* A_shared_ptr = A_shared
|
||||||
|
+ ((int)threadIdx.y) * row_stride_warp * (32 + 8)
|
||||||
|
+ (((int)threadIdx.x) / (32 / 8)) * (32 + 8)
|
||||||
|
+ (((int)threadIdx.x) % (32 / 8) ) * 8;
|
||||||
|
|
||||||
|
half* B_shared_ptr = B_shared
|
||||||
|
+ ((int)threadIdx.y) * (row_stride / 2) * (64 + 8)
|
||||||
|
+ (((int)threadIdx.x) / (64 / 8)) * (64 + 8)
|
||||||
|
+ (((int)threadIdx.x) % (64 / 8)) * 8;
|
||||||
|
|
||||||
|
int* zeros_ptr = zeros
|
||||||
|
+ (((int)blockIdx_y) % j_factors1) * (64 / 8)
|
||||||
|
+ ((int)threadIdx.x) % (64 / 8);
|
||||||
|
|
||||||
|
half* scaling_factors_ptr = scaling_factors
|
||||||
|
+ (((int)blockIdx_y) % j_factors1) * (64)
|
||||||
|
+ (((int)threadIdx.x) % (64 / 8)) * 8;
|
||||||
|
|
||||||
|
half* C_ptr = C
|
||||||
|
+ static_cast<long long>(blockIdx_z) * M * OC // blockIdz.x -> split_k dim
|
||||||
|
+ (((int)blockIdx_y) % j_factors1) * 64
|
||||||
|
+ ((int)threadIdx.y) * 32
|
||||||
|
+ (((int)threadIdx.x) % 4) * 2;
|
||||||
|
|
||||||
|
// preload s.f. and zeros
|
||||||
|
int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters;
|
||||||
|
if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1;
|
||||||
|
for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) {
|
||||||
|
int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z;
|
||||||
|
__syncthreads();
|
||||||
|
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
|
||||||
|
if (ld_A_flag)
|
||||||
|
{
|
||||||
|
*(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32));
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
*(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
|
||||||
|
uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8));
|
||||||
|
uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
|
||||||
|
uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC));
|
||||||
|
/*
|
||||||
|
if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){
|
||||||
|
printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
// uint4 B_loaded_scale = make_uint4(0, 0, 0, 0);
|
||||||
|
int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);
|
||||||
|
|
||||||
|
for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 4; ++ax0_ax1_fused_0) {
|
||||||
|
|
||||||
|
// B: 32 x 136 (128+8) float16
|
||||||
|
// each warp: 32 x 4
|
||||||
|
// each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4
|
||||||
|
// *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8)));
|
||||||
|
// row stride in shared memory: (NWARPS * 32 * 8 / cta_N)
|
||||||
|
uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8));
|
||||||
|
uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
|
||||||
|
//uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8);
|
||||||
|
|
||||||
|
// uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8);
|
||||||
|
// - zero and * scale
|
||||||
|
// TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale.
|
||||||
|
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
|
||||||
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
|
||||||
|
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
|
||||||
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
|
||||||
|
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
|
||||||
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
|
||||||
|
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
|
||||||
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
|
||||||
|
/*
|
||||||
|
if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){
|
||||||
|
printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w);
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
// write back
|
||||||
|
*(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (64 + 8)) = B_loaded_fp16;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1)
|
||||||
|
{
|
||||||
|
{
|
||||||
|
unsigned int addr;
|
||||||
|
__asm__ __volatile__(
|
||||||
|
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
|
||||||
|
: "=r"(addr)
|
||||||
|
: "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8))))
|
||||||
|
);
|
||||||
|
__asm__ __volatile__(
|
||||||
|
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
|
||||||
|
"{%0, %1, %2, %3}, [%4];\n"
|
||||||
|
: "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3])
|
||||||
|
: "r"(addr)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
for (int ax1_0 = 0; ax1_0 < 2; ++ax1_0)
|
||||||
|
{
|
||||||
|
{
|
||||||
|
unsigned int addr;
|
||||||
|
__asm__ __volatile__(
|
||||||
|
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
|
||||||
|
: "=r"(addr)
|
||||||
|
: "l"((void *)((&(B_shared[(((k_0_1 * 1152) + (((int)threadIdx.y) * 32)) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * 72) + ((((int)threadIdx.x) >> 4) * 8))))
|
||||||
|
);
|
||||||
|
__asm__ __volatile__(
|
||||||
|
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
|
||||||
|
"{%0, %1, %2, %3}, [%4];\n"
|
||||||
|
: "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3])
|
||||||
|
: "r"(addr)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int j_0_4 = 0; j_0_4 < 2; ++j_0_4)
|
||||||
|
{
|
||||||
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
||||||
|
{
|
||||||
|
__asm__ __volatile__(
|
||||||
|
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
||||||
|
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
||||||
|
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
|
||||||
|
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
__asm__ __volatile__(
|
||||||
|
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
||||||
|
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
||||||
|
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
|
||||||
|
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
__asm__ __volatile__(
|
||||||
|
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
||||||
|
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
||||||
|
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
|
||||||
|
: "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
__asm__ __volatile__(
|
||||||
|
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
||||||
|
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
||||||
|
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
|
||||||
|
: "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
{
|
||||||
|
__asm__ __volatile__(
|
||||||
|
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
|
||||||
|
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
|
||||||
|
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
|
||||||
|
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
__asm__ __volatile__(
|
||||||
|
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
|
||||||
|
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
|
||||||
|
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
|
||||||
|
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Shang: Hoist loop invariance.
|
||||||
|
for (int ax1_0_1 = 0; ax1_0_1 < 2; ++ax1_0_1) {
|
||||||
|
for (int local_id = 0; local_id < 8; ++local_id) {
|
||||||
|
int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8;
|
||||||
|
if (row_offset < M)
|
||||||
|
{
|
||||||
|
*(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace awq
|
||||||
|
} // namespace vllm
|
||||||
|
|
||||||
|
// in_feats: M, IC [float16]
|
||||||
|
// kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b]
|
||||||
|
// scaling_factors: IC // G, OC [float16]
|
||||||
|
// zeros: IC // G, OC // 8 [int32] -> cast to IC // G, OC [uint4b]
|
||||||
|
// assume that batch_size < 16 for now
|
||||||
|
|
||||||
|
torch::Tensor awq_gemm(
|
||||||
|
torch::Tensor _in_feats,
|
||||||
|
torch::Tensor _kernel,
|
||||||
|
torch::Tensor _scaling_factors,
|
||||||
|
torch::Tensor _zeros,
|
||||||
|
int split_k_iters)
|
||||||
|
{
|
||||||
|
int num_in_feats = _in_feats.size(0);
|
||||||
|
int num_in_channels = _in_feats.size(1);
|
||||||
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats));
|
||||||
|
|
||||||
|
auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device());
|
||||||
|
at::Tensor _out_feats = torch::empty({split_k_iters, num_in_feats, _kernel.size(1) * 8}, options);
|
||||||
|
int num_out_feats = _out_feats.size(-2);
|
||||||
|
int num_out_channels = _out_feats.size(-1);
|
||||||
|
|
||||||
|
auto in_feats = reinterpret_cast<half*>(_in_feats.data_ptr<at::Half>());
|
||||||
|
auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>());
|
||||||
|
auto out_feats = reinterpret_cast<half*>(_out_feats.data_ptr<at::Half>());
|
||||||
|
auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
|
||||||
|
auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
|
||||||
|
int group_size = num_in_channels / _scaling_factors.size(0);
|
||||||
|
|
||||||
|
if (num_out_channels % 64 != 0)
|
||||||
|
throw std::invalid_argument("OC is not multiple of cta_N = 64");
|
||||||
|
if (num_out_channels % 8 != 0)
|
||||||
|
throw std::invalid_argument("OC is not multiple of pack_num = 8");
|
||||||
|
if (group_size % 32 != 0)
|
||||||
|
throw std::invalid_argument("Group size should be a multiple of 32");
|
||||||
|
if (num_out_channels % group_size != 0)
|
||||||
|
throw std::invalid_argument("OC is not multiple of Group size");
|
||||||
|
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
if (num_out_channels % 128 == 0)
|
||||||
|
{
|
||||||
|
int j_factors1 = num_out_channels / 128 / 1;
|
||||||
|
dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
|
||||||
|
// threadIdx.x: 32
|
||||||
|
// threadIdx.y: i_factors[2] * j_factors[2]
|
||||||
|
dim3 threads_per_block(32, 2);
|
||||||
|
vllm::awq::gemm_forward_4bit_cuda_m16n128k32<<<num_blocks, threads_per_block, 0, stream>>>(
|
||||||
|
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
|
||||||
|
}
|
||||||
|
else if (num_out_channels % 64 == 0)
|
||||||
|
{
|
||||||
|
int j_factors1 = num_out_channels / 64 / 1;
|
||||||
|
dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
|
||||||
|
|
||||||
|
// threadIdx.x: 32
|
||||||
|
// threadIdx.y: i_factors[2] * j_factors[2]
|
||||||
|
dim3 threads_per_block(32, 2);
|
||||||
|
vllm::awq::gemm_forward_4bit_cuda_m16n64k32<<<num_blocks, threads_per_block, 0, stream>>>(
|
||||||
|
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
|
||||||
|
}
|
||||||
|
return _out_feats.sum(0);
|
||||||
|
}
|
||||||
51
csrc/reduction_utils.cuh
Normal file
51
csrc/reduction_utils.cuh
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
/*
|
||||||
|
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/reduce_kernel_utils.cuh
|
||||||
|
* Copyright (c) 2023, The vLLM team.
|
||||||
|
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
__inline__ __device__ T warpReduceSum(T val) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int mask = 16; mask > 0; mask >>= 1)
|
||||||
|
val += __shfl_xor_sync(0xffffffff, val, mask, 32);
|
||||||
|
return val;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Calculate the sum of all elements in a block */
|
||||||
|
template<typename T>
|
||||||
|
__inline__ __device__ T blockReduceSum(T val) {
|
||||||
|
static __shared__ T shared[32];
|
||||||
|
int lane = threadIdx.x & 0x1f;
|
||||||
|
int wid = threadIdx.x >> 5;
|
||||||
|
|
||||||
|
val = warpReduceSum<T>(val);
|
||||||
|
|
||||||
|
if (lane == 0)
|
||||||
|
shared[wid] = val;
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
|
||||||
|
// blockDim.x is not divided by 32
|
||||||
|
val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f);
|
||||||
|
val = warpReduceSum<T>(val);
|
||||||
|
return val;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace vllm
|
||||||
@ -1,76 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
namespace cacheflow {
|
|
||||||
|
|
||||||
template<int WARPS_PER_BLOCK, int WARP_SIZE = 32>
|
|
||||||
inline __device__ float block_sum(float* red_smem, float sum)
|
|
||||||
{
|
|
||||||
|
|
||||||
// Decompose the thread index into warp / lane.
|
|
||||||
int warp = threadIdx.x / WARP_SIZE;
|
|
||||||
int lane = threadIdx.x % WARP_SIZE;
|
|
||||||
|
|
||||||
// Compute the sum per warp.
|
|
||||||
#pragma unroll
|
|
||||||
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
|
|
||||||
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Warp leaders store the data to shared memory.
|
|
||||||
if (lane == 0) {
|
|
||||||
red_smem[warp] = sum;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Make sure the data is in shared memory.
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
// The warps compute the final sums.
|
|
||||||
if (lane < WARPS_PER_BLOCK) {
|
|
||||||
sum = red_smem[lane];
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parallel reduction inside the warp.
|
|
||||||
#pragma unroll
|
|
||||||
for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) {
|
|
||||||
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Broadcast to other threads.
|
|
||||||
return __shfl_sync(uint32_t(-1), sum, 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
#define FINAL_MASK 0xffffffff
|
|
||||||
|
|
||||||
template<typename T>
|
|
||||||
__inline__ __device__ T warpReduceSum(T val)
|
|
||||||
{
|
|
||||||
#pragma unroll
|
|
||||||
for (int mask = 16; mask > 0; mask >>= 1)
|
|
||||||
val += __shfl_xor_sync(FINAL_MASK, val, mask, 32);
|
|
||||||
return val;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Calculate the sum of all elements in a block */
|
|
||||||
template<typename T>
|
|
||||||
__inline__ __device__ T blockReduceSum(T val)
|
|
||||||
{
|
|
||||||
static __shared__ T shared[32];
|
|
||||||
int lane = threadIdx.x & 0x1f;
|
|
||||||
int wid = threadIdx.x >> 5;
|
|
||||||
|
|
||||||
val = warpReduceSum<T>(val);
|
|
||||||
|
|
||||||
if (lane == 0)
|
|
||||||
shared[wid] = val;
|
|
||||||
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
|
|
||||||
// blockDim.x is not divided by 32
|
|
||||||
val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f);
|
|
||||||
val = warpReduceSum<T>(val);
|
|
||||||
|
|
||||||
return val;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace cacheflow
|
|
||||||
20
docs/Makefile
Normal file
20
docs/Makefile
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
# Minimal makefile for Sphinx documentation
|
||||||
|
#
|
||||||
|
|
||||||
|
# You can set these variables from the command line, and also
|
||||||
|
# from the environment for the first two.
|
||||||
|
SPHINXOPTS ?=
|
||||||
|
SPHINXBUILD ?= sphinx-build
|
||||||
|
SOURCEDIR = source
|
||||||
|
BUILDDIR = build
|
||||||
|
|
||||||
|
# Put it first so that "make" without argument is like "make help".
|
||||||
|
help:
|
||||||
|
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
||||||
|
|
||||||
|
.PHONY: help Makefile
|
||||||
|
|
||||||
|
# Catch-all target: route all unknown targets to Sphinx using the new
|
||||||
|
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
|
||||||
|
%: Makefile
|
||||||
|
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
||||||
19
docs/README.md
Normal file
19
docs/README.md
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
# vLLM documents
|
||||||
|
|
||||||
|
## Build the docs
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Install dependencies.
|
||||||
|
pip install -r requirements-docs.txt
|
||||||
|
|
||||||
|
# Build the docs.
|
||||||
|
make clean
|
||||||
|
make html
|
||||||
|
```
|
||||||
|
|
||||||
|
## Open the docs with your browser
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m http.server -d build/html/
|
||||||
|
```
|
||||||
|
Launch your browser and open localhost:8000.
|
||||||
35
docs/make.bat
Normal file
35
docs/make.bat
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
@ECHO OFF
|
||||||
|
|
||||||
|
pushd %~dp0
|
||||||
|
|
||||||
|
REM Command file for Sphinx documentation
|
||||||
|
|
||||||
|
if "%SPHINXBUILD%" == "" (
|
||||||
|
set SPHINXBUILD=sphinx-build
|
||||||
|
)
|
||||||
|
set SOURCEDIR=source
|
||||||
|
set BUILDDIR=build
|
||||||
|
|
||||||
|
%SPHINXBUILD% >NUL 2>NUL
|
||||||
|
if errorlevel 9009 (
|
||||||
|
echo.
|
||||||
|
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
|
||||||
|
echo.installed, then set the SPHINXBUILD environment variable to point
|
||||||
|
echo.to the full path of the 'sphinx-build' executable. Alternatively you
|
||||||
|
echo.may add the Sphinx directory to PATH.
|
||||||
|
echo.
|
||||||
|
echo.If you don't have Sphinx installed, grab it from
|
||||||
|
echo.https://www.sphinx-doc.org/
|
||||||
|
exit /b 1
|
||||||
|
)
|
||||||
|
|
||||||
|
if "%1" == "" goto help
|
||||||
|
|
||||||
|
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
|
||||||
|
goto end
|
||||||
|
|
||||||
|
:help
|
||||||
|
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
|
||||||
|
|
||||||
|
:end
|
||||||
|
popd
|
||||||
3
docs/requirements-docs.txt
Normal file
3
docs/requirements-docs.txt
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
sphinx == 6.2.1
|
||||||
|
sphinx-book-theme == 1.0.1
|
||||||
|
sphinx-copybutton == 0.5.2
|
||||||
BIN
docs/source/assets/logos/vllm-logo-only-light.png
Normal file
BIN
docs/source/assets/logos/vllm-logo-only-light.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 53 KiB |
BIN
docs/source/assets/logos/vllm-logo-text-dark.png
Normal file
BIN
docs/source/assets/logos/vllm-logo-text-dark.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 86 KiB |
BIN
docs/source/assets/logos/vllm-logo-text-light.png
Normal file
BIN
docs/source/assets/logos/vllm-logo-text-light.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 88 KiB |
67
docs/source/conf.py
Normal file
67
docs/source/conf.py
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
# Configuration file for the Sphinx documentation builder.
|
||||||
|
#
|
||||||
|
# This file only contains a selection of the most common options. For a full
|
||||||
|
# list see the documentation:
|
||||||
|
# https://www.sphinx-doc.org/en/master/usage/configuration.html
|
||||||
|
|
||||||
|
# -- Path setup --------------------------------------------------------------
|
||||||
|
|
||||||
|
# If extensions (or modules to document with autodoc) are in another directory,
|
||||||
|
# add these directories to sys.path here. If the directory is relative to the
|
||||||
|
# documentation root, use os.path.abspath to make it absolute, like shown here.
|
||||||
|
#
|
||||||
|
# import os
|
||||||
|
# import sys
|
||||||
|
# sys.path.insert(0, os.path.abspath('.'))
|
||||||
|
|
||||||
|
|
||||||
|
# -- Project information -----------------------------------------------------
|
||||||
|
|
||||||
|
project = 'vLLM'
|
||||||
|
copyright = '2023, vLLM Team'
|
||||||
|
author = 'the vLLM Team'
|
||||||
|
|
||||||
|
|
||||||
|
# -- General configuration ---------------------------------------------------
|
||||||
|
|
||||||
|
# Add any Sphinx extension module names here, as strings. They can be
|
||||||
|
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
|
||||||
|
# ones.
|
||||||
|
extensions = [
|
||||||
|
"sphinx.ext.napoleon",
|
||||||
|
"sphinx.ext.viewcode",
|
||||||
|
"sphinx.ext.intersphinx",
|
||||||
|
"sphinx_copybutton",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Add any paths that contain templates here, relative to this directory.
|
||||||
|
templates_path = ['_templates']
|
||||||
|
|
||||||
|
# List of patterns, relative to source directory, that match files and
|
||||||
|
# directories to ignore when looking for source files.
|
||||||
|
# This pattern also affects html_static_path and html_extra_path.
|
||||||
|
exclude_patterns = []
|
||||||
|
|
||||||
|
# Exclude the prompt "$" when copying code
|
||||||
|
copybutton_prompt_text = r"\$ "
|
||||||
|
copybutton_prompt_is_regexp = True
|
||||||
|
|
||||||
|
# -- Options for HTML output -------------------------------------------------
|
||||||
|
|
||||||
|
# The theme to use for HTML and HTML Help pages. See the documentation for
|
||||||
|
# a list of builtin themes.
|
||||||
|
#
|
||||||
|
html_title = project
|
||||||
|
html_theme = 'sphinx_book_theme'
|
||||||
|
html_logo = 'assets/logos/vllm-logo-text-light.png'
|
||||||
|
html_theme_options = {
|
||||||
|
'logo_only': True,
|
||||||
|
'path_to_docs': 'docs/source',
|
||||||
|
'repository_url': 'https://github.com/vllm-project/vllm',
|
||||||
|
'use_repository_button': True,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add any paths that contain custom static files (such as style sheets) here,
|
||||||
|
# relative to this directory. They are copied after the builtin static files,
|
||||||
|
# so a file named "default.css" will overwrite the builtin "default.css".
|
||||||
|
html_static_path = ['_static']
|
||||||
50
docs/source/getting_started/installation.rst
Normal file
50
docs/source/getting_started/installation.rst
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
.. _installation:
|
||||||
|
|
||||||
|
Installation
|
||||||
|
============
|
||||||
|
|
||||||
|
vLLM is a Python library that also contains pre-compiled C++ and CUDA (11.8) binaries.
|
||||||
|
|
||||||
|
Requirements
|
||||||
|
------------
|
||||||
|
|
||||||
|
* OS: Linux
|
||||||
|
* Python: 3.8 -- 3.11
|
||||||
|
* GPU: compute capability 7.0 or higher (e.g., V100, T4, RTX20xx, A100, L4, etc.)
|
||||||
|
|
||||||
|
Install with pip
|
||||||
|
----------------
|
||||||
|
|
||||||
|
You can install vLLM using pip:
|
||||||
|
|
||||||
|
.. code-block:: console
|
||||||
|
|
||||||
|
$ # (Optional) Create a new conda environment.
|
||||||
|
$ conda create -n myenv python=3.8 -y
|
||||||
|
$ conda activate myenv
|
||||||
|
|
||||||
|
$ # Install vLLM.
|
||||||
|
$ pip install vllm
|
||||||
|
|
||||||
|
|
||||||
|
.. _build_from_source:
|
||||||
|
|
||||||
|
Build from source
|
||||||
|
-----------------
|
||||||
|
|
||||||
|
You can also build and install vLLM from source:
|
||||||
|
|
||||||
|
.. code-block:: console
|
||||||
|
|
||||||
|
$ git clone https://github.com/vllm-project/vllm.git
|
||||||
|
$ cd vllm
|
||||||
|
$ pip install -e . # This may take 5-10 minutes.
|
||||||
|
|
||||||
|
.. tip::
|
||||||
|
If you have trouble building vLLM, we recommend using the NVIDIA PyTorch Docker image.
|
||||||
|
|
||||||
|
.. code-block:: console
|
||||||
|
|
||||||
|
$ # Pull the Docker image with CUDA 11.8.
|
||||||
|
$ # Use `--ipc=host` to make sure the shared memory is large enough.
|
||||||
|
$ docker run --gpus all -it --rm --ipc=host nvcr.io/nvidia/pytorch:22.12-py3
|
||||||
131
docs/source/getting_started/quickstart.rst
Normal file
131
docs/source/getting_started/quickstart.rst
Normal file
@ -0,0 +1,131 @@
|
|||||||
|
.. _quickstart:
|
||||||
|
|
||||||
|
Quickstart
|
||||||
|
==========
|
||||||
|
|
||||||
|
This guide shows how to use vLLM to:
|
||||||
|
|
||||||
|
* run offline batched inference on a dataset;
|
||||||
|
* build an API server for a large language model;
|
||||||
|
* start an OpenAI-compatible API server.
|
||||||
|
|
||||||
|
Be sure to complete the :ref:`installation instructions <installation>` before continuing with this guide.
|
||||||
|
|
||||||
|
Offline Batched Inference
|
||||||
|
-------------------------
|
||||||
|
|
||||||
|
We first show an example of using vLLM for offline batched inference on a dataset. In other words, we use vLLM to generate texts for a list of input prompts.
|
||||||
|
|
||||||
|
Import ``LLM`` and ``SamplingParams`` from vLLM. The ``LLM`` class is the main class for running offline inference with vLLM engine. The ``SamplingParams`` class specifies the parameters for the sampling process.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
|
||||||
|
Define the list of input prompts and the sampling parameters for generation. The sampling temperature is set to 0.8 and the nucleus sampling probability is set to 0.95. For more information about the sampling parameters, refer to the `class definition <https://github.com/vllm-project/vllm/blob/main/vllm/sampling_params.py>`_.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
prompts = [
|
||||||
|
"Hello, my name is",
|
||||||
|
"The president of the United States is",
|
||||||
|
"The capital of France is",
|
||||||
|
"The future of AI is",
|
||||||
|
]
|
||||||
|
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||||
|
|
||||||
|
Initialize vLLM's engine for offline inference with the ``LLM`` class and the `OPT-125M model <https://arxiv.org/abs/2205.01068>`_. The list of supported models can be found at :ref:`supported models <supported_models>`.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
llm = LLM(model="facebook/opt-125m")
|
||||||
|
|
||||||
|
Call ``llm.generate`` to generate the outputs. It adds the input prompts to vLLM engine's waiting queue and executes the vLLM engine to generate the outputs with high throughput. The outputs are returned as a list of ``RequestOutput`` objects, which include all the output tokens.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
|
|
||||||
|
# Print the outputs.
|
||||||
|
for output in outputs:
|
||||||
|
prompt = output.prompt
|
||||||
|
generated_text = output.outputs[0].text
|
||||||
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||||
|
|
||||||
|
|
||||||
|
The code example can also be found in `examples/offline_inference.py <https://github.com/vllm-project/vllm/blob/main/examples/offline_inference.py>`_.
|
||||||
|
|
||||||
|
|
||||||
|
API Server
|
||||||
|
----------
|
||||||
|
|
||||||
|
vLLM can be deployed as an LLM service. We provide an example `FastAPI <https://fastapi.tiangolo.com/>`_ server. Check `vllm/entrypoints/api_server.py <https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/api_server.py>`_ for the server implementation. The server uses ``AsyncLLMEngine`` class to support asynchronous processing of incoming requests.
|
||||||
|
|
||||||
|
Start the server:
|
||||||
|
|
||||||
|
.. code-block:: console
|
||||||
|
|
||||||
|
$ python -m vllm.entrypoints.api_server
|
||||||
|
|
||||||
|
By default, this command starts the server at ``http://localhost:8000`` with the OPT-125M model.
|
||||||
|
|
||||||
|
Query the model in shell:
|
||||||
|
|
||||||
|
.. code-block:: console
|
||||||
|
|
||||||
|
$ curl http://localhost:8000/generate \
|
||||||
|
$ -d '{
|
||||||
|
$ "prompt": "San Francisco is a",
|
||||||
|
$ "use_beam_search": true,
|
||||||
|
$ "n": 4,
|
||||||
|
$ "temperature": 0
|
||||||
|
$ }'
|
||||||
|
|
||||||
|
See `examples/api_client.py <https://github.com/vllm-project/vllm/blob/main/examples/api_client.py>`_ for a more detailed client example.
|
||||||
|
|
||||||
|
OpenAI-Compatible Server
|
||||||
|
------------------------
|
||||||
|
|
||||||
|
vLLM can be deployed as a server that mimics the OpenAI API protocol. This allows vLLM to be used as a drop-in replacement for applications using OpenAI API.
|
||||||
|
|
||||||
|
Start the server:
|
||||||
|
|
||||||
|
.. code-block:: console
|
||||||
|
|
||||||
|
$ python -m vllm.entrypoints.openai.api_server \
|
||||||
|
$ --model facebook/opt-125m
|
||||||
|
|
||||||
|
By default, it starts the server at ``http://localhost:8000``. You can specify the address with ``--host`` and ``--port`` arguments. The server currently hosts one model at a time (OPT-125M in the above command) and implements `list models <https://platform.openai.com/docs/api-reference/models/list>`_ and `create completion <https://platform.openai.com/docs/api-reference/completions/create>`_ endpoints. We are actively adding support for more endpoints.
|
||||||
|
|
||||||
|
This server can be queried in the same format as OpenAI API. For example, list the models:
|
||||||
|
|
||||||
|
.. code-block:: console
|
||||||
|
|
||||||
|
$ curl http://localhost:8000/v1/models
|
||||||
|
|
||||||
|
Query the model with input prompts:
|
||||||
|
|
||||||
|
.. code-block:: console
|
||||||
|
|
||||||
|
$ curl http://localhost:8000/v1/completions \
|
||||||
|
$ -H "Content-Type: application/json" \
|
||||||
|
$ -d '{
|
||||||
|
$ "model": "facebook/opt-125m",
|
||||||
|
$ "prompt": "San Francisco is a",
|
||||||
|
$ "max_tokens": 7,
|
||||||
|
$ "temperature": 0
|
||||||
|
$ }'
|
||||||
|
|
||||||
|
Since this server is compatible with OpenAI API, you can use it as a drop-in replacement for any applications using OpenAI API. For example, another way to query the server is via the ``openai`` python package:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import openai
|
||||||
|
# Modify OpenAI's API key and API base to use vLLM's API server.
|
||||||
|
openai.api_key = "EMPTY"
|
||||||
|
openai.api_base = "http://localhost:8000/v1"
|
||||||
|
completion = openai.Completion.create(model="facebook/opt-125m",
|
||||||
|
prompt="San Francisco is a")
|
||||||
|
print("Completion result:", completion)
|
||||||
|
|
||||||
|
For a more detailed client example, refer to `examples/openai_completion_client.py <https://github.com/vllm-project/vllm/blob/main/examples/openai_completion_client.py>`_.
|
||||||
74
docs/source/index.rst
Normal file
74
docs/source/index.rst
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
Welcome to vLLM!
|
||||||
|
================
|
||||||
|
|
||||||
|
.. figure:: ./assets/logos/vllm-logo-text-light.png
|
||||||
|
:width: 60%
|
||||||
|
:align: center
|
||||||
|
:alt: vLLM
|
||||||
|
:class: no-scaled-link
|
||||||
|
|
||||||
|
.. raw:: html
|
||||||
|
|
||||||
|
<p style="text-align:center">
|
||||||
|
<strong>Easy, fast, and cheap LLM serving for everyone
|
||||||
|
</strong>
|
||||||
|
</p>
|
||||||
|
|
||||||
|
<p style="text-align:center">
|
||||||
|
<script async defer src="https://buttons.github.io/buttons.js"></script>
|
||||||
|
<a class="github-button" href="https://github.com/vllm-project/vllm" data-show-count="true" data-size="large" aria-label="Star">Star</a>
|
||||||
|
<a class="github-button" href="https://github.com/vllm-project/vllm/subscription" data-icon="octicon-eye" data-size="large" aria-label="Watch">Watch</a>
|
||||||
|
<a class="github-button" href="https://github.com/vllm-project/vllm/fork" data-icon="octicon-repo-forked" data-size="large" aria-label="Fork">Fork</a>
|
||||||
|
</p>
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
vLLM is a fast and easy-to-use library for LLM inference and serving.
|
||||||
|
|
||||||
|
vLLM is fast with:
|
||||||
|
|
||||||
|
* State-of-the-art serving throughput
|
||||||
|
* Efficient management of attention key and value memory with **PagedAttention**
|
||||||
|
* Continuous batching of incoming requests
|
||||||
|
* Optimized CUDA kernels
|
||||||
|
|
||||||
|
vLLM is flexible and easy to use with:
|
||||||
|
|
||||||
|
* Seamless integration with popular HuggingFace models
|
||||||
|
* High-throughput serving with various decoding algorithms, including *parallel sampling*, *beam search*, and more
|
||||||
|
* Tensor parallelism support for distributed inference
|
||||||
|
* Streaming outputs
|
||||||
|
* OpenAI-compatible API server
|
||||||
|
|
||||||
|
For more information, check out the following:
|
||||||
|
|
||||||
|
* `vLLM announcing blog post <https://vllm.ai>`_ (intro to PagedAttention)
|
||||||
|
* `vLLM paper <https://arxiv.org/abs/2309.06180>`_ (SOSP 2023)
|
||||||
|
* `How continuous batching enables 23x throughput in LLM inference while reducing p50 latency <https://www.anyscale.com/blog/continuous-batching-llm-inference>`_ by Cade Daniel et al.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Documentation
|
||||||
|
-------------
|
||||||
|
|
||||||
|
.. toctree::
|
||||||
|
:maxdepth: 1
|
||||||
|
:caption: Getting Started
|
||||||
|
|
||||||
|
getting_started/installation
|
||||||
|
getting_started/quickstart
|
||||||
|
|
||||||
|
.. toctree::
|
||||||
|
:maxdepth: 1
|
||||||
|
:caption: Serving
|
||||||
|
|
||||||
|
serving/distributed_serving
|
||||||
|
serving/run_on_sky
|
||||||
|
serving/deploying_with_triton
|
||||||
|
|
||||||
|
.. toctree::
|
||||||
|
:maxdepth: 1
|
||||||
|
:caption: Models
|
||||||
|
|
||||||
|
models/supported_models
|
||||||
|
models/adding_model
|
||||||
94
docs/source/models/adding_model.rst
Normal file
94
docs/source/models/adding_model.rst
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
.. _adding_a_new_model:
|
||||||
|
|
||||||
|
Adding a New Model
|
||||||
|
==================
|
||||||
|
|
||||||
|
This document provides a high-level guide on integrating a `HuggingFace Transformers <https://github.com/huggingface/transformers>`_ model into vLLM.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
The complexity of adding a new model depends heavily on the model's architecture.
|
||||||
|
The process is considerably straightforward if the model shares a similar architecture with an existing model in vLLM.
|
||||||
|
However, for models that include new operators (e.g., a new attention mechanism), the process can be a bit more complex.
|
||||||
|
|
||||||
|
.. tip::
|
||||||
|
If you are encountering issues while integrating your model into vLLM, feel free to open an issue on our `GitHub <https://github.com/vllm-project/vllm/issues>`_ repository.
|
||||||
|
We will be happy to help you out!
|
||||||
|
|
||||||
|
|
||||||
|
0. Fork the vLLM repository
|
||||||
|
--------------------------------
|
||||||
|
|
||||||
|
Start by forking our `GitHub <https://github.com/vllm-project/vllm/>`_ repository and then :ref:`build it from source <build_from_source>`.
|
||||||
|
This gives you the ability to modify the codebase and test your model.
|
||||||
|
|
||||||
|
|
||||||
|
1. Bring your model code
|
||||||
|
------------------------
|
||||||
|
|
||||||
|
Clone the PyTorch model code from the HuggingFace Transformers repository and put it into the `vllm/model_executor/models <https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models>`_ directory.
|
||||||
|
For instance, vLLM's `OPT model <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/opt.py>`_ was adpated from the HuggingFace's `modeling_opt.py <https://github.com/huggingface/transformers/blob/main/src/transformers/models/opt/modeling_opt.py>`_ file.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
When copying the model code, make sure to review and adhere to the code's copyright and licensing terms.
|
||||||
|
|
||||||
|
|
||||||
|
2. Rewrite the :code:`forward` methods
|
||||||
|
--------------------------------------
|
||||||
|
|
||||||
|
Next, you need to rewrite the :code:`forward` methods of your model by following these steps:
|
||||||
|
|
||||||
|
1. Remove any unnecessary code, such as the code only used for training.
|
||||||
|
2. Change the input parameters:
|
||||||
|
|
||||||
|
.. code-block:: diff
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
- attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
- position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
- past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
- inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
- labels: Optional[torch.LongTensor] = None,
|
||||||
|
- use_cache: Optional[bool] = None,
|
||||||
|
- output_attentions: Optional[bool] = None,
|
||||||
|
- output_hidden_states: Optional[bool] = None,
|
||||||
|
- return_dict: Optional[bool] = None,
|
||||||
|
-) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||||
|
+ positions: torch.Tensor,
|
||||||
|
+ kv_caches: List[KVCache],
|
||||||
|
+ input_metadata: InputMetadata,
|
||||||
|
+ cache_events: Optional[List[torch.cuda.Event]],
|
||||||
|
+) -> SamplerOutput:
|
||||||
|
|
||||||
|
3. Update the code by considering that :code:`input_ids` and :code:`positions` are now flattened tensors.
|
||||||
|
4. Replace the attention operation with either :code:`GPTPagedAttention` or :code:`GPTNeoXPagedAttention`, depending on the model's architecture.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
Currently, vLLM supports the basic multi-head attention mechanism and its variant with rotary positional embeddings.
|
||||||
|
If your model employs a different attention mechanism, you will need to implement a new attention layer in vLLM.
|
||||||
|
|
||||||
|
|
||||||
|
3. (Optional) Implement tensor parallelism support
|
||||||
|
--------------------------------------------------
|
||||||
|
|
||||||
|
If your model is too large to fit into a single GPU, you can use tensor parallelism to manage it.
|
||||||
|
To do this, substitute your model's linear and embedding layers with their tensor-parallel versions.
|
||||||
|
For the embedding layer, you can simply replace :code:`nn.Embedding` with :code:`VocabParallelEmbedding`.
|
||||||
|
When it comes to the linear layers, you should use either :code:`RowParallelLinear` or :code:`ColumnParallelLinear`.
|
||||||
|
Typically, :code:`ColumnParallelLinear` is used for QKV linear layers and the first linear layers of the MLP blocks.
|
||||||
|
For the remaining linear layers, :code:`RowParallelLinear` is used.
|
||||||
|
|
||||||
|
|
||||||
|
4. Implement the weight loading logic
|
||||||
|
-------------------------------------
|
||||||
|
|
||||||
|
You now need to implement the :code:`load_weights` method in your :code:`*ForCausalLM` class.
|
||||||
|
This method should load the weights from the HuggingFace's checkpoint file and assign them to the corresponding layers in your model.
|
||||||
|
While the process is straightforward for most layers, the tensor-parallel layers necessitate some additional care as their weights should be partitioned to multiple GPUs.
|
||||||
|
|
||||||
|
|
||||||
|
5. Register your model
|
||||||
|
----------------------
|
||||||
|
|
||||||
|
Finally, include your :code:`*ForCausalLM` class in `vllm/model_executor/models/__init__.py <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/__init__.py>`_ and register it to the :code:`_MODEL_REGISTRY` in `vllm/model_executor/model_loader.py <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/model_loader.py>`_.
|
||||||
75
docs/source/models/supported_models.rst
Normal file
75
docs/source/models/supported_models.rst
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
.. _supported_models:
|
||||||
|
|
||||||
|
Supported Models
|
||||||
|
================
|
||||||
|
|
||||||
|
vLLM supports a variety of generative Transformer models in `HuggingFace Transformers <https://huggingface.co/models>`_.
|
||||||
|
The following is the list of model architectures that are currently supported by vLLM.
|
||||||
|
Alongside each architecture, we include some popular models that use it.
|
||||||
|
|
||||||
|
.. list-table::
|
||||||
|
:widths: 25 25 50
|
||||||
|
:header-rows: 1
|
||||||
|
|
||||||
|
* - Architecture
|
||||||
|
- Models
|
||||||
|
- Example HuggingFace Models
|
||||||
|
* - :code:`AquilaForCausalLM`
|
||||||
|
- Aquila
|
||||||
|
- :code:`BAAI/Aquila-7B`, :code:`BAAI/AquilaChat-7B`, etc.
|
||||||
|
* - :code:`BaiChuanForCausalLM`
|
||||||
|
- Baichuan
|
||||||
|
- :code:`baichuan-inc/Baichuan-7B`, :code:`baichuan-inc/Baichuan-13B-Chat`, etc.
|
||||||
|
* - :code:`BloomForCausalLM`
|
||||||
|
- BLOOM, BLOOMZ, BLOOMChat
|
||||||
|
- :code:`bigscience/bloom`, :code:`bigscience/bloomz`, etc.
|
||||||
|
* - :code:`FalconForCausalLM`
|
||||||
|
- Falcon
|
||||||
|
- :code:`tiiuae/falcon-7b`, :code:`tiiuae/falcon-40b`, :code:`tiiuae/falcon-rw-7b`, etc.
|
||||||
|
* - :code:`GPT2LMHeadModel`
|
||||||
|
- GPT-2
|
||||||
|
- :code:`gpt2`, :code:`gpt2-xl`, etc.
|
||||||
|
* - :code:`GPTBigCodeForCausalLM`
|
||||||
|
- StarCoder, SantaCoder, WizardCoder
|
||||||
|
- :code:`bigcode/starcoder`, :code:`bigcode/gpt_bigcode-santacoder`, :code:`WizardLM/WizardCoder-15B-V1.0`, etc.
|
||||||
|
* - :code:`GPTJForCausalLM`
|
||||||
|
- GPT-J
|
||||||
|
- :code:`EleutherAI/gpt-j-6b`, :code:`nomic-ai/gpt4all-j`, etc.
|
||||||
|
* - :code:`GPTNeoXForCausalLM`
|
||||||
|
- GPT-NeoX, Pythia, OpenAssistant, Dolly V2, StableLM
|
||||||
|
- :code:`EleutherAI/gpt-neox-20b`, :code:`EleutherAI/pythia-12b`, :code:`OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5`, :code:`databricks/dolly-v2-12b`, :code:`stabilityai/stablelm-tuned-alpha-7b`, etc.
|
||||||
|
* - :code:`InternLMForCausalLM`
|
||||||
|
- InternLM
|
||||||
|
- :code:`internlm/internlm-7b`, :code:`internlm/internlm-chat-7b`, etc.
|
||||||
|
* - :code:`LlamaForCausalLM`
|
||||||
|
- LLaMA, LLaMA-2, Vicuna, Alpaca, Koala, Guanaco
|
||||||
|
- :code:`meta-llama/Llama-2-13b-hf`, :code:`meta-llama/Llama-2-70b-hf`, :code:`openlm-research/open_llama_13b`, :code:`lmsys/vicuna-13b-v1.3`, :code:`young-geng/koala`, etc.
|
||||||
|
* - :code:`MistralForCausalLM`
|
||||||
|
- Mistral, Mistral-Instruct
|
||||||
|
- :code:`mistralai/Mistral-7B-v0.1`, :code:`mistralai/Mistral-7B-Instruct-v0.1`, etc.
|
||||||
|
* - :code:`MPTForCausalLM`
|
||||||
|
- MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter
|
||||||
|
- :code:`mosaicml/mpt-7b`, :code:`mosaicml/mpt-7b-storywriter`, :code:`mosaicml/mpt-30b`, etc.
|
||||||
|
* - :code:`OPTForCausalLM`
|
||||||
|
- OPT, OPT-IML
|
||||||
|
- :code:`facebook/opt-66b`, :code:`facebook/opt-iml-max-30b`, etc.
|
||||||
|
* - :code:`QWenLMHeadModel`
|
||||||
|
- Qwen
|
||||||
|
- :code:`Qwen/Qwen-7B`, :code:`Qwen/Qwen-7B-Chat`, etc.
|
||||||
|
|
||||||
|
If your model uses one of the above model architectures, you can seamlessly run your model with vLLM.
|
||||||
|
Otherwise, please refer to :ref:`Adding a New Model <adding_a_new_model>` for instructions on how to implement support for your model.
|
||||||
|
Alternatively, you can raise an issue on our `GitHub <https://github.com/vllm-project/vllm/issues>`_ project.
|
||||||
|
|
||||||
|
.. tip::
|
||||||
|
The easiest way to check if your model is supported is to run the program below:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from vllm import LLM
|
||||||
|
|
||||||
|
llm = LLM(model=...) # Name or path of your model
|
||||||
|
output = llm.generate("Hello, my name is")
|
||||||
|
print(output)
|
||||||
|
|
||||||
|
If vLLM successfully generates text, it indicates that your model is supported.
|
||||||
6
docs/source/serving/deploying_with_triton.rst
Normal file
6
docs/source/serving/deploying_with_triton.rst
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
.. _deploying_with_triton:
|
||||||
|
|
||||||
|
Deploying with NVIDIA Triton
|
||||||
|
============================
|
||||||
|
|
||||||
|
The `Triton Inference Server <https://github.com/triton-inference-server>`_ hosts a tutorial demonstrating how to quickly deploy a simple `facebook/opt-125m <https://huggingface.co/facebook/opt-125m>`_ model using vLLM. Please see `Deploying a vLLM model in Triton <https://github.com/triton-inference-server/tutorials/blob/main/Quick_Deploy/vLLM/README.md#deploying-a-vllm-model-in-triton>`_ for more details.
|
||||||
38
docs/source/serving/distributed_serving.rst
Normal file
38
docs/source/serving/distributed_serving.rst
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
.. _distributed_serving:
|
||||||
|
|
||||||
|
Distributed Inference and Serving
|
||||||
|
=================================
|
||||||
|
|
||||||
|
vLLM supports distributed tensor-parallel inference and serving. Currently, we support `Megatron-LM's tensor parallel algorithm <https://arxiv.org/pdf/1909.08053.pdf>`_. We manage the distributed runtime with `Ray <https://github.com/ray-project/ray>`_. To run distributed inference, install Ray with:
|
||||||
|
|
||||||
|
.. code-block:: console
|
||||||
|
|
||||||
|
$ pip install ray
|
||||||
|
|
||||||
|
To run multi-GPU inference with the :code:`LLM` class, set the :code:`tensor_parallel_size` argument to the number of GPUs you want to use. For example, to run inference on 4 GPUs:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from vllm import LLM
|
||||||
|
llm = LLM("facebook/opt-13b", tensor_parallel_size=4)
|
||||||
|
output = llm.generate("San Franciso is a")
|
||||||
|
|
||||||
|
To run multi-GPU serving, pass in the :code:`--tensor-parallel-size` argument when starting the server. For example, to run API server on 4 GPUs:
|
||||||
|
|
||||||
|
.. code-block:: console
|
||||||
|
|
||||||
|
$ python -m vllm.entrypoints.api_server \
|
||||||
|
$ --model facebook/opt-13b \
|
||||||
|
$ --tensor-parallel-size 4
|
||||||
|
|
||||||
|
To scale vLLM beyond a single machine, start a `Ray runtime <https://docs.ray.io/en/latest/ray-core/starting-ray.html>`_ via CLI before running vLLM:
|
||||||
|
|
||||||
|
.. code-block:: console
|
||||||
|
|
||||||
|
$ # On head node
|
||||||
|
$ ray start --head
|
||||||
|
|
||||||
|
$ # On worker nodes
|
||||||
|
$ ray start --address=<ray-head-address>
|
||||||
|
|
||||||
|
After that, you can run inference and serving on multiple machines by launching the vLLM process on the head node by setting :code:`tensor_parallel_size` to the number of GPUs to be the total number of GPUs across all machines.
|
||||||
69
docs/source/serving/run_on_sky.rst
Normal file
69
docs/source/serving/run_on_sky.rst
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
.. _on_cloud:
|
||||||
|
|
||||||
|
Running on clouds with SkyPilot
|
||||||
|
===============================
|
||||||
|
|
||||||
|
.. raw:: html
|
||||||
|
|
||||||
|
<p align="center">
|
||||||
|
<img src="https://imgur.com/yxtzPEu.png" alt="vLLM"/>
|
||||||
|
</p>
|
||||||
|
|
||||||
|
vLLM can be run on the cloud to scale to multiple GPUs with `SkyPilot <https://github.com/skypilot-org/skypilot>`__, an open-source framework for running LLMs on any cloud.
|
||||||
|
|
||||||
|
To install SkyPilot and setup your cloud credentials, run:
|
||||||
|
|
||||||
|
.. code-block:: console
|
||||||
|
|
||||||
|
$ pip install skypilot
|
||||||
|
$ sky check
|
||||||
|
|
||||||
|
See the vLLM SkyPilot YAML for serving, `serving.yaml <https://github.com/skypilot-org/skypilot/blob/master/llm/vllm/serve.yaml>`__.
|
||||||
|
|
||||||
|
.. code-block:: yaml
|
||||||
|
|
||||||
|
resources:
|
||||||
|
accelerators: A100
|
||||||
|
|
||||||
|
envs:
|
||||||
|
MODEL_NAME: decapoda-research/llama-13b-hf
|
||||||
|
TOKENIZER: hf-internal-testing/llama-tokenizer
|
||||||
|
|
||||||
|
setup: |
|
||||||
|
conda create -n vllm python=3.9 -y
|
||||||
|
conda activate vllm
|
||||||
|
git clone https://github.com/vllm-project/vllm.git
|
||||||
|
cd vllm
|
||||||
|
pip install .
|
||||||
|
pip install gradio
|
||||||
|
|
||||||
|
run: |
|
||||||
|
conda activate vllm
|
||||||
|
echo 'Starting vllm api server...'
|
||||||
|
python -u -m vllm.entrypoints.api_server \
|
||||||
|
--model $MODEL_NAME \
|
||||||
|
--tensor-parallel-size $SKYPILOT_NUM_GPUS_PER_NODE \
|
||||||
|
--tokenizer $TOKENIZER 2>&1 | tee api_server.log &
|
||||||
|
echo 'Waiting for vllm api server to start...'
|
||||||
|
while ! `cat api_server.log | grep -q 'Uvicorn running on'`; do sleep 1; done
|
||||||
|
echo 'Starting gradio server...'
|
||||||
|
python vllm/examples/gradio_webserver.py
|
||||||
|
|
||||||
|
Start the serving the LLaMA-13B model on an A100 GPU:
|
||||||
|
|
||||||
|
.. code-block:: console
|
||||||
|
|
||||||
|
$ sky launch serving.yaml
|
||||||
|
|
||||||
|
Check the output of the command. There will be a sharable gradio link (like the last line of the following). Open it in your browser to use the LLaMA model to do the text completion.
|
||||||
|
|
||||||
|
.. code-block:: console
|
||||||
|
|
||||||
|
(task, pid=7431) Running on public URL: https://<gradio-hash>.gradio.live
|
||||||
|
|
||||||
|
**Optional**: Serve the 65B model instead of the default 13B and use more GPU:
|
||||||
|
|
||||||
|
.. code-block:: console
|
||||||
|
|
||||||
|
sky launch -c vllm-serve-new -s serve.yaml --gpus A100:8 --env MODEL_NAME=decapoda-research/llama-65b-hf
|
||||||
|
|
||||||
77
examples/api_client.py
Normal file
77
examples/api_client.py
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
"""Example Python client for vllm.entrypoints.api_server"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
from typing import Iterable, List
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
def clear_line(n: int = 1) -> None:
|
||||||
|
LINE_UP = '\033[1A'
|
||||||
|
LINE_CLEAR = '\x1b[2K'
|
||||||
|
for _ in range(n):
|
||||||
|
print(LINE_UP, end=LINE_CLEAR, flush=True)
|
||||||
|
|
||||||
|
|
||||||
|
def post_http_request(prompt: str,
|
||||||
|
api_url: str,
|
||||||
|
n: int = 1,
|
||||||
|
stream: bool = False) -> requests.Response:
|
||||||
|
headers = {"User-Agent": "Test Client"}
|
||||||
|
pload = {
|
||||||
|
"prompt": prompt,
|
||||||
|
"n": n,
|
||||||
|
"use_beam_search": True,
|
||||||
|
"temperature": 0.0,
|
||||||
|
"max_tokens": 16,
|
||||||
|
"stream": stream,
|
||||||
|
}
|
||||||
|
response = requests.post(api_url, headers=headers, json=pload, stream=True)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def get_streaming_response(response: requests.Response) -> Iterable[List[str]]:
|
||||||
|
for chunk in response.iter_lines(chunk_size=8192,
|
||||||
|
decode_unicode=False,
|
||||||
|
delimiter=b"\0"):
|
||||||
|
if chunk:
|
||||||
|
data = json.loads(chunk.decode("utf-8"))
|
||||||
|
output = data["text"]
|
||||||
|
yield output
|
||||||
|
|
||||||
|
|
||||||
|
def get_response(response: requests.Response) -> List[str]:
|
||||||
|
data = json.loads(response.content)
|
||||||
|
output = data["text"]
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--host", type=str, default="localhost")
|
||||||
|
parser.add_argument("--port", type=int, default=8000)
|
||||||
|
parser.add_argument("--n", type=int, default=4)
|
||||||
|
parser.add_argument("--prompt", type=str, default="San Francisco is a")
|
||||||
|
parser.add_argument("--stream", action="store_true")
|
||||||
|
args = parser.parse_args()
|
||||||
|
prompt = args.prompt
|
||||||
|
api_url = f"http://{args.host}:{args.port}/generate"
|
||||||
|
n = args.n
|
||||||
|
stream = args.stream
|
||||||
|
|
||||||
|
print(f"Prompt: {prompt!r}\n", flush=True)
|
||||||
|
response = post_http_request(prompt, api_url, n, stream)
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
num_printed_lines = 0
|
||||||
|
for h in get_streaming_response(response):
|
||||||
|
clear_line(num_printed_lines)
|
||||||
|
num_printed_lines = 0
|
||||||
|
for i, line in enumerate(h):
|
||||||
|
num_printed_lines += 1
|
||||||
|
print(f"Beam candidate {i}: {line!r}", flush=True)
|
||||||
|
else:
|
||||||
|
output = get_response(response)
|
||||||
|
for i, line in enumerate(output):
|
||||||
|
print(f"Beam candidate {i}: {line!r}", flush=True)
|
||||||
52
examples/gradio_webserver.py
Normal file
52
examples/gradio_webserver.py
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
def http_bot(prompt):
|
||||||
|
headers = {"User-Agent": "vLLM Client"}
|
||||||
|
pload = {
|
||||||
|
"prompt": prompt,
|
||||||
|
"stream": True,
|
||||||
|
"max_tokens": 128,
|
||||||
|
}
|
||||||
|
response = requests.post(args.model_url,
|
||||||
|
headers=headers,
|
||||||
|
json=pload,
|
||||||
|
stream=True)
|
||||||
|
|
||||||
|
for chunk in response.iter_lines(chunk_size=8192,
|
||||||
|
decode_unicode=False,
|
||||||
|
delimiter=b"\0"):
|
||||||
|
if chunk:
|
||||||
|
data = json.loads(chunk.decode("utf-8"))
|
||||||
|
output = data["text"][0]
|
||||||
|
yield output
|
||||||
|
|
||||||
|
|
||||||
|
def build_demo():
|
||||||
|
with gr.Blocks() as demo:
|
||||||
|
gr.Markdown("# vLLM text completion demo\n")
|
||||||
|
inputbox = gr.Textbox(label="Input",
|
||||||
|
placeholder="Enter text and press ENTER")
|
||||||
|
outputbox = gr.Textbox(label="Output",
|
||||||
|
placeholder="Generated result from the model")
|
||||||
|
inputbox.submit(http_bot, [inputbox], [outputbox])
|
||||||
|
return demo
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--host", type=str, default=None)
|
||||||
|
parser.add_argument("--port", type=int, default=8001)
|
||||||
|
parser.add_argument("--model-url",
|
||||||
|
type=str,
|
||||||
|
default="http://localhost:8000/generate")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
demo = build_demo()
|
||||||
|
demo.queue(concurrency_count=100).launch(server_name=args.host,
|
||||||
|
server_port=args.port,
|
||||||
|
share=True)
|
||||||
51
examples/llm_engine_example.py
Normal file
51
examples/llm_engine_example.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
import argparse
|
||||||
|
|
||||||
|
from vllm import EngineArgs, LLMEngine, SamplingParams
|
||||||
|
|
||||||
|
|
||||||
|
def main(args: argparse.Namespace):
|
||||||
|
# Parse the CLI argument and initialize the engine.
|
||||||
|
engine_args = EngineArgs.from_cli_args(args)
|
||||||
|
engine = LLMEngine.from_engine_args(engine_args)
|
||||||
|
|
||||||
|
# Test the following prompts.
|
||||||
|
test_prompts = [
|
||||||
|
("A robot may not injure a human being",
|
||||||
|
SamplingParams(temperature=0.0, logprobs=1, prompt_logprobs=1)),
|
||||||
|
("To be or not to be,",
|
||||||
|
SamplingParams(temperature=0.8, top_k=5, presence_penalty=0.2)),
|
||||||
|
("What is the meaning of life?",
|
||||||
|
SamplingParams(n=2,
|
||||||
|
best_of=5,
|
||||||
|
temperature=0.8,
|
||||||
|
top_p=0.95,
|
||||||
|
frequency_penalty=0.1)),
|
||||||
|
("It is only with the heart that one can see rightly",
|
||||||
|
SamplingParams(n=3, best_of=3, use_beam_search=True,
|
||||||
|
temperature=0.0)),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Run the engine by calling `engine.step()` manually.
|
||||||
|
request_id = 0
|
||||||
|
while True:
|
||||||
|
# To test continuous batching, we add one request at each step.
|
||||||
|
if test_prompts:
|
||||||
|
prompt, sampling_params = test_prompts.pop(0)
|
||||||
|
engine.add_request(str(request_id), prompt, sampling_params)
|
||||||
|
request_id += 1
|
||||||
|
|
||||||
|
request_outputs = engine.step()
|
||||||
|
for request_output in request_outputs:
|
||||||
|
if request_output.finished:
|
||||||
|
print(request_output)
|
||||||
|
|
||||||
|
if not (engine.has_unfinished_requests() or test_prompts):
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description='Demo on using the LLMEngine class directly')
|
||||||
|
parser = EngineArgs.add_cli_args(parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
main(args)
|
||||||
22
examples/offline_inference.py
Normal file
22
examples/offline_inference.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
|
||||||
|
# Sample prompts.
|
||||||
|
prompts = [
|
||||||
|
"Hello, my name is",
|
||||||
|
"The president of the United States is",
|
||||||
|
"The capital of France is",
|
||||||
|
"The future of AI is",
|
||||||
|
]
|
||||||
|
# Create a sampling params object.
|
||||||
|
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||||
|
|
||||||
|
# Create an LLM.
|
||||||
|
llm = LLM(model="facebook/opt-125m")
|
||||||
|
# Generate texts from the prompts. The output is a list of RequestOutput objects
|
||||||
|
# that contain the prompt, generated text, and other information.
|
||||||
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
|
# Print the outputs.
|
||||||
|
for output in outputs:
|
||||||
|
prompt = output.prompt
|
||||||
|
generated_text = output.outputs[0].text
|
||||||
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user