Compare commits
157 Commits
Author | SHA1 | Date | |
---|---|---|---|
c5f7740d89 | |||
be66d9b125 | |||
e1054247ba | |||
8d17774f92 | |||
e946260cf3 | |||
edb305584b | |||
bb00f66e19 | |||
e87557b069 | |||
dcc543a298 | |||
0fc280b06c | |||
20d0699d49 | |||
686f5e3210 | |||
415d109527 | |||
521b35f799 | |||
cb08cd0d75 | |||
2a2c135b41 | |||
65ea2ddf17 | |||
b514d3c496 | |||
7076fa1c9f | |||
660a7fcfa4 | |||
054072bee5 | |||
eb825c1e74 | |||
1b290ace4f | |||
0d578228ca | |||
aebfcb262a | |||
ab9e8488d5 | |||
fd58b73a40 | |||
8efe23f150 | |||
06458a0b42 | |||
1a2bbc9301 | |||
e7f579eb97 | |||
8516999495 | |||
9f669a9a7c | |||
555bdcc5a3 | |||
54ca1ba71d | |||
9738b84a08 | |||
1fe0990023 | |||
7e90a2d117 | |||
5687d584fe | |||
cf8849f2d6 | |||
e575df33b1 | |||
0ce8647dc5 | |||
9cabcb7645 | |||
7b895c5976 | |||
7013a80170 | |||
79a30912b8 | |||
2f3d36a8a1 | |||
ac8d36f3e5 | |||
15f5632365 | |||
aa9af07cac | |||
69be658bba | |||
beac8dd461 | |||
28b47d1e49 | |||
1f24755bf8 | |||
bf31d3606a | |||
d189170b6c | |||
f61dc8072f | |||
f8a1e39fae | |||
a132435204 | |||
9524867701 | |||
c1376e0f82 | |||
651c614aa4 | |||
d3a5bd9fb7 | |||
e8ef4c0820 | |||
348897af31 | |||
9d9072a069 | |||
928de46888 | |||
29678cd213 | |||
d0740dff1b | |||
de89472897 | |||
e7c8555d06 | |||
ec3b5ce9cc | |||
6368e777a8 | |||
875afe38ab | |||
ee8217e5be | |||
980dd4a2c4 | |||
8285736840 | |||
91fce82c6f | |||
ac5cf86aa6 | |||
6a6119554c | |||
b95ee898fe | |||
9eed4d1f3e | |||
6b5296aa3a | |||
ee92b58b3a | |||
09ff7f106a | |||
acbed3ef40 | |||
66d18a7fb0 | |||
ba0bfd40e2 | |||
84e4e37d14 | |||
a60b353005 | |||
ebe4d1db3a | |||
b5a10eb0ef | |||
0967102c6d | |||
e2fb71ec9f | |||
f936657eb6 | |||
6f88f762bf | |||
202351d5bf | |||
2e8e49fce3 | |||
a8e98aee0c | |||
bb1ba58f06 | |||
7bedab5748 | |||
20f7cc4cde | |||
649aa730c5 | |||
a19bc5c628 | |||
28e616c4e3 | |||
30e775281d | |||
21877b0d75 | |||
cf5cb1e33e | |||
03ffd0a022 | |||
a425bd9a9a | |||
bbbf86565f | |||
9f6be8692e | |||
f187877945 | |||
947b794146 | |||
8d926e91f1 | |||
4ee52bb169 | |||
7d7e3b78a3 | |||
f98b745a81 | |||
2d1e86f1b1 | |||
1ac4ccf73c | |||
2ac4d5e2bf | |||
3302f0aef3 | |||
6f2dd6c37e | |||
bc0644574c | |||
400b8289f7 | |||
c1026311b5 | |||
2b1c116b5a | |||
cc796b1358 | |||
f029ef94d7 | |||
95592fa00a | |||
fbe66e1d0b | |||
90979c38f8 | |||
e21d7687a9 | |||
ff36139ffc | |||
e3e79e9e8a | |||
b9fe4616f9 | |||
64ca424e75 | |||
b5f93d0631 | |||
a58936966f | |||
dd54a4b026 | |||
eda1a7cad3 | |||
f04908cae7 | |||
ab019eea75 | |||
9841d48a10 | |||
3272d7a0b7 | |||
0bb1e885a0 | |||
d6545ad22e | |||
90eb3f43ca | |||
e67b4f2c2a | |||
d6770d1f23 | |||
b9cecc2635 | |||
898285c9bf | |||
a62de9ecfd | |||
4042d192f5 | |||
1117aa1411 | |||
080438477f | |||
4b5bcf8906 |
11
.github/workflows/publish.yml
vendored
@ -43,13 +43,14 @@ jobs:
|
||||
name: Build Wheel
|
||||
runs-on: ${{ matrix.os }}
|
||||
needs: release
|
||||
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: ['ubuntu-20.04']
|
||||
python-version: ['3.8', '3.9', '3.10', '3.11']
|
||||
cuda-version: ['11.8'] # Github runner can't build anything older than 11.8
|
||||
pytorch-version: ['2.1.0']
|
||||
cuda-version: ['11.8', '12.1']
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
@ -69,9 +70,9 @@ jobs:
|
||||
run: |
|
||||
bash -x .github/workflows/scripts/cuda-install.sh ${{ matrix.cuda-version }} ${{ matrix.os }}
|
||||
|
||||
- name: Install PyTorch-cu${{ matrix.cuda-version }}
|
||||
- name: Install PyTorch ${{ matrix.pytorch-version }} with CUDA ${{ matrix.cuda-version }}
|
||||
run: |
|
||||
bash -x .github/workflows/scripts/pytorch-install.sh ${{ matrix.python-version }} ${{ matrix.cuda-version }}
|
||||
bash -x .github/workflows/scripts/pytorch-install.sh ${{ matrix.python-version }} ${{ matrix.pytorch-version }} ${{ matrix.cuda-version }}
|
||||
|
||||
- name: Build wheel
|
||||
shell: bash
|
||||
@ -81,7 +82,7 @@ jobs:
|
||||
asset_name=${wheel_name//"linux"/"manylinux1"}
|
||||
echo "wheel_name=${wheel_name}" >> $GITHUB_ENV
|
||||
echo "asset_name=${asset_name}" >> $GITHUB_ENV
|
||||
|
||||
|
||||
- name: Upload Release Asset
|
||||
uses: actions/upload-release-asset@v1
|
||||
env:
|
||||
|
2
.github/workflows/pylint.yml
vendored
@ -28,4 +28,4 @@ jobs:
|
||||
pip install pylint==2.8.2
|
||||
- name: Analysing the code with pylint
|
||||
run: |
|
||||
pylint vllm
|
||||
pylint vllm tests
|
||||
|
3
.github/workflows/scripts/build.sh
vendored
@ -11,5 +11,8 @@ LD_LIBRARY_PATH=${cuda_home}/lib64:$LD_LIBRARY_PATH
|
||||
$python_executable -m pip install wheel packaging
|
||||
$python_executable -m pip install -r requirements.txt
|
||||
|
||||
# Limit the number of parallel jobs to avoid OOM
|
||||
export MAX_JOBS=1
|
||||
|
||||
# Build
|
||||
$python_executable setup.py bdist_wheel --dist-dir=dist
|
||||
|
5
.github/workflows/scripts/cuda-install.sh
vendored
@ -16,3 +16,8 @@ sudo apt clean
|
||||
# Test nvcc
|
||||
PATH=/usr/local/cuda-$1/bin:${PATH}
|
||||
nvcc --version
|
||||
|
||||
# Log gcc, g++, c++ versions
|
||||
gcc --version
|
||||
g++ --version
|
||||
c++ --version
|
||||
|
5
.github/workflows/scripts/pytorch-install.sh
vendored
@ -1,11 +1,12 @@
|
||||
#!/bin/bash
|
||||
|
||||
python_executable=python$1
|
||||
cuda_version=$2
|
||||
pytorch_version=$2
|
||||
cuda_version=$3
|
||||
|
||||
# Install torch
|
||||
$python_executable -m pip install numpy pyyaml scipy ipython mkl mkl-include ninja cython typing pandas typing-extensions dataclasses setuptools && conda clean -ya
|
||||
$python_executable -m pip install torch -f https://download.pytorch.org/whl/cu${cuda_version//./}/torch_stable.html
|
||||
$python_executable -m pip install torch==${pytorch_version}+cu${cuda_version//./} --extra-index-url https://download.pytorch.org/whl/cu${cuda_version//./}
|
||||
|
||||
# Print version information
|
||||
$python_executable --version
|
||||
|
2
.github/workflows/yapf.yml
vendored
@ -28,4 +28,4 @@ jobs:
|
||||
pip install toml==0.10.2
|
||||
- name: Running yapf
|
||||
run: |
|
||||
yapf --diff --recursive vllm --exclude 'vllm/model_executor/parallel_utils/**'
|
||||
yapf --diff --recursive vllm tests
|
||||
|
4
.gitignore
vendored
@ -173,3 +173,7 @@ cython_debug/
|
||||
|
||||
# Sphinx documentation
|
||||
_build/
|
||||
|
||||
# vim swap files
|
||||
*.swo
|
||||
*.swp
|
||||
|
@ -8,7 +8,7 @@
|
||||
[MASTER]
|
||||
|
||||
# Files or directories to be skipped. They should be base names, not paths.
|
||||
ignore=docs,parallel_utils
|
||||
ignore=docs
|
||||
|
||||
# Files or directories matching the regex patterns are skipped. The regex
|
||||
# matches against base names, not paths.
|
||||
|
72
Dockerfile
Normal file
@ -0,0 +1,72 @@
|
||||
FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 AS dev
|
||||
|
||||
RUN apt-get update -y \
|
||||
&& apt-get install -y python3-pip
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
# install build and runtime dependencies
|
||||
COPY requirements.txt requirements.txt
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
pip install -r requirements.txt
|
||||
|
||||
# install development dependencies
|
||||
COPY requirements-dev.txt requirements-dev.txt
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
pip install -r requirements-dev.txt
|
||||
|
||||
# image to build pytorch extensions
|
||||
FROM dev AS build
|
||||
|
||||
# copy input files
|
||||
COPY csrc csrc
|
||||
COPY setup.py setup.py
|
||||
COPY requirements.txt requirements.txt
|
||||
COPY pyproject.toml pyproject.toml
|
||||
COPY vllm/__init__.py vllm/__init__.py
|
||||
|
||||
# max jobs used by Ninja to build extensions
|
||||
ENV MAX_JOBS=$max_jobs
|
||||
RUN python3 setup.py build_ext --inplace
|
||||
|
||||
# image to run unit testing suite
|
||||
FROM dev AS test
|
||||
|
||||
# copy pytorch extensions separately to avoid having to rebuild
|
||||
# when python code changes
|
||||
COPY --from=build /workspace/vllm/*.so /workspace/vllm/
|
||||
COPY tests tests
|
||||
COPY vllm vllm
|
||||
|
||||
ENTRYPOINT ["python3", "-m", "pytest", "tests"]
|
||||
|
||||
# use CUDA base as CUDA runtime dependencies are already installed via pip
|
||||
FROM nvidia/cuda:12.1.0-base-ubuntu22.04 AS vllm-base
|
||||
|
||||
# libnccl required for ray
|
||||
RUN apt-get update -y \
|
||||
&& apt-get install -y python3-pip
|
||||
|
||||
WORKDIR /workspace
|
||||
COPY requirements.txt requirements.txt
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
pip install -r requirements.txt
|
||||
|
||||
FROM vllm-base AS vllm
|
||||
COPY --from=build /workspace/vllm/*.so /workspace/vllm/
|
||||
COPY vllm vllm
|
||||
|
||||
EXPOSE 8000
|
||||
ENTRYPOINT ["python3", "-m", "vllm.entrypoints.api_server"]
|
||||
|
||||
# openai api server alternative
|
||||
FROM vllm-base AS vllm-openai
|
||||
# install additional dependencies for openai api server
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
pip install accelerate fschat
|
||||
|
||||
COPY --from=build /workspace/vllm/*.so /workspace/vllm/
|
||||
COPY vllm vllm
|
||||
|
||||
ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"]
|
||||
|
57
README.md
@ -10,13 +10,16 @@ Easy, fast, and cheap LLM serving for everyone
|
||||
</h3>
|
||||
|
||||
<p align="center">
|
||||
| <a href="https://vllm.readthedocs.io/en/latest/"><b>Documentation</b></a> | <a href="https://vllm.ai"><b>Blog</b></a> | <a href="https://github.com/vllm-project/vllm/discussions"><b>Discussions</b></a> |
|
||||
| <a href="https://vllm.readthedocs.io/en/latest/"><b>Documentation</b></a> | <a href="https://vllm.ai"><b>Blog</b></a> | <a href="https://arxiv.org/abs/2309.06180"><b>Paper</b></a> | <a href="https://discord.gg/jz7wjKhh6g"><b>Discord</b></a> |
|
||||
|
||||
</p>
|
||||
|
||||
---
|
||||
|
||||
*Latest News* 🔥
|
||||
- [2023/10] We hosted [the first vLLM meetup](https://lu.ma/first-vllm-meetup) in SF! Please find the meetup slides [here](https://docs.google.com/presentation/d/1QL-XPFXiFpDBh86DbEegFXBXFXjix4v032GhShbKf3s/edit?usp=sharing).
|
||||
- [2023/09] We created our [Discord server](https://discord.gg/jz7wjKhh6g)! Join us to discuss vLLM and LLM serving! We will also post the latest announcements and updates there.
|
||||
- [2023/09] We released our [PagedAttention paper](https://arxiv.org/abs/2309.06180) on arXiv!
|
||||
- [2023/08] We would like to express our sincere gratitude to [Andreessen Horowitz](https://a16z.com/2023/08/30/supporting-the-open-source-ai-community/) (a16z) for providing a generous grant to support the open-source development and research of vLLM.
|
||||
- [2023/07] Added support for LLaMA-2! You can run and serve 7B/13B/70B LLaMA-2s on vLLM with a single command!
|
||||
- [2023/06] Serving vLLM On any Cloud with SkyPilot. Check out a 1-click [example](https://github.com/skypilot-org/skypilot/blob/master/llm/vllm) to start the vLLM demo, and the [blog post](https://blog.skypilot.co/serving-llm-24x-faster-on-the-cloud-with-vllm-and-skypilot/) for the story behind vLLM development on the clouds.
|
||||
@ -35,17 +38,18 @@ vLLM is fast with:
|
||||
|
||||
vLLM is flexible and easy to use with:
|
||||
|
||||
- Seamless integration with popular HuggingFace models
|
||||
- Seamless integration with popular Hugging Face models
|
||||
- High-throughput serving with various decoding algorithms, including *parallel sampling*, *beam search*, and more
|
||||
- Tensor parallelism support for distributed inference
|
||||
- Streaming outputs
|
||||
- OpenAI-compatible API server
|
||||
|
||||
vLLM seamlessly supports many Huggingface models, including the following architectures:
|
||||
vLLM seamlessly supports many Hugging Face models, including the following architectures:
|
||||
|
||||
- Aquila (`BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc.)
|
||||
- Aquila & Aquila2 (`BAAI/AquilaChat2-7B`, `BAAI/AquilaChat2-34B`, `BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc.)
|
||||
- Baichuan (`baichuan-inc/Baichuan-7B`, `baichuan-inc/Baichuan-13B-Chat`, etc.)
|
||||
- BLOOM (`bigscience/bloom`, `bigscience/bloomz`, etc.)
|
||||
- ChatGLM (`THUDM/chatglm2-6b`, `THUDM/chatglm3-6b`, etc.)
|
||||
- Falcon (`tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc.)
|
||||
- GPT-2 (`gpt2`, `gpt2-xl`, etc.)
|
||||
- GPT BigCode (`bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, etc.)
|
||||
@ -53,9 +57,12 @@ vLLM seamlessly supports many Huggingface models, including the following archit
|
||||
- GPT-NeoX (`EleutherAI/gpt-neox-20b`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc.)
|
||||
- InternLM (`internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc.)
|
||||
- LLaMA & LLaMA-2 (`meta-llama/Llama-2-70b-hf`, `lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.)
|
||||
- Mistral (`mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc.)
|
||||
- MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.)
|
||||
- OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.)
|
||||
- Phi-1.5 (`microsoft/phi-1_5`, etc.)
|
||||
- Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.)
|
||||
- Yi (`01-ai/Yi-6B`, `01-ai/Yi-34B`, etc.)
|
||||
|
||||
Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source):
|
||||
|
||||
@ -70,37 +77,19 @@ Visit our [documentation](https://vllm.readthedocs.io/en/latest/) to get started
|
||||
- [Quickstart](https://vllm.readthedocs.io/en/latest/getting_started/quickstart.html)
|
||||
- [Supported Models](https://vllm.readthedocs.io/en/latest/models/supported_models.html)
|
||||
|
||||
## Performance
|
||||
|
||||
vLLM outperforms HuggingFace Transformers (HF) by up to 24x and Text Generation Inference (TGI) by up to 3.5x, in terms of throughput.
|
||||
For details, check out our [blog post](https://vllm.ai).
|
||||
|
||||
<p align="center">
|
||||
<picture>
|
||||
<source media="(prefers-color-scheme: dark)" srcset="https://raw.githubusercontent.com/vllm-project/vllm/main/docs/source/assets/figures/perf_a10g_n1_dark.png">
|
||||
<img src="https://raw.githubusercontent.com/vllm-project/vllm/main/docs/source/assets/figures/perf_a10g_n1_light.png" width="45%">
|
||||
</picture>
|
||||
<picture>
|
||||
<source media="(prefers-color-scheme: dark)" srcset="https://raw.githubusercontent.com/vllm-project/vllm/main/docs/source/assets/figures/perf_a100_n1_dark.png">
|
||||
<img src="https://raw.githubusercontent.com/vllm-project/vllm/main/docs/source/assets/figures/perf_a100_n1_light.png" width="45%">
|
||||
</picture>
|
||||
<br>
|
||||
<em> Serving throughput when each request asks for 1 output completion. </em>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<picture>
|
||||
<source media="(prefers-color-scheme: dark)" srcset="https://raw.githubusercontent.com/vllm-project/vllm/main/docs/source/assets/figures/perf_a10g_n3_dark.png">
|
||||
<img src="https://raw.githubusercontent.com/vllm-project/vllm/main/docs/source/assets/figures/perf_a10g_n3_light.png" width="45%">
|
||||
</picture>
|
||||
<picture>
|
||||
<source media="(prefers-color-scheme: dark)" srcset="https://raw.githubusercontent.com/vllm-project/vllm/main/docs/source/assets/figures/perf_a100_n3_dark.png">
|
||||
<img src="https://raw.githubusercontent.com/vllm-project/vllm/main/docs/source/assets/figures/perf_a100_n3_light.png" width="45%">
|
||||
</picture> <br>
|
||||
<em> Serving throughput when each request asks for 3 output completions. </em>
|
||||
</p>
|
||||
|
||||
## Contributing
|
||||
|
||||
We welcome and value any contributions and collaborations.
|
||||
Please check out [CONTRIBUTING.md](./CONTRIBUTING.md) for how to get involved.
|
||||
|
||||
## Citation
|
||||
|
||||
If you use vLLM for your research, please cite our [paper](https://arxiv.org/abs/2309.06180):
|
||||
```bibtex
|
||||
@inproceedings{kwon2023efficient,
|
||||
title={Efficient Memory Management for Large Language Model Serving with PagedAttention},
|
||||
author={Woosuk Kwon and Zhuohan Li and Siyuan Zhuang and Ying Sheng and Lianmin Zheng and Cody Hao Yu and Joseph E. Gonzalez and Hao Zhang and Ion Stoica},
|
||||
booktitle={Proceedings of the ACM SIGOPS 29th Symposium on Operating Systems Principles},
|
||||
year={2023}
|
||||
}
|
||||
```
|
||||
|
@ -18,10 +18,12 @@ def main(args: argparse.Namespace):
|
||||
llm = LLM(
|
||||
model=args.model,
|
||||
tokenizer=args.tokenizer,
|
||||
quantization=args.quantization,
|
||||
tensor_parallel_size=args.tensor_parallel_size,
|
||||
max_num_seqs=args.batch_size,
|
||||
max_num_batched_tokens=args.batch_size * args.input_len,
|
||||
trust_remote_code=args.trust_remote_code,
|
||||
dtype=args.dtype,
|
||||
)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
@ -38,13 +40,13 @@ def main(args: argparse.Namespace):
|
||||
def run_to_completion(profile: bool = False):
|
||||
if profile:
|
||||
torch.cuda.cudart().cudaProfilerStart()
|
||||
start_time = time.time()
|
||||
start_time = time.perf_counter()
|
||||
|
||||
llm.generate(prompt_token_ids=dummy_prompt_token_ids,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=False)
|
||||
|
||||
end_time = time.time()
|
||||
end_time = time.perf_counter()
|
||||
latency = end_time - start_time
|
||||
if profile:
|
||||
torch.cuda.cudart().cudaProfilerStop()
|
||||
@ -63,19 +65,37 @@ def main(args: argparse.Namespace):
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Benchmark the latency of processing a single batch of '
|
||||
'requests till completion.')
|
||||
'requests till completion.')
|
||||
parser.add_argument('--model', type=str, default='facebook/opt-125m')
|
||||
parser.add_argument('--tokenizer', type=str, default=None)
|
||||
parser.add_argument('--quantization',
|
||||
'-q',
|
||||
choices=['awq', 'squeezellm', None],
|
||||
default=None)
|
||||
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
|
||||
parser.add_argument('--input-len', type=int, default=32)
|
||||
parser.add_argument('--output-len', type=int, default=128)
|
||||
parser.add_argument('--batch-size', type=int, default=8)
|
||||
parser.add_argument('--n', type=int, default=1,
|
||||
parser.add_argument('--n',
|
||||
type=int,
|
||||
default=1,
|
||||
help='Number of generated sequences per prompt.')
|
||||
parser.add_argument('--use-beam-search', action='store_true')
|
||||
parser.add_argument('--num-iters', type=int, default=3,
|
||||
parser.add_argument('--num-iters',
|
||||
type=int,
|
||||
default=3,
|
||||
help='Number of iterations to run.')
|
||||
parser.add_argument('--trust-remote-code', action='store_true',
|
||||
parser.add_argument('--trust-remote-code',
|
||||
action='store_true',
|
||||
help='trust remote code from huggingface')
|
||||
parser.add_argument(
|
||||
'--dtype',
|
||||
type=str,
|
||||
default='auto',
|
||||
choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'],
|
||||
help='data type for model weights and activations. '
|
||||
'The "auto" option will use FP16 precision '
|
||||
'for FP32 and FP16 models, and BF16 precision '
|
||||
'for BF16 models.')
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
@ -105,7 +105,7 @@ async def send_request(
|
||||
best_of: int,
|
||||
use_beam_search: bool,
|
||||
) -> None:
|
||||
request_start_time = time.time()
|
||||
request_start_time = time.perf_counter()
|
||||
|
||||
headers = {"User-Agent": "Benchmark Client"}
|
||||
if backend == "vllm":
|
||||
@ -148,7 +148,7 @@ async def send_request(
|
||||
if "error" not in output:
|
||||
break
|
||||
|
||||
request_end_time = time.time()
|
||||
request_end_time = time.perf_counter()
|
||||
request_latency = request_end_time - request_start_time
|
||||
REQUEST_LATENCY.append((prompt_len, output_len, request_latency))
|
||||
|
||||
@ -180,10 +180,10 @@ def main(args: argparse.Namespace):
|
||||
tokenizer = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code)
|
||||
input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
|
||||
|
||||
benchmark_start_time = time.time()
|
||||
benchmark_start_time = time.perf_counter()
|
||||
asyncio.run(benchmark(args.backend, api_url, input_requests, args.best_of,
|
||||
args.use_beam_search, args.request_rate))
|
||||
benchmark_end_time = time.time()
|
||||
benchmark_end_time = time.perf_counter()
|
||||
benchmark_time = benchmark_end_time - benchmark_start_time
|
||||
print(f"Total time: {benchmark_time:.2f} s")
|
||||
print(f"Throughput: {args.num_prompts / benchmark_time:.2f} requests/s")
|
||||
|
@ -3,34 +3,32 @@ import argparse
|
||||
import json
|
||||
import random
|
||||
import time
|
||||
from typing import List, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, PreTrainedTokenizerBase
|
||||
from transformers import (AutoModelForCausalLM, AutoTokenizer,
|
||||
PreTrainedTokenizerBase)
|
||||
from tqdm import tqdm
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
|
||||
def sample_requests(
|
||||
dataset_path: str,
|
||||
num_requests: int,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
fixed_output_len: Optional[int],
|
||||
) -> List[Tuple[str, int, int]]:
|
||||
if fixed_output_len is not None:
|
||||
if fixed_output_len < 4:
|
||||
raise ValueError("output_len too small")
|
||||
|
||||
# Load the dataset.
|
||||
with open(dataset_path) as f:
|
||||
dataset = json.load(f)
|
||||
# Filter out the conversations with less than 2 turns.
|
||||
dataset = [
|
||||
data for data in dataset
|
||||
if len(data["conversations"]) >= 2
|
||||
]
|
||||
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
|
||||
# Only keep the first two turns of each conversation.
|
||||
dataset = [
|
||||
(data["conversations"][0]["value"], data["conversations"][1]["value"])
|
||||
for data in dataset
|
||||
]
|
||||
dataset = [(data["conversations"][0]["value"],
|
||||
data["conversations"][1]["value"]) for data in dataset]
|
||||
|
||||
# Tokenize the prompts and completions.
|
||||
prompts = [prompt for prompt, _ in dataset]
|
||||
@ -40,6 +38,8 @@ def sample_requests(
|
||||
tokenized_dataset = []
|
||||
for i in range(len(dataset)):
|
||||
output_len = len(completion_token_ids[i])
|
||||
if fixed_output_len is not None:
|
||||
output_len = fixed_output_len
|
||||
tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len))
|
||||
|
||||
# Filter out too long sequences.
|
||||
@ -63,18 +63,23 @@ def run_vllm(
|
||||
requests: List[Tuple[str, int, int]],
|
||||
model: str,
|
||||
tokenizer: str,
|
||||
quantization: Optional[str],
|
||||
tensor_parallel_size: int,
|
||||
seed: int,
|
||||
n: int,
|
||||
use_beam_search: bool,
|
||||
trust_remote_code: bool,
|
||||
dtype: str,
|
||||
) -> float:
|
||||
from vllm import LLM, SamplingParams
|
||||
llm = LLM(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
quantization=quantization,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
seed=seed,
|
||||
trust_remote_code=trust_remote_code,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# Add the requests to the engine.
|
||||
@ -94,10 +99,10 @@ def run_vllm(
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
|
||||
start = time.time()
|
||||
# FIXME(woosuk): Do use internal method.
|
||||
start = time.perf_counter()
|
||||
# FIXME(woosuk): Do not use internal method.
|
||||
llm._run_engine(use_tqdm=True)
|
||||
end = time.time()
|
||||
end = time.perf_counter()
|
||||
return end - start
|
||||
|
||||
|
||||
@ -111,15 +116,15 @@ def run_hf(
|
||||
trust_remote_code: bool,
|
||||
) -> float:
|
||||
assert not use_beam_search
|
||||
llm = AutoModelForCausalLM.from_pretrained(model,
|
||||
torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
|
||||
llm = AutoModelForCausalLM.from_pretrained(
|
||||
model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
|
||||
if llm.config.model_type == "llama":
|
||||
# To enable padding in the HF backend.
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
llm = llm.cuda()
|
||||
|
||||
pbar = tqdm(total=len(requests))
|
||||
start = time.time()
|
||||
start = time.perf_counter()
|
||||
batch: List[str] = []
|
||||
max_prompt_len = 0
|
||||
max_output_len = 0
|
||||
@ -132,13 +137,14 @@ def run_hf(
|
||||
if len(batch) < max_batch_size and i != len(requests) - 1:
|
||||
# Check if we can add more requests to the batch.
|
||||
_, next_prompt_len, next_output_len = requests[i + 1]
|
||||
if (max(max_prompt_len, next_prompt_len) + max(
|
||||
max_output_len, next_output_len)) <= 2048:
|
||||
if (max(max_prompt_len, next_prompt_len) +
|
||||
max(max_output_len, next_output_len)) <= 2048:
|
||||
# We can add more requests to the batch.
|
||||
continue
|
||||
|
||||
# Generate the sequences.
|
||||
input_ids = tokenizer(batch, return_tensors="pt", padding=True).input_ids
|
||||
input_ids = tokenizer(batch, return_tensors="pt",
|
||||
padding=True).input_ids
|
||||
llm_outputs = llm.generate(
|
||||
input_ids=input_ids.cuda(),
|
||||
do_sample=not use_beam_search,
|
||||
@ -156,7 +162,23 @@ def run_hf(
|
||||
batch = []
|
||||
max_prompt_len = 0
|
||||
max_output_len = 0
|
||||
end = time.time()
|
||||
end = time.perf_counter()
|
||||
return end - start
|
||||
|
||||
|
||||
def run_mii(
|
||||
requests: List[Tuple[str, int, int]],
|
||||
model: str,
|
||||
tensor_parallel_size: int,
|
||||
output_len: int,
|
||||
) -> float:
|
||||
from mii import pipeline
|
||||
llm = pipeline(model, tensor_parallel=tensor_parallel_size)
|
||||
prompts = [prompt for prompt, _, _ in requests]
|
||||
|
||||
start = time.perf_counter()
|
||||
llm(prompts, max_new_tokens=output_len)
|
||||
end = time.perf_counter()
|
||||
return end - start
|
||||
|
||||
|
||||
@ -165,49 +187,98 @@ def main(args: argparse.Namespace):
|
||||
random.seed(args.seed)
|
||||
|
||||
# Sample the requests.
|
||||
tokenizer = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code)
|
||||
requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.tokenizer, trust_remote_code=args.trust_remote_code)
|
||||
if args.dataset is None:
|
||||
# Synthesize a prompt with the given input length.
|
||||
prompt = "hi" * (args.input_len - 1)
|
||||
requests = [(prompt, args.input_len, args.output_len)
|
||||
for _ in range(args.num_prompts)]
|
||||
else:
|
||||
requests = sample_requests(args.dataset, args.num_prompts, tokenizer,
|
||||
args.output_len)
|
||||
|
||||
if args.backend == "vllm":
|
||||
elapsed_time = run_vllm(
|
||||
requests, args.model, args.tokenizer, args.tensor_parallel_size,
|
||||
args.seed, args.n, args.use_beam_search, args.trust_remote_code)
|
||||
elapsed_time = run_vllm(requests, args.model, args.tokenizer,
|
||||
args.quantization, args.tensor_parallel_size,
|
||||
args.seed, args.n, args.use_beam_search,
|
||||
args.trust_remote_code, args.dtype)
|
||||
elif args.backend == "hf":
|
||||
assert args.tensor_parallel_size == 1
|
||||
elapsed_time = run_hf(
|
||||
requests, args.model, tokenizer, args.n, args.use_beam_search,
|
||||
args.hf_max_batch_size, args.trust_remote_code)
|
||||
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
|
||||
args.use_beam_search, args.hf_max_batch_size,
|
||||
args.trust_remote_code)
|
||||
elif args.backend == "mii":
|
||||
elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size,
|
||||
args.output_len)
|
||||
else:
|
||||
raise ValueError(f"Unknown backend: {args.backend}")
|
||||
total_num_tokens = sum(
|
||||
prompt_len + output_len
|
||||
for _, prompt_len, output_len in requests
|
||||
)
|
||||
total_num_tokens = sum(prompt_len + output_len
|
||||
for _, prompt_len, output_len in requests)
|
||||
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
|
||||
f"{total_num_tokens / elapsed_time:.2f} tokens/s")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Benchmark the throughput.")
|
||||
parser.add_argument("--backend", type=str, choices=["vllm", "hf"],
|
||||
parser.add_argument("--backend",
|
||||
type=str,
|
||||
choices=["vllm", "hf", "mii"],
|
||||
default="vllm")
|
||||
parser.add_argument("--dataset", type=str, required=True,
|
||||
parser.add_argument("--dataset",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the dataset.")
|
||||
parser.add_argument("--input-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Input prompt length for each request")
|
||||
parser.add_argument("--output-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Output length for each request. Overrides the "
|
||||
"output length from the dataset.")
|
||||
parser.add_argument("--model", type=str, default="facebook/opt-125m")
|
||||
parser.add_argument("--tokenizer", type=str, default=None)
|
||||
parser.add_argument('--quantization',
|
||||
'-q',
|
||||
choices=['awq', 'squeezellm', None],
|
||||
default=None)
|
||||
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
|
||||
parser.add_argument("--n", type=int, default=1,
|
||||
parser.add_argument("--n",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of generated sequences per prompt.")
|
||||
parser.add_argument("--use-beam-search", action="store_true")
|
||||
parser.add_argument("--num-prompts", type=int, default=1000,
|
||||
parser.add_argument("--num-prompts",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="Number of prompts to process.")
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--hf-max-batch-size", type=int, default=None,
|
||||
parser.add_argument("--hf-max-batch-size",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Maximum batch size for HF backend.")
|
||||
parser.add_argument('--trust-remote-code',
|
||||
action='store_true',
|
||||
help='trust remote code from huggingface')
|
||||
parser.add_argument(
|
||||
'--dtype',
|
||||
type=str,
|
||||
default='auto',
|
||||
choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'],
|
||||
help='data type for model weights and activations. '
|
||||
'The "auto" option will use FP16 precision '
|
||||
'for FP32 and FP16 models, and BF16 precision '
|
||||
'for BF16 models.')
|
||||
args = parser.parse_args()
|
||||
if args.tokenizer is None:
|
||||
args.tokenizer = args.model
|
||||
if args.dataset is None:
|
||||
assert args.input_len is not None
|
||||
assert args.output_len is not None
|
||||
else:
|
||||
assert args.input_len is None
|
||||
|
||||
if args.backend == "vllm":
|
||||
if args.hf_max_batch_size is not None:
|
||||
@ -215,7 +286,20 @@ if __name__ == "__main__":
|
||||
elif args.backend == "hf":
|
||||
if args.hf_max_batch_size is None:
|
||||
raise ValueError("HF max batch size is required for HF backend.")
|
||||
if args.tokenizer is None:
|
||||
args.tokenizer = args.model
|
||||
|
||||
if args.quantization is not None:
|
||||
raise ValueError("Quantization is only for vLLM backend.")
|
||||
elif args.backend == "mii":
|
||||
if args.dtype != "auto":
|
||||
raise ValueError("dtype must be auto for MII backend.")
|
||||
if args.n != 1:
|
||||
raise ValueError("n must be 1 for MII backend.")
|
||||
if args.use_beam_search:
|
||||
raise ValueError("Beam search is not supported for MII backend.")
|
||||
if args.quantization is not None:
|
||||
raise ValueError("Quantization is only for vLLM backend.")
|
||||
if args.hf_max_batch_size is not None:
|
||||
raise ValueError("HF max batch size is only for HF backend.")
|
||||
if args.tokenizer != args.model:
|
||||
raise ValueError("Tokenizer must be the same as the model for MII "
|
||||
"backend.")
|
||||
main(args)
|
||||
|
197
benchmarks/kernels/benchmark_paged_attention.py
Normal file
@ -0,0 +1,197 @@
|
||||
import argparse
|
||||
import random
|
||||
import time
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import attention_ops
|
||||
|
||||
NUM_BLOCKS = 1024
|
||||
PARTITION_SIZE = 512
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def main(
|
||||
version: str,
|
||||
num_seqs: int,
|
||||
context_len: int,
|
||||
num_query_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
use_alibi: bool,
|
||||
block_size: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
do_profile: bool,
|
||||
) -> None:
|
||||
random.seed(seed)
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
scale = float(1.0 / (head_size**0.5))
|
||||
query = torch.empty(num_seqs,
|
||||
num_query_heads,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device="cuda")
|
||||
query.uniform_(-scale, scale)
|
||||
|
||||
assert num_query_heads % num_kv_heads == 0
|
||||
num_queries_per_kv = num_query_heads // num_kv_heads
|
||||
head_mapping = torch.repeat_interleave(
|
||||
torch.arange(num_kv_heads, dtype=torch.int32, device="cuda"),
|
||||
num_queries_per_kv)
|
||||
alibi_slopes = None
|
||||
if use_alibi:
|
||||
alibi_slopes = torch.randn(num_query_heads,
|
||||
dtype=torch.float,
|
||||
device="cuda")
|
||||
|
||||
context_lens = [context_len for _ in range(num_seqs)]
|
||||
max_context_len = max(context_lens)
|
||||
context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda")
|
||||
|
||||
# Create the block tables.
|
||||
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
|
||||
block_tables = []
|
||||
for _ in range(num_seqs):
|
||||
block_table = [
|
||||
random.randint(0, NUM_BLOCKS - 1)
|
||||
for _ in range(max_num_blocks_per_seq)
|
||||
]
|
||||
block_tables.append(block_table)
|
||||
block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda")
|
||||
|
||||
# Create the KV cache.
|
||||
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
||||
key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, block_size, x)
|
||||
key_cache = torch.empty(size=key_cache_shape, dtype=dtype, device="cuda")
|
||||
key_cache.uniform_(-scale, scale)
|
||||
value_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size, block_size)
|
||||
value_cache = torch.empty(size=value_cache_shape,
|
||||
dtype=dtype,
|
||||
device="cuda")
|
||||
value_cache.uniform_(-scale, scale)
|
||||
|
||||
# Prepare for the paged attention kernel.
|
||||
output = torch.empty_like(query)
|
||||
if version == "v2":
|
||||
num_partitions = ((max_context_len + PARTITION_SIZE - 1) //
|
||||
PARTITION_SIZE)
|
||||
tmp_output = torch.empty(
|
||||
size=(num_seqs, num_query_heads, num_partitions, head_size),
|
||||
dtype=output.dtype,
|
||||
device=output.device,
|
||||
)
|
||||
exp_sums = torch.empty(
|
||||
size=(num_seqs, num_query_heads, num_partitions),
|
||||
dtype=torch.float32,
|
||||
device=output.device,
|
||||
)
|
||||
max_logits = torch.empty_like(exp_sums)
|
||||
|
||||
def run_benchmark(num_iters: int, profile: bool = False) -> float:
|
||||
torch.cuda.synchronize()
|
||||
if profile:
|
||||
torch.cuda.cudart().cudaProfilerStart()
|
||||
start_time = time.perf_counter()
|
||||
|
||||
for _ in range(num_iters):
|
||||
if version == "v1":
|
||||
attention_ops.paged_attention_v1(
|
||||
output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
head_mapping,
|
||||
scale,
|
||||
block_tables,
|
||||
context_lens,
|
||||
block_size,
|
||||
max_context_len,
|
||||
alibi_slopes,
|
||||
)
|
||||
elif version == "v2":
|
||||
attention_ops.paged_attention_v2(
|
||||
output,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
head_mapping,
|
||||
scale,
|
||||
block_tables,
|
||||
context_lens,
|
||||
block_size,
|
||||
max_context_len,
|
||||
alibi_slopes,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid version: {version}")
|
||||
torch.cuda.synchronize()
|
||||
|
||||
end_time = time.perf_counter()
|
||||
if profile:
|
||||
torch.cuda.cudart().cudaProfilerStart()
|
||||
return (end_time - start_time) / num_iters
|
||||
|
||||
# Warmup.
|
||||
print("Warming up...")
|
||||
run_benchmark(num_iters=3, profile=False)
|
||||
|
||||
# Benchmark.
|
||||
if do_profile:
|
||||
latency = run_benchmark(num_iters=1, profile=True)
|
||||
else:
|
||||
latency = run_benchmark(num_iters=100, profile=False)
|
||||
print(f"Kernel running time: {latency * 1000000:.3f} us")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Benchmark the paged attention kernel.")
|
||||
parser.add_argument("--version",
|
||||
type=str,
|
||||
choices=["v1", "v2"],
|
||||
default="v2")
|
||||
parser.add_argument("--batch-size", type=int, default=8)
|
||||
parser.add_argument("--context-len", type=int, default=4096)
|
||||
parser.add_argument("--num-query-heads", type=int, default=64)
|
||||
parser.add_argument("--num-kv-heads", type=int, default=8)
|
||||
parser.add_argument("--head-size",
|
||||
type=int,
|
||||
choices=[64, 80, 96, 112, 128, 256],
|
||||
default=128)
|
||||
parser.add_argument("--block-size", type=int, choices=[16, 32], default=16)
|
||||
parser.add_argument("--use-alibi", action="store_true")
|
||||
parser.add_argument("--dtype",
|
||||
type=str,
|
||||
choices=["half", "bfloat16", "float"],
|
||||
default="half")
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--profile", action="store_true")
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
|
||||
if args.num_query_heads % args.num_kv_heads != 0:
|
||||
raise ValueError("num_query_heads must be divisible by num_kv_heads")
|
||||
dtype_to_torch_dtype = {
|
||||
"half": torch.half,
|
||||
"bfloat16": torch.bfloat16,
|
||||
"float": torch.float,
|
||||
}
|
||||
main(
|
||||
version=args.version,
|
||||
num_seqs=args.batch_size,
|
||||
context_len=args.context_len,
|
||||
num_query_heads=args.num_query_heads,
|
||||
num_kv_heads=args.num_kv_heads,
|
||||
head_size=args.head_size,
|
||||
block_size=args.block_size,
|
||||
use_alibi=args.use_alibi,
|
||||
dtype=dtype_to_torch_dtype[args.dtype],
|
||||
seed=args.seed,
|
||||
do_profile=args.profile,
|
||||
)
|
@ -13,11 +13,11 @@ __device__ __forceinline__ T silu(const T& x) {
|
||||
|
||||
template<typename scalar_t>
|
||||
__global__ void silu_and_mul_kernel(
|
||||
scalar_t* __restrict__ out, // [num_tokens, d]
|
||||
const scalar_t* __restrict__ input, // [num_tokens, 2, d]
|
||||
scalar_t* __restrict__ out, // [..., d]
|
||||
const scalar_t* __restrict__ input, // [..., 2, d]
|
||||
const int d) {
|
||||
const int token_idx = blockIdx.x;
|
||||
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||
const scalar_t x = __ldg(&input[token_idx * 2 * d + idx]);
|
||||
const scalar_t y = __ldg(&input[token_idx * 2 * d + d + idx]);
|
||||
out[token_idx * d + idx] = silu(x) * y;
|
||||
@ -27,11 +27,11 @@ __global__ void silu_and_mul_kernel(
|
||||
} // namespace vllm
|
||||
|
||||
void silu_and_mul(
|
||||
torch::Tensor& out, // [num_tokens, d]
|
||||
torch::Tensor& input) // [num_tokens, 2 * d]
|
||||
torch::Tensor& out, // [..., d]
|
||||
torch::Tensor& input) // [..., 2 * d]
|
||||
{
|
||||
int num_tokens = input.size(0);
|
||||
int d = input.size(1) / 2;
|
||||
int64_t num_tokens = input.numel() / input.size(-1);
|
||||
int d = input.size(-1) / 2;
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(d, 1024));
|
||||
@ -52,11 +52,11 @@ namespace vllm {
|
||||
// Element-wise activation kernel template.
|
||||
template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
|
||||
__global__ void activation_kernel(
|
||||
scalar_t* __restrict__ out, // [num_tokens, d]
|
||||
const scalar_t* __restrict__ input, // [num_tokens, d]
|
||||
scalar_t* __restrict__ out, // [..., d]
|
||||
const scalar_t* __restrict__ input, // [..., d]
|
||||
const int d) {
|
||||
const int token_idx = blockIdx.x;
|
||||
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||
const scalar_t x = __ldg(&input[token_idx * d + idx]);
|
||||
out[token_idx * d + idx] = ACT_FN(x);
|
||||
}
|
||||
@ -66,8 +66,8 @@ __global__ void activation_kernel(
|
||||
|
||||
// Launch element-wise activation kernel.
|
||||
#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
|
||||
int num_tokens = input.size(0); \
|
||||
int d = input.size(1); \
|
||||
int d = input.size(-1); \
|
||||
int64_t num_tokens = input.numel() / d; \
|
||||
dim3 grid(num_tokens); \
|
||||
dim3 block(std::min(d, 1024)); \
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|
||||
@ -100,15 +100,15 @@ __device__ __forceinline__ T gelu_fast_kernel(const T& x) {
|
||||
} // namespace vllm
|
||||
|
||||
void gelu_new(
|
||||
torch::Tensor& out, // [num_tokens, d]
|
||||
torch::Tensor& input) // [num_tokens, d]
|
||||
torch::Tensor& out, // [..., d]
|
||||
torch::Tensor& input) // [..., d]
|
||||
{
|
||||
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel);
|
||||
}
|
||||
|
||||
void gelu_fast(
|
||||
torch::Tensor& out, // [num_tokens, d]
|
||||
torch::Tensor& input) // [num_tokens, d]
|
||||
torch::Tensor& out, // [..., d]
|
||||
torch::Tensor& input) // [..., d]
|
||||
{
|
||||
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel);
|
||||
}
|
||||
|
@ -1,7 +1,7 @@
|
||||
#include <torch/extension.h>
|
||||
#include <c10/util/Optional.h>
|
||||
|
||||
void single_query_cached_kv_attention(
|
||||
void paged_attention_v1(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& query,
|
||||
torch::Tensor& key_cache,
|
||||
@ -14,9 +14,29 @@ void single_query_cached_kv_attention(
|
||||
int max_context_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes);
|
||||
|
||||
void paged_attention_v2(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& exp_sums,
|
||||
torch::Tensor& max_logits,
|
||||
torch::Tensor& tmp_out,
|
||||
torch::Tensor& query,
|
||||
torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache,
|
||||
torch::Tensor& head_mapping,
|
||||
float scale,
|
||||
torch::Tensor& block_tables,
|
||||
torch::Tensor& context_lens,
|
||||
int block_size,
|
||||
int max_context_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def(
|
||||
"single_query_cached_kv_attention",
|
||||
&single_query_cached_kv_attention,
|
||||
"Compute the attention between an input query and the cached key/value tensors");
|
||||
"paged_attention_v1",
|
||||
&paged_attention_v1,
|
||||
"Compute the attention between an input query and the cached keys/values using PagedAttention.");
|
||||
m.def(
|
||||
"paged_attention_v2",
|
||||
&paged_attention_v2,
|
||||
"PagedAttention V2.");
|
||||
}
|
||||
|
@ -26,6 +26,7 @@
|
||||
#define WARP_SIZE 32
|
||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
|
||||
|
||||
namespace vllm {
|
||||
|
||||
@ -65,14 +66,18 @@ inline __device__ float block_sum(float* red_smem, float sum) {
|
||||
return __shfl_sync(uint32_t(-1), sum, 0);
|
||||
}
|
||||
|
||||
// Grid: (num_heads, num_seqs).
|
||||
// TODO(woosuk): Merge the last two dimensions of the grid.
|
||||
// Grid: (num_heads, num_seqs, max_num_partitions).
|
||||
template<
|
||||
typename scalar_t,
|
||||
int HEAD_SIZE,
|
||||
int BLOCK_SIZE,
|
||||
int NUM_THREADS>
|
||||
__global__ void single_query_cached_kv_attention_kernel(
|
||||
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
||||
int NUM_THREADS,
|
||||
int PARTITION_SIZE = 0> // Zero means no partitioning.
|
||||
__device__ void paged_attention_kernel(
|
||||
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
||||
float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
|
||||
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size]
|
||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
||||
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
||||
@ -85,10 +90,33 @@ __global__ void single_query_cached_kv_attention_kernel(
|
||||
const int q_stride,
|
||||
const int kv_block_stride,
|
||||
const int kv_head_stride) {
|
||||
const int seq_idx = blockIdx.y;
|
||||
const int partition_idx = blockIdx.z;
|
||||
const int max_num_partitions = gridDim.z;
|
||||
constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0;
|
||||
const int context_len = context_lens[seq_idx];
|
||||
if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= context_len) {
|
||||
// No work to do. Terminate the thread block.
|
||||
return;
|
||||
}
|
||||
|
||||
const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
|
||||
const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_context_blocks;
|
||||
|
||||
// [start_block_idx, end_block_idx) is the range of blocks to process.
|
||||
const int start_block_idx = USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0;
|
||||
const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_context_blocks);
|
||||
const int num_blocks = end_block_idx - start_block_idx;
|
||||
|
||||
// [start_token_idx, end_token_idx) is the range of tokens to process.
|
||||
const int start_token_idx = start_block_idx * BLOCK_SIZE;
|
||||
const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, context_len);
|
||||
const int num_tokens = end_token_idx - start_token_idx;
|
||||
|
||||
constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
|
||||
constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS
|
||||
assert(NUM_THREADS % THREAD_GROUP_SIZE == 0);
|
||||
constexpr int NUM_TOKENS_PER_THREAD_GROUP = (BLOCK_SIZE + WARP_SIZE - 1) / WARP_SIZE;
|
||||
constexpr int NUM_TOKENS_PER_THREAD_GROUP = DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE);
|
||||
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
||||
const int thread_idx = threadIdx.x;
|
||||
const int warp_idx = thread_idx / WARP_SIZE;
|
||||
@ -97,7 +125,6 @@ __global__ void single_query_cached_kv_attention_kernel(
|
||||
const int head_idx = blockIdx.x;
|
||||
const int num_heads = gridDim.x;
|
||||
const int kv_head_idx = head_mapping[head_idx];
|
||||
const int seq_idx = blockIdx.y;
|
||||
const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
|
||||
|
||||
// A vector type to store a part of a key or a query.
|
||||
@ -142,16 +169,16 @@ __global__ void single_query_cached_kv_attention_kernel(
|
||||
constexpr int x = 16 / sizeof(scalar_t);
|
||||
float qk_max = -FLT_MAX;
|
||||
|
||||
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
|
||||
const int context_len = context_lens[seq_idx];
|
||||
const int num_blocks = (context_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
|
||||
// Iterate over the key blocks.
|
||||
// Each warp fetches a block of keys for each iteration.
|
||||
// Each thread group in a warp fetches a key from the block, and computes
|
||||
// dot product with the query.
|
||||
for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) {
|
||||
const int physical_block_number = block_table[block_idx];
|
||||
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
|
||||
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
|
||||
// NOTE(woosuk): The block number is stored in int32. However, we cast it to int64
|
||||
// because int32 can lead to overflow when this variable is multiplied by large numbers
|
||||
// (e.g., kv_block_stride).
|
||||
const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
|
||||
|
||||
// Load a key to registers.
|
||||
// Each thread in a thread group has a different part of the key.
|
||||
@ -184,7 +211,7 @@ __global__ void single_query_cached_kv_attention_kernel(
|
||||
// Store the partial reductions to shared memory.
|
||||
// NOTE(woosuk): It is required to zero out the masked logits.
|
||||
const bool mask = token_idx >= context_len;
|
||||
logits[token_idx] = mask ? 0.f : qk;
|
||||
logits[token_idx - start_token_idx] = mask ? 0.f : qk;
|
||||
// Update the max value.
|
||||
qk_max = mask ? qk_max : fmaxf(qk_max, qk);
|
||||
}
|
||||
@ -215,7 +242,7 @@ __global__ void single_query_cached_kv_attention_kernel(
|
||||
|
||||
// Get the sum of the exp values.
|
||||
float exp_sum = 0.f;
|
||||
for (int i = thread_idx; i < context_len; i += NUM_THREADS) {
|
||||
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
|
||||
float val = __expf(logits[i] - qk_max);
|
||||
logits[i] = val;
|
||||
exp_sum += val;
|
||||
@ -224,11 +251,23 @@ __global__ void single_query_cached_kv_attention_kernel(
|
||||
|
||||
// Compute softmax.
|
||||
const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
|
||||
for (int i = thread_idx; i < context_len; i += NUM_THREADS) {
|
||||
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
|
||||
logits[i] *= inv_sum;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// If partitioning is enabled, store the max logit and exp_sum.
|
||||
if (USE_PARTITIONING && thread_idx == 0) {
|
||||
float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions
|
||||
+ head_idx * max_num_partitions
|
||||
+ partition_idx;
|
||||
*max_logits_ptr = qk_max;
|
||||
float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions
|
||||
+ head_idx * max_num_partitions
|
||||
+ partition_idx;
|
||||
*exp_sums_ptr = exp_sum;
|
||||
}
|
||||
|
||||
// Each thread will fetch 16 bytes from the value cache at a time.
|
||||
constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
|
||||
using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
|
||||
@ -237,7 +276,7 @@ __global__ void single_query_cached_kv_attention_kernel(
|
||||
|
||||
constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
|
||||
constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
|
||||
constexpr int NUM_ROWS_PER_THREAD = (HEAD_SIZE + NUM_ROWS_PER_ITER - 1) / NUM_ROWS_PER_ITER;
|
||||
constexpr int NUM_ROWS_PER_THREAD = DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER);
|
||||
|
||||
// NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
|
||||
float accs[NUM_ROWS_PER_THREAD];
|
||||
@ -248,12 +287,15 @@ __global__ void single_query_cached_kv_attention_kernel(
|
||||
|
||||
scalar_t zero_value;
|
||||
zero(zero_value);
|
||||
for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) {
|
||||
const int physical_block_number = block_table[block_idx];
|
||||
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
|
||||
// NOTE(woosuk): The block number is stored in int32. However, we cast it to int64
|
||||
// because int32 can lead to overflow when this variable is multiplied by large numbers
|
||||
// (e.g., kv_block_stride).
|
||||
const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
|
||||
const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
|
||||
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
|
||||
L_vec logits_vec;
|
||||
from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx));
|
||||
from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx - start_token_idx));
|
||||
|
||||
const scalar_t* v_ptr = v_cache + physical_block_number * kv_block_stride
|
||||
+ kv_head_idx * kv_head_stride;
|
||||
@ -263,13 +305,13 @@ __global__ void single_query_cached_kv_attention_kernel(
|
||||
if (row_idx < HEAD_SIZE) {
|
||||
const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
|
||||
V_vec v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
|
||||
if (block_idx == num_blocks - 1) {
|
||||
if (block_idx == num_context_blocks - 1) {
|
||||
// NOTE(woosuk): When v_vec contains the tokens that are out of the context,
|
||||
// we should explicitly zero out the values since they may contain NaNs.
|
||||
// See https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
|
||||
scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
|
||||
#pragma unroll
|
||||
for (int j = 0; j <= V_VEC_SIZE; j++) {
|
||||
for (int j = 0; j < V_VEC_SIZE; j++) {
|
||||
v_vec_ptr[j] = token_idx + j < context_len ? v_vec_ptr[j] : zero_value;
|
||||
}
|
||||
}
|
||||
@ -327,7 +369,9 @@ __global__ void single_query_cached_kv_attention_kernel(
|
||||
|
||||
// Write the final output.
|
||||
if (warp_idx == 0) {
|
||||
scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
|
||||
scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
|
||||
+ head_idx * max_num_partitions * HEAD_SIZE
|
||||
+ partition_idx * HEAD_SIZE;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
||||
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
||||
@ -338,10 +382,167 @@ __global__ void single_query_cached_kv_attention_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
// Grid: (num_heads, num_seqs, 1).
|
||||
template<
|
||||
typename scalar_t,
|
||||
int HEAD_SIZE,
|
||||
int BLOCK_SIZE,
|
||||
int NUM_THREADS>
|
||||
__global__ void paged_attention_v1_kernel(
|
||||
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
||||
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
||||
const int* __restrict__ head_mapping, // [num_heads]
|
||||
const float scale,
|
||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
const int* __restrict__ context_lens, // [num_seqs]
|
||||
const int max_num_blocks_per_seq,
|
||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||
const int q_stride,
|
||||
const int kv_block_stride,
|
||||
const int kv_head_stride) {
|
||||
paged_attention_kernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>(
|
||||
/* exp_sums */ nullptr, /* max_logits */ nullptr,
|
||||
out, q, k_cache, v_cache, head_mapping, scale, block_tables, context_lens,
|
||||
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride);
|
||||
}
|
||||
|
||||
// Grid: (num_heads, num_seqs, max_num_partitions).
|
||||
template<
|
||||
typename scalar_t,
|
||||
int HEAD_SIZE,
|
||||
int BLOCK_SIZE,
|
||||
int NUM_THREADS,
|
||||
int PARTITION_SIZE>
|
||||
__global__ void paged_attention_v2_kernel(
|
||||
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
||||
float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
|
||||
scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
|
||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
||||
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
||||
const int* __restrict__ head_mapping, // [num_heads]
|
||||
const float scale,
|
||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
const int* __restrict__ context_lens, // [num_seqs]
|
||||
const int max_num_blocks_per_seq,
|
||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||
const int q_stride,
|
||||
const int kv_block_stride,
|
||||
const int kv_head_stride) {
|
||||
paged_attention_kernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, PARTITION_SIZE>(
|
||||
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, head_mapping, scale,
|
||||
block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes,
|
||||
q_stride, kv_block_stride, kv_head_stride);
|
||||
}
|
||||
|
||||
// Grid: (num_heads, num_seqs).
|
||||
template<
|
||||
typename scalar_t,
|
||||
int HEAD_SIZE,
|
||||
int NUM_THREADS,
|
||||
int PARTITION_SIZE>
|
||||
__global__ void paged_attention_v2_reduce_kernel(
|
||||
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
||||
const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
||||
const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
|
||||
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
|
||||
const int* __restrict__ context_lens, // [num_seqs]
|
||||
const int max_num_partitions) {
|
||||
const int num_heads = gridDim.x;
|
||||
const int head_idx = blockIdx.x;
|
||||
const int seq_idx = blockIdx.y;
|
||||
const int context_len = context_lens[seq_idx];
|
||||
const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE);
|
||||
if (num_partitions == 1) {
|
||||
// No need to reduce. Only copy tmp_out to out.
|
||||
scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
|
||||
const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
|
||||
+ head_idx * max_num_partitions * HEAD_SIZE;
|
||||
for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) {
|
||||
out_ptr[i] = tmp_out_ptr[i];
|
||||
}
|
||||
// Terminate the thread block.
|
||||
return;
|
||||
}
|
||||
|
||||
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
||||
const int warp_idx = threadIdx.x / WARP_SIZE;
|
||||
const int lane = threadIdx.x % WARP_SIZE;
|
||||
|
||||
// Size: 2 * num_partitions.
|
||||
extern __shared__ char shared_mem[];
|
||||
// Workspace for reduction.
|
||||
__shared__ float red_smem[2 * NUM_WARPS];
|
||||
|
||||
// Load max logits to shared memory.
|
||||
float* shared_max_logits = reinterpret_cast<float*>(shared_mem);
|
||||
const float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions
|
||||
+ head_idx * max_num_partitions;
|
||||
float max_logit = -FLT_MAX;
|
||||
for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
|
||||
const float l = max_logits_ptr[i];
|
||||
shared_max_logits[i] = l;
|
||||
max_logit = fmaxf(max_logit, l);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Get the global max logit.
|
||||
// Reduce within the warp.
|
||||
#pragma unroll
|
||||
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
|
||||
max_logit = fmaxf(max_logit, __shfl_xor_sync(uint32_t(-1), max_logit, mask));
|
||||
}
|
||||
if (lane == 0) {
|
||||
red_smem[warp_idx] = max_logit;
|
||||
}
|
||||
__syncthreads();
|
||||
// Reduce across warps.
|
||||
max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
|
||||
#pragma unroll
|
||||
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
|
||||
max_logit = fmaxf(max_logit, __shfl_xor_sync(uint32_t(-1), max_logit, mask));
|
||||
}
|
||||
// Broadcast the max value to all threads.
|
||||
max_logit = __shfl_sync(uint32_t(-1), max_logit, 0);
|
||||
|
||||
// Load rescaled exp sums to shared memory.
|
||||
float* shared_exp_sums = reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions);
|
||||
const float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions
|
||||
+ head_idx * max_num_partitions;
|
||||
float global_exp_sum = 0.0f;
|
||||
for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
|
||||
float l = shared_max_logits[i];
|
||||
float rescaled_exp_sum = exp_sums_ptr[i] * expf(l - max_logit);
|
||||
global_exp_sum += rescaled_exp_sum;
|
||||
shared_exp_sums[i] = rescaled_exp_sum;
|
||||
}
|
||||
__syncthreads();
|
||||
global_exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], global_exp_sum);
|
||||
const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f);
|
||||
|
||||
// Aggregate tmp_out to out.
|
||||
const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
|
||||
+ head_idx * max_num_partitions * HEAD_SIZE;
|
||||
scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
|
||||
#pragma unroll
|
||||
for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) {
|
||||
float acc = 0.0f;
|
||||
for (int j = 0; j < num_partitions; ++j) {
|
||||
acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * inv_global_exp_sum;
|
||||
}
|
||||
from_float(out_ptr[i], acc);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
#define LAUNCH_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \
|
||||
vllm::single_query_cached_kv_attention_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS> \
|
||||
#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
|
||||
cudaFuncSetAttribute( \
|
||||
vllm::paged_attention_v1_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>, \
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \
|
||||
vllm::paged_attention_v1_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS> \
|
||||
<<<grid, block, shared_mem_size, stream>>>( \
|
||||
out_ptr, \
|
||||
query_ptr, \
|
||||
@ -362,7 +563,7 @@ template<
|
||||
typename T,
|
||||
int BLOCK_SIZE,
|
||||
int NUM_THREADS = 128>
|
||||
void single_query_cached_kv_attention_launcher(
|
||||
void paged_attention_v1_launcher(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& query,
|
||||
torch::Tensor& key_cache,
|
||||
@ -398,43 +599,37 @@ void single_query_cached_kv_attention_launcher(
|
||||
int* context_lens_ptr = context_lens.data_ptr<int>();
|
||||
|
||||
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
||||
int padded_max_context_len = ((max_context_len + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE;
|
||||
int padded_max_context_len = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE) * BLOCK_SIZE;
|
||||
int logits_size = padded_max_context_len * sizeof(float);
|
||||
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
|
||||
// Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
|
||||
// Keep that in sync with the logic here!
|
||||
int shared_mem_size = std::max(logits_size, outputs_size);
|
||||
|
||||
dim3 grid(num_heads, num_seqs);
|
||||
dim3 grid(num_heads, num_seqs, 1);
|
||||
dim3 block(NUM_THREADS);
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
switch (head_size) {
|
||||
// NOTE(woosuk): To reduce the compilation time, we omitted head sizes
|
||||
// 32, 160, 192.
|
||||
// case 32:
|
||||
// LAUNCH_ATTENTION_KERNEL(T, 32, BLOCK_SIZE, NUM_THREADS);
|
||||
// break;
|
||||
// NOTE(woosuk): To reduce the compilation time, we only compile for the
|
||||
// head sizes that we use in the model. However, we can easily extend this
|
||||
// to support any head size which is a multiple of 16.
|
||||
case 64:
|
||||
LAUNCH_ATTENTION_KERNEL(T, 64, BLOCK_SIZE, NUM_THREADS);
|
||||
LAUNCH_PAGED_ATTENTION_V1(64);
|
||||
break;
|
||||
case 80:
|
||||
LAUNCH_ATTENTION_KERNEL(T, 80, BLOCK_SIZE, NUM_THREADS);
|
||||
LAUNCH_PAGED_ATTENTION_V1(80);
|
||||
break;
|
||||
case 96:
|
||||
LAUNCH_ATTENTION_KERNEL(T, 96, BLOCK_SIZE, NUM_THREADS);
|
||||
LAUNCH_PAGED_ATTENTION_V1(96);
|
||||
break;
|
||||
case 112:
|
||||
LAUNCH_ATTENTION_KERNEL(T, 112, BLOCK_SIZE, NUM_THREADS);
|
||||
LAUNCH_PAGED_ATTENTION_V1(112);
|
||||
break;
|
||||
case 128:
|
||||
LAUNCH_ATTENTION_KERNEL(T, 128, BLOCK_SIZE, NUM_THREADS);
|
||||
LAUNCH_PAGED_ATTENTION_V1(128);
|
||||
break;
|
||||
// case 160:
|
||||
// LAUNCH_ATTENTION_KERNEL(T, 160, BLOCK_SIZE, NUM_THREADS);
|
||||
// break;
|
||||
// case 192:
|
||||
// LAUNCH_ATTENTION_KERNEL(T, 192, BLOCK_SIZE, NUM_THREADS);
|
||||
// break;
|
||||
case 256:
|
||||
LAUNCH_ATTENTION_KERNEL(T, 256, BLOCK_SIZE, NUM_THREADS);
|
||||
LAUNCH_PAGED_ATTENTION_V1(256);
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(false, "Unsupported head size: ", head_size);
|
||||
@ -442,8 +637,8 @@ void single_query_cached_kv_attention_launcher(
|
||||
}
|
||||
}
|
||||
|
||||
#define CALL_KERNEL_LAUNCHER(T, BLOCK_SIZE) \
|
||||
single_query_cached_kv_attention_launcher<T, BLOCK_SIZE>( \
|
||||
#define CALL_V1_LAUNCHER(T, BLOCK_SIZE) \
|
||||
paged_attention_v1_launcher<T, BLOCK_SIZE>( \
|
||||
out, \
|
||||
query, \
|
||||
key_cache, \
|
||||
@ -457,41 +652,23 @@ void single_query_cached_kv_attention_launcher(
|
||||
|
||||
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
|
||||
// 1, 2, 4, 64, 128, 256.
|
||||
#define CALL_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
|
||||
#define CALL_V1_LAUNCHER_BLOCK_SIZE(T) \
|
||||
switch (block_size) { \
|
||||
/* case 1: */ \
|
||||
/* CALL_KERNEL_LAUNCHER(T, 1); */ \
|
||||
/* break; */ \
|
||||
/* case 2: */ \
|
||||
/* CALL_KERNEL_LAUNCHER(T, 2); */ \
|
||||
/* break; */ \
|
||||
/* case 4: */ \
|
||||
/* CALL_KERNEL_LAUNCHER(T, 4); */ \
|
||||
/* break; */ \
|
||||
case 8: \
|
||||
CALL_KERNEL_LAUNCHER(T, 8); \
|
||||
CALL_V1_LAUNCHER(T, 8); \
|
||||
break; \
|
||||
case 16: \
|
||||
CALL_KERNEL_LAUNCHER(T, 16); \
|
||||
CALL_V1_LAUNCHER(T, 16); \
|
||||
break; \
|
||||
case 32: \
|
||||
CALL_KERNEL_LAUNCHER(T, 32); \
|
||||
CALL_V1_LAUNCHER(T, 32); \
|
||||
break; \
|
||||
/* case 64: */ \
|
||||
/* CALL_KERNEL_LAUNCHER(T, 64); */ \
|
||||
/* break; */ \
|
||||
/* case 128: */ \
|
||||
/* CALL_KERNEL_LAUNCHER(T, 128); */ \
|
||||
/* break; */ \
|
||||
/* case 256: */ \
|
||||
/* CALL_KERNEL_LAUNCHER(T, 256); */ \
|
||||
/* break; */ \
|
||||
default: \
|
||||
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
||||
break; \
|
||||
}
|
||||
|
||||
void single_query_cached_kv_attention(
|
||||
void paged_attention_v1(
|
||||
torch::Tensor& out, // [num_seqs, num_heads, head_size]
|
||||
torch::Tensor& query, // [num_seqs, num_heads, head_size]
|
||||
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||
@ -504,11 +681,186 @@ void single_query_cached_kv_attention(
|
||||
int max_context_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes) {
|
||||
if (query.dtype() == at::ScalarType::Float) {
|
||||
CALL_KERNEL_LAUNCHER_BLOCK_SIZE(float);
|
||||
CALL_V1_LAUNCHER_BLOCK_SIZE(float);
|
||||
} else if (query.dtype() == at::ScalarType::Half) {
|
||||
CALL_KERNEL_LAUNCHER_BLOCK_SIZE(uint16_t);
|
||||
CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t);
|
||||
} else if (query.dtype() == at::ScalarType::BFloat16) {
|
||||
CALL_KERNEL_LAUNCHER_BLOCK_SIZE(__nv_bfloat16);
|
||||
CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
||||
}
|
||||
}
|
||||
|
||||
#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
|
||||
vllm::paged_attention_v2_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, PARTITION_SIZE> \
|
||||
<<<grid, block, shared_mem_size, stream>>>( \
|
||||
exp_sums_ptr, \
|
||||
max_logits_ptr, \
|
||||
tmp_out_ptr, \
|
||||
query_ptr, \
|
||||
key_cache_ptr, \
|
||||
value_cache_ptr, \
|
||||
head_mapping_ptr, \
|
||||
scale, \
|
||||
block_tables_ptr, \
|
||||
context_lens_ptr, \
|
||||
max_num_blocks_per_seq, \
|
||||
alibi_slopes_ptr, \
|
||||
q_stride, \
|
||||
kv_block_stride, \
|
||||
kv_head_stride); \
|
||||
vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, PARTITION_SIZE> \
|
||||
<<<reduce_grid, block, reduce_shared_mem_size, stream>>>( \
|
||||
out_ptr, \
|
||||
exp_sums_ptr, \
|
||||
max_logits_ptr, \
|
||||
tmp_out_ptr, \
|
||||
context_lens_ptr, \
|
||||
max_num_partitions);
|
||||
|
||||
template<
|
||||
typename T,
|
||||
int BLOCK_SIZE,
|
||||
int NUM_THREADS = 128,
|
||||
int PARTITION_SIZE = 512>
|
||||
void paged_attention_v2_launcher(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& exp_sums,
|
||||
torch::Tensor& max_logits,
|
||||
torch::Tensor& tmp_out,
|
||||
torch::Tensor& query,
|
||||
torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache,
|
||||
torch::Tensor& head_mapping,
|
||||
float scale,
|
||||
torch::Tensor& block_tables,
|
||||
torch::Tensor& context_lens,
|
||||
int max_context_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes) {
|
||||
int num_seqs = query.size(0);
|
||||
int num_heads = query.size(1);
|
||||
int head_size = query.size(2);
|
||||
int max_num_blocks_per_seq = block_tables.size(1);
|
||||
int q_stride = query.stride(0);
|
||||
int kv_block_stride = key_cache.stride(0);
|
||||
int kv_head_stride = key_cache.stride(1);
|
||||
|
||||
int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
|
||||
assert(head_size % thread_group_size == 0);
|
||||
|
||||
// NOTE: alibi_slopes is optional.
|
||||
const float* alibi_slopes_ptr = alibi_slopes ?
|
||||
reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
|
||||
: nullptr;
|
||||
|
||||
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
|
||||
float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
|
||||
float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
|
||||
T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
|
||||
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
|
||||
T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
|
||||
T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
|
||||
int* head_mapping_ptr = reinterpret_cast<int*>(head_mapping.data_ptr());
|
||||
int* block_tables_ptr = block_tables.data_ptr<int>();
|
||||
int* context_lens_ptr = context_lens.data_ptr<int>();
|
||||
|
||||
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
||||
int max_num_partitions = DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE);
|
||||
int logits_size = PARTITION_SIZE * sizeof(float);
|
||||
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
|
||||
|
||||
// For paged attention v2 kernel.
|
||||
dim3 grid(num_heads, num_seqs, max_num_partitions);
|
||||
int shared_mem_size = std::max(logits_size, outputs_size);
|
||||
// For paged attention v2 reduce kernel.
|
||||
dim3 reduce_grid(num_heads, num_seqs);
|
||||
int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);
|
||||
|
||||
dim3 block(NUM_THREADS);
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
switch (head_size) {
|
||||
// NOTE(woosuk): To reduce the compilation time, we only compile for the
|
||||
// head sizes that we use in the model. However, we can easily extend this
|
||||
// to support any head size which is a multiple of 16.
|
||||
case 64:
|
||||
LAUNCH_PAGED_ATTENTION_V2(64);
|
||||
break;
|
||||
case 80:
|
||||
LAUNCH_PAGED_ATTENTION_V2(80);
|
||||
break;
|
||||
case 96:
|
||||
LAUNCH_PAGED_ATTENTION_V2(96);
|
||||
break;
|
||||
case 112:
|
||||
LAUNCH_PAGED_ATTENTION_V2(112);
|
||||
break;
|
||||
case 128:
|
||||
LAUNCH_PAGED_ATTENTION_V2(128);
|
||||
break;
|
||||
case 256:
|
||||
LAUNCH_PAGED_ATTENTION_V2(256);
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(false, "Unsupported head size: ", head_size);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
#define CALL_V2_LAUNCHER(T, BLOCK_SIZE) \
|
||||
paged_attention_v2_launcher<T, BLOCK_SIZE>( \
|
||||
out, \
|
||||
exp_sums, \
|
||||
max_logits, \
|
||||
tmp_out, \
|
||||
query, \
|
||||
key_cache, \
|
||||
value_cache, \
|
||||
head_mapping, \
|
||||
scale, \
|
||||
block_tables, \
|
||||
context_lens, \
|
||||
max_context_len, \
|
||||
alibi_slopes);
|
||||
|
||||
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
|
||||
// 1, 2, 4, 64, 128, 256.
|
||||
#define CALL_V2_LAUNCHER_BLOCK_SIZE(T) \
|
||||
switch (block_size) { \
|
||||
case 8: \
|
||||
CALL_V2_LAUNCHER(T, 8); \
|
||||
break; \
|
||||
case 16: \
|
||||
CALL_V2_LAUNCHER(T, 16); \
|
||||
break; \
|
||||
case 32: \
|
||||
CALL_V2_LAUNCHER(T, 32); \
|
||||
break; \
|
||||
default: \
|
||||
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
||||
break; \
|
||||
}
|
||||
|
||||
void paged_attention_v2(
|
||||
torch::Tensor& out, // [num_seqs, num_heads, head_size]
|
||||
torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
||||
torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions]
|
||||
torch::Tensor& tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
|
||||
torch::Tensor& query, // [num_seqs, num_heads, head_size]
|
||||
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||
torch::Tensor& head_mapping, // [num_heads]
|
||||
float scale,
|
||||
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
torch::Tensor& context_lens, // [num_seqs]
|
||||
int block_size,
|
||||
int max_context_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes) {
|
||||
if (query.dtype() == at::ScalarType::Float) {
|
||||
CALL_V2_LAUNCHER_BLOCK_SIZE(float);
|
||||
} else if (query.dtype() == at::ScalarType::Half) {
|
||||
CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t);
|
||||
} else if (query.dtype() == at::ScalarType::BFloat16) {
|
||||
CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
||||
}
|
||||
@ -517,3 +869,4 @@ void single_query_cached_kv_attention(
|
||||
#undef WARP_SIZE
|
||||
#undef MAX
|
||||
#undef MIN
|
||||
#undef DIVIDE_ROUND_UP
|
||||
|
@ -420,6 +420,11 @@ inline __device__ void from_float(bf16_8_t& dst, Float8_ src) {
|
||||
#endif
|
||||
}
|
||||
|
||||
// From bfloat16 to float32.
|
||||
inline __device__ float to_float(__nv_bfloat16 u) {
|
||||
return __bfloat162float(u);
|
||||
}
|
||||
|
||||
// Zero-out a variable.
|
||||
inline __device__ void zero(__nv_bfloat16& dst) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
|
@ -55,26 +55,26 @@ template<typename scalar_t>
|
||||
__global__ void copy_blocks_kernel(
|
||||
int64_t* key_cache_ptrs,
|
||||
int64_t* value_cache_ptrs,
|
||||
const int* __restrict__ block_mapping,
|
||||
const int64_t* __restrict__ block_mapping,
|
||||
const int numel_per_block) {
|
||||
const int layer_idx = blockIdx.x;
|
||||
const int pair_idx = blockIdx.y;
|
||||
|
||||
scalar_t* key_cache = reinterpret_cast<scalar_t*>(key_cache_ptrs[layer_idx]);
|
||||
scalar_t* value_cache = reinterpret_cast<scalar_t*>(value_cache_ptrs[layer_idx]);
|
||||
int src_block_number = block_mapping[2 * pair_idx];
|
||||
int dst_block_number = block_mapping[2 * pair_idx + 1];
|
||||
int64_t src_block_number = block_mapping[2 * pair_idx];
|
||||
int64_t dst_block_number = block_mapping[2 * pair_idx + 1];
|
||||
|
||||
const int src_block_offset = src_block_number * numel_per_block;
|
||||
const int dst_block_offset = dst_block_number * numel_per_block;
|
||||
const int64_t src_block_offset = src_block_number * numel_per_block;
|
||||
const int64_t dst_block_offset = dst_block_number * numel_per_block;
|
||||
for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
|
||||
int src_offset = src_block_offset + i;
|
||||
int dst_offset = dst_block_offset + i;
|
||||
int64_t src_offset = src_block_offset + i;
|
||||
int64_t dst_offset = dst_block_offset + i;
|
||||
key_cache[dst_offset] = key_cache[src_offset];
|
||||
}
|
||||
for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
|
||||
int src_offset = src_block_offset + i;
|
||||
int dst_offset = dst_block_offset + i;
|
||||
int64_t src_offset = src_block_offset + i;
|
||||
int64_t dst_offset = dst_block_offset + i;
|
||||
value_cache[dst_offset] = value_cache[src_offset];
|
||||
}
|
||||
}
|
||||
@ -102,15 +102,15 @@ void copy_blocks(
|
||||
value_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr());
|
||||
}
|
||||
// Create block mapping array.
|
||||
std::vector<int> block_mapping_vec;
|
||||
std::vector<int64_t> block_mapping_vec;
|
||||
for (const auto& pair : block_mapping) {
|
||||
int src_block_number = pair.first;
|
||||
for (int dst_block_number : pair.second) {
|
||||
int64_t src_block_number = pair.first;
|
||||
for (int64_t dst_block_number : pair.second) {
|
||||
block_mapping_vec.push_back(src_block_number);
|
||||
block_mapping_vec.push_back(dst_block_number);
|
||||
}
|
||||
}
|
||||
int* block_mapping_array = block_mapping_vec.data();
|
||||
int64_t* block_mapping_array = block_mapping_vec.data();
|
||||
int num_pairs = block_mapping_vec.size() / 2;
|
||||
|
||||
// Move the data structures to the GPU.
|
||||
@ -120,7 +120,7 @@ void copy_blocks(
|
||||
torch::Tensor value_cache_ptrs_tensor = torch::from_blob(
|
||||
value_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
|
||||
torch::Tensor block_mapping_tensor = torch::from_blob(
|
||||
block_mapping_array, {2 * num_pairs}, torch::kInt).to(cache_device);
|
||||
block_mapping_array, {2 * num_pairs}, torch::kInt64).to(cache_device);
|
||||
|
||||
// Launch the kernel.
|
||||
const int numel_per_block = key_caches[0][0].numel();
|
||||
@ -132,7 +132,7 @@ void copy_blocks(
|
||||
vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
key_cache_ptrs_tensor.data_ptr<int64_t>(),
|
||||
value_cache_ptrs_tensor.data_ptr<int64_t>(),
|
||||
block_mapping_tensor.data_ptr<int>(),
|
||||
block_mapping_tensor.data_ptr<int64_t>(),
|
||||
numel_per_block);
|
||||
}));
|
||||
}
|
||||
@ -141,43 +141,48 @@ namespace vllm {
|
||||
|
||||
template<typename scalar_t>
|
||||
__global__ void reshape_and_cache_kernel(
|
||||
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
||||
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
|
||||
scalar_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||
scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||
const int* __restrict__ slot_mapping, // [num_tokens]
|
||||
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
||||
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
|
||||
scalar_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||
scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
||||
const int key_stride,
|
||||
const int value_stride,
|
||||
const int num_heads,
|
||||
const int head_size,
|
||||
const int block_size,
|
||||
const int x) {
|
||||
const int token_idx = blockIdx.x;
|
||||
const int slot_idx = slot_mapping[token_idx];
|
||||
const int block_idx = slot_idx / block_size;
|
||||
const int block_offset = slot_idx % block_size;
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
const int64_t slot_idx = slot_mapping[token_idx];
|
||||
if (slot_idx < 0) {
|
||||
// Padding token that should be ignored.
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t block_idx = slot_idx / block_size;
|
||||
const int64_t block_offset = slot_idx % block_size;
|
||||
|
||||
const int n = num_heads * head_size;
|
||||
for (int i = threadIdx.x; i < n; i += blockDim.x) {
|
||||
const int src_key_idx = token_idx * key_stride + i;
|
||||
const int src_value_idx = token_idx * value_stride + i;
|
||||
const int64_t src_key_idx = token_idx * key_stride + i;
|
||||
const int64_t src_value_idx = token_idx * value_stride + i;
|
||||
|
||||
const int head_idx = i / head_size;
|
||||
const int head_offset = i % head_size;
|
||||
const int x_idx = head_offset / x;
|
||||
const int x_offset = head_offset % x;
|
||||
|
||||
const int tgt_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
|
||||
+ head_idx * (head_size / x) * block_size * x
|
||||
+ x_idx * block_size * x
|
||||
+ block_offset * x
|
||||
+ x_offset;
|
||||
const int tgt_value_idx = block_idx * num_heads * head_size * block_size
|
||||
+ head_idx * head_size * block_size
|
||||
+ head_offset * block_size
|
||||
+ block_offset;
|
||||
key_cache[tgt_key_idx] = __ldg(&key[src_key_idx]);
|
||||
value_cache[tgt_value_idx] = __ldg(&value[src_value_idx]);
|
||||
const int64_t tgt_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
|
||||
+ head_idx * (head_size / x) * block_size * x
|
||||
+ x_idx * block_size * x
|
||||
+ block_offset * x
|
||||
+ x_offset;
|
||||
const int64_t tgt_value_idx = block_idx * num_heads * head_size * block_size
|
||||
+ head_idx * head_size * block_size
|
||||
+ head_offset * block_size
|
||||
+ block_offset;
|
||||
key_cache[tgt_key_idx] = key[src_key_idx];
|
||||
value_cache[tgt_value_idx] = value[src_value_idx];
|
||||
}
|
||||
}
|
||||
|
||||
@ -211,7 +216,7 @@ void reshape_and_cache(
|
||||
value.data_ptr<scalar_t>(),
|
||||
key_cache.data_ptr<scalar_t>(),
|
||||
value_cache.data_ptr<scalar_t>(),
|
||||
slot_mapping.data_ptr<int>(),
|
||||
slot_mapping.data_ptr<int64_t>(),
|
||||
key_stride,
|
||||
value_stride,
|
||||
num_heads,
|
||||
|
13
csrc/cuda_utils.cpp
Normal file
@ -0,0 +1,13 @@
|
||||
#include <torch/extension.h>
|
||||
|
||||
int get_device_attribute(
|
||||
int attribute,
|
||||
int device_id);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def(
|
||||
"get_device_attribute",
|
||||
&get_device_attribute,
|
||||
"Gets the specified device attribute.");
|
||||
}
|
||||
|
14
csrc/cuda_utils_kernels.cu
Normal file
@ -0,0 +1,14 @@
|
||||
int get_device_attribute(
|
||||
int attribute,
|
||||
int device_id)
|
||||
{
|
||||
int device, value;
|
||||
if (device_id < 0) {
|
||||
cudaGetDevice(&device);
|
||||
}
|
||||
else {
|
||||
device = device_id;
|
||||
}
|
||||
cudaDeviceGetAttribute(&value, static_cast<cudaDeviceAttr>(attribute), device);
|
||||
return value;
|
||||
}
|
@ -6,9 +6,19 @@ void rms_norm(
|
||||
torch::Tensor& weight,
|
||||
float epsilon);
|
||||
|
||||
void fused_add_rms_norm(
|
||||
torch::Tensor& input,
|
||||
torch::Tensor& residual,
|
||||
torch::Tensor& weight,
|
||||
float epsilon);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def(
|
||||
"rms_norm",
|
||||
&rms_norm,
|
||||
"Apply Root Mean Square (RMS) Normalization to the input tensor.");
|
||||
m.def(
|
||||
"fused_add_rms_norm",
|
||||
&fused_add_rms_norm,
|
||||
"In-place fused Add and RMS Normalization");
|
||||
}
|
||||
|
@ -9,8 +9,8 @@ namespace vllm {
|
||||
// TODO(woosuk): Further optimize this kernel.
|
||||
template<typename scalar_t>
|
||||
__global__ void rms_norm_kernel(
|
||||
scalar_t* __restrict__ out, // [num_tokens, hidden_size]
|
||||
const scalar_t* __restrict__ input, // [num_tokens, hidden_size]
|
||||
scalar_t* __restrict__ out, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ input, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||
const float epsilon,
|
||||
const int num_tokens,
|
||||
@ -34,15 +34,45 @@ __global__ void rms_norm_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Further optimize this kernel.
|
||||
template<typename scalar_t>
|
||||
__global__ void fused_add_rms_norm_kernel(
|
||||
scalar_t* __restrict__ input, // [..., hidden_size]
|
||||
scalar_t* __restrict__ residual, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||
const float epsilon,
|
||||
const int num_tokens,
|
||||
const int hidden_size) {
|
||||
__shared__ float s_variance;
|
||||
float variance = 0.0f;
|
||||
|
||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||
float x = (float) input[blockIdx.x * hidden_size + idx];
|
||||
x += (float) residual[blockIdx.x * hidden_size + idx];
|
||||
variance += x * x;
|
||||
residual[blockIdx.x * hidden_size + idx] = (scalar_t) x;
|
||||
}
|
||||
variance = blockReduceSum<float>(variance);
|
||||
if (threadIdx.x == 0) {
|
||||
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||
float x = (float) residual[blockIdx.x * hidden_size + idx];
|
||||
input[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx];
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
void rms_norm(
|
||||
torch::Tensor& out, // [num_tokens, hidden_size]
|
||||
torch::Tensor& input, // [num_tokens, hidden_size]
|
||||
torch::Tensor& out, // [..., hidden_size]
|
||||
torch::Tensor& input, // [..., hidden_size]
|
||||
torch::Tensor& weight, // [hidden_size]
|
||||
float epsilon) {
|
||||
int num_tokens = input.size(0);
|
||||
int hidden_size = input.size(1);
|
||||
int hidden_size = input.size(-1);
|
||||
int num_tokens = input.numel() / hidden_size;
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(hidden_size, 1024));
|
||||
@ -60,3 +90,28 @@ void rms_norm(
|
||||
hidden_size);
|
||||
});
|
||||
}
|
||||
|
||||
void fused_add_rms_norm(
|
||||
torch::Tensor& input, // [..., hidden_size]
|
||||
torch::Tensor& residual, // [..., hidden_size]
|
||||
torch::Tensor& weight, // [hidden_size]
|
||||
float epsilon) {
|
||||
int hidden_size = input.size(-1);
|
||||
int num_tokens = input.numel() / hidden_size;
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(hidden_size, 1024));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(),
|
||||
"fused_add_rms_norm_kernel",
|
||||
[&] {
|
||||
vllm::fused_add_rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
input.data_ptr<scalar_t>(),
|
||||
residual.data_ptr<scalar_t>(),
|
||||
weight.data_ptr<scalar_t>(),
|
||||
epsilon,
|
||||
num_tokens,
|
||||
hidden_size);
|
||||
});
|
||||
}
|
||||
|
@ -37,9 +37,9 @@ inline __device__ void apply_rotary_embedding(
|
||||
|
||||
template<typename scalar_t, bool IS_NEOX>
|
||||
__global__ void rotary_embedding_kernel(
|
||||
const int64_t* __restrict__ positions, // [num_tokens]
|
||||
scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size]
|
||||
scalar_t* __restrict__ key, // [num_tokens, num_kv_heads, head_size]
|
||||
const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens]
|
||||
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
|
||||
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
|
||||
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
|
||||
const int rot_dim,
|
||||
const int query_stride,
|
||||
@ -78,18 +78,18 @@ __global__ void rotary_embedding_kernel(
|
||||
} // namespace vllm
|
||||
|
||||
void rotary_embedding(
|
||||
torch::Tensor& positions, // [num_tokens]
|
||||
torch::Tensor& query, // [num_tokens, num_heads * head_size]
|
||||
torch::Tensor& key, // [num_tokens, num_kv_heads * head_size]
|
||||
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
|
||||
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size]
|
||||
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size]
|
||||
int head_size,
|
||||
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
|
||||
bool is_neox) {
|
||||
int num_tokens = query.size(0);
|
||||
int64_t num_tokens = query.numel() / query.size(-1);
|
||||
int rot_dim = cos_sin_cache.size(1);
|
||||
int num_heads = query.size(1) / head_size;
|
||||
int num_kv_heads = key.size(1) / head_size;
|
||||
int query_stride = query.stride(0);
|
||||
int key_stride = key.stride(0);
|
||||
int num_heads = query.size(-1) / head_size;
|
||||
int num_kv_heads = key.size(-1) / head_size;
|
||||
int query_stride = query.stride(-2);
|
||||
int key_stride = key.stride(-2);
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(num_heads * rot_dim / 2, 512));
|
||||
|
19
csrc/quantization.cpp
Normal file
@ -0,0 +1,19 @@
|
||||
#include <torch/extension.h>
|
||||
|
||||
torch::Tensor awq_gemm(
|
||||
torch::Tensor _in_feats,
|
||||
torch::Tensor _kernel,
|
||||
torch::Tensor _scaling_factors,
|
||||
torch::Tensor _zeros,
|
||||
int split_k_iters);
|
||||
|
||||
void squeezellm_gemm(
|
||||
torch::Tensor vec,
|
||||
torch::Tensor mat,
|
||||
torch::Tensor mul,
|
||||
torch::Tensor lookup_table);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
|
||||
m.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
|
||||
}
|
87
csrc/quantization/awq/dequantize.cuh
Normal file
@ -0,0 +1,87 @@
|
||||
/*
|
||||
Adapted from https://github.com/mit-han-lab/llm-awq
|
||||
Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
|
||||
@article{lin2023awq,
|
||||
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
|
||||
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
|
||||
journal={arXiv},
|
||||
year={2023}
|
||||
}
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace vllm {
|
||||
namespace awq {
|
||||
|
||||
__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
|
||||
assert(false);
|
||||
#else
|
||||
uint4 result;
|
||||
|
||||
uint32_t* h = reinterpret_cast<uint32_t*>(&result);
|
||||
uint32_t const i4s = reinterpret_cast<uint32_t const&>(source);
|
||||
|
||||
// First, we extract the i4s and construct an intermediate fp16 number.
|
||||
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
|
||||
static constexpr uint32_t BOTTOM_MASK = 0x000f000f;
|
||||
static constexpr uint32_t TOP_MASK = 0x00f000f0;
|
||||
static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;
|
||||
|
||||
// Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing
|
||||
// format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions.
|
||||
// In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and
|
||||
// elt_67 to fp16 without having to shift them to the bottom bits before hand.
|
||||
|
||||
// Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue
|
||||
// immediately before required.
|
||||
const uint32_t top_i4s = i4s >> 8;
|
||||
// Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
: "=r"(h[0])
|
||||
: "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
||||
// Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
: "=r"(h[1])
|
||||
: "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
||||
// Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
: "=r"(h[2])
|
||||
: "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
||||
// Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
: "=r"(h[3])
|
||||
: "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
||||
|
||||
// I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the
|
||||
// half2 ctor. In this case, I chose performance reliability over code readability.
|
||||
|
||||
// This is the half2 {1032, 1032} represented as an integer.
|
||||
// static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408;
|
||||
// Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7]
|
||||
static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400;
|
||||
// This is the half2 {1 / 16, 1 / 16} represented as an integer.
|
||||
static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00;
|
||||
// This is the half2 {-72, -72} represented as an integer.
|
||||
// static constexpr uint32_t NEG_72 = 0xd480d480;
|
||||
// Haotian: Let's use {-64, -64}.
|
||||
static constexpr uint32_t NEG_64 = 0xd400d400;
|
||||
|
||||
// Finally, we construct the output numbers.
|
||||
// Convert elt_01
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
|
||||
// Convert elt_23
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
|
||||
// Convert elt_45
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
|
||||
// Convert elt_67
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
|
||||
|
||||
return result;
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace awq
|
||||
} // namespace vllm
|
560
csrc/quantization/awq/gemm_kernels.cu
Normal file
@ -0,0 +1,560 @@
|
||||
/*
|
||||
Adapted from https://github.com/mit-han-lab/llm-awq
|
||||
@article{lin2023awq,
|
||||
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
|
||||
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
|
||||
journal={arXiv},
|
||||
year={2023}
|
||||
}
|
||||
*/
|
||||
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include "dequantize.cuh"
|
||||
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
namespace vllm {
|
||||
namespace awq {
|
||||
|
||||
// Pack two half values.
|
||||
static inline __device__ __host__ unsigned
|
||||
__pack_half2(const half x, const half y) {
|
||||
unsigned v0 = *((unsigned short *)&x);
|
||||
unsigned v1 = *((unsigned short *)&y);
|
||||
return (v1 << 16) | v0;
|
||||
}
|
||||
|
||||
__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
|
||||
assert(false);
|
||||
#else
|
||||
static constexpr uint32_t ZERO = 0x0;
|
||||
float C_warp[32];
|
||||
__shared__ half A_shared[16 * (32 + 8)];
|
||||
__shared__ half B_shared[32 * (128 + 8)];
|
||||
|
||||
__shared__ half scaling_factors_shared[128];
|
||||
__shared__ half zeros_shared[128];
|
||||
|
||||
int j_factors1 = ((OC + 128 - 1) / 128);
|
||||
int blockIdx_x = 0;
|
||||
int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1);
|
||||
int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1);
|
||||
|
||||
half A_shared_warp[8];
|
||||
half B_shared_warp[32];
|
||||
for (int j_0_4_init = 0; j_0_4_init < 4; ++j_0_4_init) {
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
C_warp[(j_0_4_init * 8) + i] = 0.0;
|
||||
}
|
||||
}
|
||||
|
||||
static constexpr int row_stride_warp = 32 * 8 / 32;
|
||||
static constexpr int row_stride = 2 * 32 * 8 / 128;
|
||||
bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < 128;
|
||||
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
|
||||
bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id
|
||||
// bool wb_C_flag = (threadIdx.x / 4) < M;
|
||||
|
||||
half* A_ptr = A
|
||||
+ (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC
|
||||
+ (((int)threadIdx.x) % (32 / 8)) * 8;
|
||||
|
||||
int* B_ptr = B
|
||||
+ ((int)threadIdx.y) * (OC / 8) * 2
|
||||
+ (((int)threadIdx.x) / (128 / 8)) * (OC / 8)
|
||||
+ (((int)blockIdx_y) % j_factors1) * (128 / 8)
|
||||
+ (((int)threadIdx.x) % (128 / 8)) * 1;
|
||||
// Why * 1 in the above line?
|
||||
|
||||
half* A_shared_ptr = A_shared
|
||||
+ ((int)threadIdx.y) * row_stride_warp * (32 + 8)
|
||||
+ (((int)threadIdx.x) / (32 / 8)) * (32 + 8)
|
||||
+ (((int)threadIdx.x) % (32 / 8) ) * 8;
|
||||
|
||||
half* B_shared_ptr = B_shared
|
||||
+ ((int)threadIdx.y) * (row_stride / 2) * (128 + 8)
|
||||
+ (((int)threadIdx.x) / (128 / 8)) * (128 + 8)
|
||||
+ (((int)threadIdx.x) % (128 / 8)) * 8;
|
||||
|
||||
int* zeros_ptr = zeros
|
||||
+ (((int)blockIdx_y) % j_factors1) * (128 / 8)
|
||||
+ ((int)threadIdx.x) % (128 / 8);
|
||||
|
||||
half* scaling_factors_ptr = scaling_factors
|
||||
+ (((int)blockIdx_y) % j_factors1) * (128)
|
||||
+ (((int)threadIdx.x) % (128 / 8)) * 8;
|
||||
|
||||
half* C_ptr = C
|
||||
+ static_cast<long long>(blockIdx_z) * M * OC // blockIdz.x -> split_k dim
|
||||
+ (((int)blockIdx_y) % j_factors1) * 128
|
||||
+ ((int)threadIdx.y) * 64
|
||||
+ (((int)threadIdx.x) % 4) * 2;
|
||||
|
||||
// preload s.f. and zeros
|
||||
int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters;
|
||||
if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1;
|
||||
for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) {
|
||||
int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z;
|
||||
__syncthreads();
|
||||
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
|
||||
if (ld_A_flag)
|
||||
{
|
||||
*(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32));
|
||||
}
|
||||
else
|
||||
{
|
||||
*(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0);
|
||||
}
|
||||
|
||||
// for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
|
||||
uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8));
|
||||
uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
|
||||
uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC));
|
||||
/*
|
||||
if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){
|
||||
printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
|
||||
}
|
||||
*/
|
||||
// uint4 B_loaded_scale = make_uint4(0, 0, 0, 0);
|
||||
int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);
|
||||
|
||||
for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 8; ++ax0_ax1_fused_0) {
|
||||
|
||||
// B: 32 x 136 (128+8) float16
|
||||
// each warp: 32 x 4
|
||||
// each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4
|
||||
// *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8)));
|
||||
// row stride in shared memory: (NWARPS * 32 * 8 / cta_N)
|
||||
uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8));
|
||||
uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
|
||||
//uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8);
|
||||
|
||||
// uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8);
|
||||
// - zero and * scale
|
||||
// TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale.
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
|
||||
/*
|
||||
if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){
|
||||
printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w);
|
||||
}
|
||||
*/
|
||||
|
||||
// write back
|
||||
*(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (128 + 8)) = B_loaded_fp16;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1) {
|
||||
{
|
||||
unsigned int addr;
|
||||
__asm__ __volatile__(
|
||||
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
|
||||
: "=r"(addr)
|
||||
: "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8))))
|
||||
);
|
||||
|
||||
|
||||
__asm__ __volatile__(
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
|
||||
"{%0, %1, %2, %3}, [%4];\n"
|
||||
: "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3])
|
||||
: "r"(addr)
|
||||
);
|
||||
}
|
||||
|
||||
for (int ax1_0 = 0; ax1_0 < 4; ++ax1_0) {
|
||||
{
|
||||
unsigned int addr;
|
||||
__asm__ __volatile__(
|
||||
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
|
||||
: "=r"(addr)
|
||||
: "l"((void *)((&(B_shared[(((k_0_1 * 2176) + (((int)threadIdx.y) * 64)) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * 136) + ((((int)threadIdx.x) >> 4) * 8))))
|
||||
);
|
||||
__asm__ __volatile__(
|
||||
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
|
||||
"{%0, %1, %2, %3}, [%4];\n"
|
||||
: "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3])
|
||||
: "r"(addr)
|
||||
);
|
||||
}
|
||||
}
|
||||
for (int j_0_4 = 0; j_0_4 < 4; ++j_0_4) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
||||
{
|
||||
__asm__ __volatile__(
|
||||
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
||||
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
||||
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
|
||||
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
|
||||
}
|
||||
|
||||
{
|
||||
__asm__ __volatile__(
|
||||
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
||||
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
||||
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
|
||||
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
|
||||
}
|
||||
|
||||
{
|
||||
__asm__ __volatile__(
|
||||
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
||||
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
||||
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
|
||||
: "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
|
||||
}
|
||||
|
||||
{
|
||||
__asm__ __volatile__(
|
||||
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
||||
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
||||
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
|
||||
: "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
|
||||
}
|
||||
#else
|
||||
{
|
||||
__asm__ __volatile__(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
|
||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
|
||||
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
|
||||
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
|
||||
}
|
||||
|
||||
{
|
||||
__asm__ __volatile__(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
|
||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
|
||||
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
|
||||
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
|
||||
}
|
||||
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Shang: Hoist loop invariance.
|
||||
for (int ax1_0_1 = 0; ax1_0_1 < 4; ++ax1_0_1) {
|
||||
for (int local_id = 0; local_id < 8; ++local_id) {
|
||||
int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8;
|
||||
if (row_offset < M)
|
||||
{
|
||||
*(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
|
||||
assert(false);
|
||||
#else
|
||||
static constexpr uint32_t ZERO = 0x0;
|
||||
float C_warp[32];
|
||||
__shared__ half A_shared[16 * (32 + 8)];
|
||||
__shared__ half B_shared[32 * (64 + 8)];
|
||||
|
||||
__shared__ half scaling_factors_shared[64];
|
||||
__shared__ half zeros_shared[64];
|
||||
|
||||
int j_factors1 = ((OC + 64 - 1) / 64);
|
||||
|
||||
int blockIdx_x = 0;
|
||||
int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1);
|
||||
int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1);
|
||||
|
||||
half A_shared_warp[8];
|
||||
half B_shared_warp[16];
|
||||
for (int j_0_4_init = 0; j_0_4_init < 2; ++j_0_4_init) {
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
C_warp[(j_0_4_init * 8) + i] = 0.0;
|
||||
}
|
||||
}
|
||||
|
||||
static constexpr int row_stride_warp = 32 * 8 / 32;
|
||||
static constexpr int row_stride = 2 * 32 * 8 / 64;
|
||||
bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < 64;
|
||||
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
|
||||
bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id
|
||||
// bool wb_C_flag = (threadIdx.x / 4) < M;
|
||||
|
||||
half* A_ptr = A
|
||||
+ (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC
|
||||
+ (((int)threadIdx.x) % (32 / 8)) * 8;
|
||||
|
||||
int* B_ptr = B
|
||||
+ ((int)threadIdx.y) * (OC / 8) * 4
|
||||
+ (((int)threadIdx.x) / (64 / 8)) * (OC / 8)
|
||||
+ (((int)blockIdx_y) % j_factors1) * (64 / 8)
|
||||
+ (((int)threadIdx.x) % (64 / 8)) * 1;
|
||||
// Why * 1 in the above line?
|
||||
|
||||
half* A_shared_ptr = A_shared
|
||||
+ ((int)threadIdx.y) * row_stride_warp * (32 + 8)
|
||||
+ (((int)threadIdx.x) / (32 / 8)) * (32 + 8)
|
||||
+ (((int)threadIdx.x) % (32 / 8) ) * 8;
|
||||
|
||||
half* B_shared_ptr = B_shared
|
||||
+ ((int)threadIdx.y) * (row_stride / 2) * (64 + 8)
|
||||
+ (((int)threadIdx.x) / (64 / 8)) * (64 + 8)
|
||||
+ (((int)threadIdx.x) % (64 / 8)) * 8;
|
||||
|
||||
int* zeros_ptr = zeros
|
||||
+ (((int)blockIdx_y) % j_factors1) * (64 / 8)
|
||||
+ ((int)threadIdx.x) % (64 / 8);
|
||||
|
||||
half* scaling_factors_ptr = scaling_factors
|
||||
+ (((int)blockIdx_y) % j_factors1) * (64)
|
||||
+ (((int)threadIdx.x) % (64 / 8)) * 8;
|
||||
|
||||
half* C_ptr = C
|
||||
+ static_cast<long long>(blockIdx_z) * M * OC // blockIdz.x -> split_k dim
|
||||
+ (((int)blockIdx_y) % j_factors1) * 64
|
||||
+ ((int)threadIdx.y) * 32
|
||||
+ (((int)threadIdx.x) % 4) * 2;
|
||||
|
||||
// preload s.f. and zeros
|
||||
int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters;
|
||||
if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1;
|
||||
for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) {
|
||||
int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z;
|
||||
__syncthreads();
|
||||
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
|
||||
if (ld_A_flag)
|
||||
{
|
||||
*(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32));
|
||||
}
|
||||
else
|
||||
{
|
||||
*(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0);
|
||||
}
|
||||
|
||||
// for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
|
||||
uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8));
|
||||
uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
|
||||
uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC));
|
||||
/*
|
||||
if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){
|
||||
printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
|
||||
}
|
||||
*/
|
||||
// uint4 B_loaded_scale = make_uint4(0, 0, 0, 0);
|
||||
int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);
|
||||
|
||||
for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 4; ++ax0_ax1_fused_0) {
|
||||
|
||||
// B: 32 x 136 (128+8) float16
|
||||
// each warp: 32 x 4
|
||||
// each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4
|
||||
// *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8)));
|
||||
// row stride in shared memory: (NWARPS * 32 * 8 / cta_N)
|
||||
uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8));
|
||||
uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
|
||||
//uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8);
|
||||
|
||||
// uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8);
|
||||
// - zero and * scale
|
||||
// TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale.
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
|
||||
/*
|
||||
if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){
|
||||
printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w);
|
||||
}
|
||||
*/
|
||||
|
||||
// write back
|
||||
*(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (64 + 8)) = B_loaded_fp16;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1)
|
||||
{
|
||||
{
|
||||
unsigned int addr;
|
||||
__asm__ __volatile__(
|
||||
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
|
||||
: "=r"(addr)
|
||||
: "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8))))
|
||||
);
|
||||
__asm__ __volatile__(
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
|
||||
"{%0, %1, %2, %3}, [%4];\n"
|
||||
: "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3])
|
||||
: "r"(addr)
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
for (int ax1_0 = 0; ax1_0 < 2; ++ax1_0)
|
||||
{
|
||||
{
|
||||
unsigned int addr;
|
||||
__asm__ __volatile__(
|
||||
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
|
||||
: "=r"(addr)
|
||||
: "l"((void *)((&(B_shared[(((k_0_1 * 1152) + (((int)threadIdx.y) * 32)) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * 72) + ((((int)threadIdx.x) >> 4) * 8))))
|
||||
);
|
||||
__asm__ __volatile__(
|
||||
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
|
||||
"{%0, %1, %2, %3}, [%4];\n"
|
||||
: "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3])
|
||||
: "r"(addr)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
for (int j_0_4 = 0; j_0_4 < 2; ++j_0_4)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
||||
{
|
||||
__asm__ __volatile__(
|
||||
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
||||
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
||||
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
|
||||
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
|
||||
}
|
||||
|
||||
{
|
||||
__asm__ __volatile__(
|
||||
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
||||
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
||||
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
|
||||
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
|
||||
}
|
||||
|
||||
{
|
||||
__asm__ __volatile__(
|
||||
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
||||
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
||||
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
|
||||
: "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
|
||||
}
|
||||
|
||||
{
|
||||
__asm__ __volatile__(
|
||||
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
||||
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
||||
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
|
||||
: "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
|
||||
}
|
||||
#else
|
||||
{
|
||||
__asm__ __volatile__(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
|
||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
|
||||
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
|
||||
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
|
||||
}
|
||||
|
||||
{
|
||||
__asm__ __volatile__(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
|
||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
|
||||
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
|
||||
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Shang: Hoist loop invariance.
|
||||
for (int ax1_0_1 = 0; ax1_0_1 < 2; ++ax1_0_1) {
|
||||
for (int local_id = 0; local_id < 8; ++local_id) {
|
||||
int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8;
|
||||
if (row_offset < M)
|
||||
{
|
||||
*(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace awq
|
||||
} // namespace vllm
|
||||
|
||||
// in_feats: M, IC [float16]
|
||||
// kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b]
|
||||
// scaling_factors: IC // G, OC [float16]
|
||||
// zeros: IC // G, OC // 8 [int32] -> cast to IC // G, OC [uint4b]
|
||||
// assume that batch_size < 16 for now
|
||||
|
||||
torch::Tensor awq_gemm(
|
||||
torch::Tensor _in_feats,
|
||||
torch::Tensor _kernel,
|
||||
torch::Tensor _scaling_factors,
|
||||
torch::Tensor _zeros,
|
||||
int split_k_iters)
|
||||
{
|
||||
int num_in_feats = _in_feats.size(0);
|
||||
int num_in_channels = _in_feats.size(1);
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats));
|
||||
|
||||
auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device());
|
||||
at::Tensor _out_feats = torch::empty({split_k_iters, num_in_feats, _kernel.size(1) * 8}, options);
|
||||
int num_out_feats = _out_feats.size(-2);
|
||||
int num_out_channels = _out_feats.size(-1);
|
||||
|
||||
auto in_feats = reinterpret_cast<half*>(_in_feats.data_ptr<at::Half>());
|
||||
auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>());
|
||||
auto out_feats = reinterpret_cast<half*>(_out_feats.data_ptr<at::Half>());
|
||||
auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
|
||||
auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
|
||||
int group_size = num_in_channels / _scaling_factors.size(0);
|
||||
|
||||
if (num_out_channels % 64 != 0)
|
||||
throw std::invalid_argument("OC is not multiple of cta_N = 64");
|
||||
if (num_out_channels % 8 != 0)
|
||||
throw std::invalid_argument("OC is not multiple of pack_num = 8");
|
||||
if (group_size % 32 != 0)
|
||||
throw std::invalid_argument("Group size should be a multiple of 32");
|
||||
if (num_out_channels % group_size != 0)
|
||||
throw std::invalid_argument("OC is not multiple of Group size");
|
||||
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
if (num_out_channels % 128 == 0)
|
||||
{
|
||||
int j_factors1 = num_out_channels / 128 / 1;
|
||||
dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
|
||||
// threadIdx.x: 32
|
||||
// threadIdx.y: i_factors[2] * j_factors[2]
|
||||
dim3 threads_per_block(32, 2);
|
||||
vllm::awq::gemm_forward_4bit_cuda_m16n128k32<<<num_blocks, threads_per_block, 0, stream>>>(
|
||||
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
|
||||
}
|
||||
else if (num_out_channels % 64 == 0)
|
||||
{
|
||||
int j_factors1 = num_out_channels / 64 / 1;
|
||||
dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
|
||||
|
||||
// threadIdx.x: 32
|
||||
// threadIdx.y: i_factors[2] * j_factors[2]
|
||||
dim3 threads_per_block(32, 2);
|
||||
vllm::awq::gemm_forward_4bit_cuda_m16n64k32<<<num_blocks, threads_per_block, 0, stream>>>(
|
||||
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
|
||||
}
|
||||
return _out_feats.sum(0);
|
||||
}
|
148
csrc/quantization/squeezellm/quant_cuda_kernel.cu
Normal file
@ -0,0 +1,148 @@
|
||||
#include <torch/all.h>
|
||||
#include <torch/python.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
// half-tensor
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <ATen/cuda/CUDATensorMethods.cuh>
|
||||
|
||||
#define BLOCKWIDTH 128
|
||||
#define BLOCKHEIGHT4 16
|
||||
|
||||
namespace vllm {
|
||||
namespace squeezellm {
|
||||
|
||||
__device__ inline unsigned int as_unsigned(int i) {
|
||||
return *reinterpret_cast<unsigned int*>(&i);
|
||||
}
|
||||
|
||||
// 4-bit matvec kernel (LUT-based)
|
||||
__global__ void NUQ4MatMulKernel(
|
||||
const half2* __restrict__ vec,
|
||||
const int* __restrict__ mat,
|
||||
half2* __restrict__ mul,
|
||||
const __half* __restrict__ lookup_table,
|
||||
int height,
|
||||
int width,
|
||||
int batch,
|
||||
int vec_height
|
||||
) {
|
||||
|
||||
const int blockwidth2 = BLOCKWIDTH / 2;
|
||||
|
||||
int row = BLOCKHEIGHT4 * blockIdx.x;
|
||||
int col = BLOCKWIDTH * blockIdx.y + threadIdx.x;
|
||||
|
||||
__shared__ half2 blockvec[blockwidth2];
|
||||
|
||||
__shared__ __half deq2[16][BLOCKWIDTH];
|
||||
int off = threadIdx.x;
|
||||
int column_offset = col * 16;
|
||||
for (int val = 0; val < 16; val += 1) {
|
||||
int lut_index = column_offset + val;
|
||||
deq2[val][off] = lookup_table[lut_index];
|
||||
}
|
||||
|
||||
__half res;
|
||||
half2 res2;
|
||||
half2 tmp2;
|
||||
|
||||
int i;
|
||||
int k;
|
||||
|
||||
unsigned int tmp1;
|
||||
unsigned int lut_index1, lut_index2;
|
||||
|
||||
for (int b = 0; b < batch; ++b){
|
||||
i = width * row + col;
|
||||
res = __int2half_rd(0);
|
||||
k = 0;
|
||||
|
||||
__syncthreads();
|
||||
if (threadIdx.x < blockwidth2)
|
||||
blockvec[threadIdx.x] = vec[b * vec_height / 2 + (row / BLOCKHEIGHT4) * blockwidth2 + threadIdx.x];
|
||||
__syncthreads();
|
||||
|
||||
while (k < blockwidth2) {
|
||||
tmp1 = as_unsigned(mat[i]);
|
||||
|
||||
res2 = {};
|
||||
tmp2 = {};
|
||||
|
||||
lut_index1 = tmp1 & 0xF;
|
||||
lut_index2 = (tmp1 >> 4) & 0xF;
|
||||
tmp2.x = deq2[lut_index1][off];
|
||||
tmp2.y = deq2[lut_index2][off];
|
||||
res2 = __hfma2(tmp2, blockvec[k + 0], res2);
|
||||
|
||||
lut_index1 = (tmp1 >> 8) & 0xF;
|
||||
lut_index2 = (tmp1 >> 12) & 0xF;
|
||||
tmp2.x = deq2[lut_index1][off];
|
||||
tmp2.y = deq2[lut_index2][off];
|
||||
res2 = __hfma2(tmp2, blockvec[k + 1], res2);
|
||||
|
||||
lut_index1 = (tmp1 >> 16) & 0xF;
|
||||
lut_index2 = (tmp1 >> 20) & 0xF;
|
||||
tmp2.x = deq2[lut_index1][off];
|
||||
tmp2.y = deq2[lut_index2][off];
|
||||
res2 = __hfma2(tmp2, blockvec[k + 2], res2);
|
||||
|
||||
lut_index1 = (tmp1 >> 24) & 0xF;
|
||||
lut_index2 = (tmp1 >> 28) & 0xF;
|
||||
tmp2.x = deq2[lut_index1][off];
|
||||
tmp2.y = deq2[lut_index2][off];
|
||||
res2 = __hfma2(tmp2, blockvec[k + 3], res2);
|
||||
|
||||
res = __hadd(__hadd(res2.x, res2.y), res);
|
||||
|
||||
i += width;
|
||||
k += 4;
|
||||
}
|
||||
|
||||
// col%2 -> only set one of the two values
|
||||
half2 res3 = {};
|
||||
if (col % 2 == 0) {
|
||||
res3.x = res;
|
||||
} else {
|
||||
res3.y = res;
|
||||
}
|
||||
|
||||
atomicAdd(&mul[b * width / 2 + col / 2], res3);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace squeezellm
|
||||
} // namespace vllm
|
||||
|
||||
// 4-bit matvec kernel (LUT-based)
|
||||
void squeezellm_gemm(
|
||||
torch::Tensor vec,
|
||||
torch::Tensor mat,
|
||||
torch::Tensor mul,
|
||||
torch::Tensor lookup_table
|
||||
) {
|
||||
int height = mat.size(0);
|
||||
int width = mat.size(1);
|
||||
|
||||
int batch = vec.size(0);
|
||||
int vec_height = vec.size(1);
|
||||
|
||||
dim3 blocks(
|
||||
(height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
|
||||
(width + BLOCKWIDTH - 1) / BLOCKWIDTH
|
||||
);
|
||||
dim3 threads(BLOCKWIDTH);
|
||||
|
||||
vllm::squeezellm::NUQ4MatMulKernel<<<blocks, threads>>>(
|
||||
(half2*) vec.data<at::Half>(),
|
||||
mat.data_ptr<int>(),
|
||||
(half2*) mul.data<at::Half>(),
|
||||
(__half*) lookup_table.data<at::Half>(),
|
||||
height, width, batch, vec_height
|
||||
);
|
||||
}
|
||||
|
||||
#undef BLOCKWIDTH
|
||||
#undef BLOCKHEIGHT4
|
Before Width: | Height: | Size: 267 KiB |
Before Width: | Height: | Size: 285 KiB |
Before Width: | Height: | Size: 259 KiB |
Before Width: | Height: | Size: 276 KiB |
Before Width: | Height: | Size: 244 KiB |
Before Width: | Height: | Size: 260 KiB |
Before Width: | Height: | Size: 255 KiB |
Before Width: | Height: | Size: 272 KiB |
@ -3,31 +3,15 @@
|
||||
Installation
|
||||
============
|
||||
|
||||
vLLM is a Python library that also contains some C++ and CUDA code.
|
||||
This additional code requires compilation on the user's machine.
|
||||
vLLM is a Python library that also contains pre-compiled C++ and CUDA (11.8) binaries.
|
||||
|
||||
Requirements
|
||||
------------
|
||||
|
||||
* OS: Linux
|
||||
* Python: 3.8 or higher
|
||||
* CUDA: 11.0 -- 11.8
|
||||
* Python: 3.8 -- 3.11
|
||||
* GPU: compute capability 7.0 or higher (e.g., V100, T4, RTX20xx, A100, L4, etc.)
|
||||
|
||||
.. note::
|
||||
As of now, vLLM does not support CUDA 12.
|
||||
If you are using Hopper or Lovelace GPUs, please use CUDA 11.8 instead of CUDA 12.
|
||||
|
||||
.. tip::
|
||||
If you have trouble installing vLLM, we recommend using the NVIDIA PyTorch Docker image.
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ # Pull the Docker image with CUDA 11.8.
|
||||
$ docker run --gpus all -it --rm --shm-size=8g nvcr.io/nvidia/pytorch:22.12-py3
|
||||
|
||||
Inside the Docker container, please execute :code:`pip uninstall torch` before installing vLLM.
|
||||
|
||||
Install with pip
|
||||
----------------
|
||||
|
||||
@ -40,7 +24,7 @@ You can install vLLM using pip:
|
||||
$ conda activate myenv
|
||||
|
||||
$ # Install vLLM.
|
||||
$ pip install vllm # This may take 5-10 minutes.
|
||||
$ pip install vllm
|
||||
|
||||
|
||||
.. _build_from_source:
|
||||
@ -55,3 +39,12 @@ You can also build and install vLLM from source:
|
||||
$ git clone https://github.com/vllm-project/vllm.git
|
||||
$ cd vllm
|
||||
$ pip install -e . # This may take 5-10 minutes.
|
||||
|
||||
.. tip::
|
||||
If you have trouble building vLLM, we recommend using the NVIDIA PyTorch Docker image.
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ # Pull the Docker image with CUDA 11.8.
|
||||
$ # Use `--ipc=host` to make sure the shared memory is large enough.
|
||||
$ docker run --gpus all -it --rm --ipc=host nvcr.io/nvidia/pytorch:22.12-py3
|
||||
|
@ -40,6 +40,16 @@ Initialize vLLM's engine for offline inference with the ``LLM`` class and the `O
|
||||
|
||||
llm = LLM(model="facebook/opt-125m")
|
||||
|
||||
Use model from www.modelscope.cn
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
export VLLM_USE_MODELSCOPE=True
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
llm = LLM(model="qwen/Qwen-7B-Chat", revision="v1.1.8", trust_remote_code=True)
|
||||
|
||||
Call ``llm.generate`` to generate the outputs. It adds the input prompts to vLLM engine's waiting queue and executes the vLLM engine to generate the outputs with high throughput. The outputs are returned as a list of ``RequestOutput`` objects, which include all the output tokens.
|
||||
|
||||
.. code-block:: python
|
||||
@ -67,6 +77,16 @@ Start the server:
|
||||
|
||||
$ python -m vllm.entrypoints.api_server
|
||||
|
||||
Use model from www.modelscope.cn
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ VLLM_USE_MODELSCOPE=True python -m vllm.entrypoints.api_server \
|
||||
$ --model="qwen/Qwen-7B-Chat" \
|
||||
$ --revision="v1.1.8" \
|
||||
$ --trust-remote-code
|
||||
|
||||
|
||||
By default, this command starts the server at ``http://localhost:8000`` with the OPT-125M model.
|
||||
|
||||
Query the model in shell:
|
||||
@ -95,6 +115,13 @@ Start the server:
|
||||
$ python -m vllm.entrypoints.openai.api_server \
|
||||
$ --model facebook/opt-125m
|
||||
|
||||
Use model from www.modelscope.cn
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ VLLM_USE_MODELSCOPE=True python -m vllm.entrypoints.openai.api_server \
|
||||
$ --model="qwen/Qwen-7B-Chat" --revision="v1.1.8" --trust-remote-code
|
||||
|
||||
By default, it starts the server at ``http://localhost:8000``. You can specify the address with ``--host`` and ``--port`` arguments. The server currently hosts one model at a time (OPT-125M in the above command) and implements `list models <https://platform.openai.com/docs/api-reference/models/list>`_ and `create completion <https://platform.openai.com/docs/api-reference/completions/create>`_ endpoints. We are actively adding support for more endpoints.
|
||||
|
||||
This server can be queried in the same format as OpenAI API. For example, list the models:
|
||||
@ -128,4 +155,4 @@ Since this server is compatible with OpenAI API, you can use it as a drop-in rep
|
||||
prompt="San Francisco is a")
|
||||
print("Completion result:", completion)
|
||||
|
||||
For a more detailed client example, refer to `examples/openai_client.py <https://github.com/vllm-project/vllm/blob/main/examples/openai_client.py>`_.
|
||||
For a more detailed client example, refer to `examples/openai_completion_client.py <https://github.com/vllm-project/vllm/blob/main/examples/openai_completion_client.py>`_.
|
||||
|
@ -43,6 +43,7 @@ vLLM is flexible and easy to use with:
|
||||
For more information, check out the following:
|
||||
|
||||
* `vLLM announcing blog post <https://vllm.ai>`_ (intro to PagedAttention)
|
||||
* `vLLM paper <https://arxiv.org/abs/2309.06180>`_ (SOSP 2023)
|
||||
* `How continuous batching enables 23x throughput in LLM inference while reducing p50 latency <https://www.anyscale.com/blog/continuous-batching-llm-inference>`_ by Cade Daniel et al.
|
||||
|
||||
|
||||
@ -63,6 +64,8 @@ Documentation
|
||||
|
||||
serving/distributed_serving
|
||||
serving/run_on_sky
|
||||
serving/deploying_with_triton
|
||||
serving/deploying_with_docker
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
@ -70,3 +73,9 @@ Documentation
|
||||
|
||||
models/supported_models
|
||||
models/adding_model
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:caption: Quantization
|
||||
|
||||
quantization/auto_awq
|
@ -62,31 +62,34 @@ Next, you need to rewrite the :code:`forward` methods of your model by following
|
||||
+) -> SamplerOutput:
|
||||
|
||||
3. Update the code by considering that :code:`input_ids` and :code:`positions` are now flattened tensors.
|
||||
4. Replace the attention operation with either :code:`GPTPagedAttention` or :code:`GPTNeoXPagedAttention`, depending on the model's architecture.
|
||||
4. Replace the attention operation with either :code:`PagedAttention`, :code:`PagedAttentionWithRoPE`, or :code:`PagedAttentionWithALiBi` depending on the model's architecture.
|
||||
|
||||
.. note::
|
||||
Currently, vLLM supports the basic multi-head attention mechanism and its variant with rotary positional embeddings.
|
||||
If your model employs a different attention mechanism, you will need to implement a new attention layer in vLLM.
|
||||
|
||||
|
||||
3. (Optional) Implement tensor parallelism support
|
||||
--------------------------------------------------
|
||||
3. (Optional) Implement tensor parallelism and quantization support
|
||||
-------------------------------------------------------------------
|
||||
|
||||
If your model is too large to fit into a single GPU, you can use tensor parallelism to manage it.
|
||||
To do this, substitute your model's linear and embedding layers with their tensor-parallel versions.
|
||||
For the embedding layer, you can simply replace :code:`nn.Embedding` with :code:`VocabParallelEmbedding`.
|
||||
When it comes to the linear layers, you should use either :code:`RowParallelLinear` or :code:`ColumnParallelLinear`.
|
||||
Typically, :code:`ColumnParallelLinear` is used for QKV linear layers and the first linear layers of the MLP blocks.
|
||||
For the remaining linear layers, :code:`RowParallelLinear` is used.
|
||||
For the embedding layer, you can simply replace :code:`nn.Embedding` with :code:`VocabParallelEmbedding`. For the output LM head, you can use :code:`ParallelLMHead`.
|
||||
When it comes to the linear layers, we provide the following options to parallelize them:
|
||||
|
||||
* :code:`ReplicatedLinear`: Replicates the inputs and weights across multiple GPUs. No memory saving.
|
||||
* :code:`RowParallelLinear`: The input tensor is partitioned along the hidden dimension. The weight matrix is partitioned along the rows (input dimension). An *all-reduce* operation is performed after the matrix multiplication to reduce the results. Typically used for the second FFN layer and the output linear transformation of the attention layer.
|
||||
* :code:`ColumnParallelLinear`: The input tensor is replicated. The weight matrix is partitioned along the columns (output dimension). The result is partitioned along the column dimension. Typically used for the first FFN layer and the separated QKV transformation of the attention layer in the original Transformer.
|
||||
* :code:`MergedColumnParallelLinear`: Column-parallel linear that merges multiple `ColumnParallelLinear` operators. Typically used for the first FFN layer with weighted activation functions (e.g., SiLU). This class handles the sharded weight loading logic of multiple weight matrices.
|
||||
* :code:`QKVParallelLinear`: Parallel linear layer for the query, key, and value projections of the multi-head and grouped-query attention mechanisms. When number of key/value heads are less than the world size, this class replicates the key/value heads properly. This class handles the weight loading and replication of the weight matrices.
|
||||
|
||||
Note that all the linear layers above take `linear_method` as an input. vLLM will set this parameter according to different quantization schemes to support weight quantization.
|
||||
|
||||
4. Implement the weight loading logic
|
||||
-------------------------------------
|
||||
|
||||
You now need to implement the :code:`load_weights` method in your :code:`*ForCausalLM` class.
|
||||
This method should load the weights from the HuggingFace's checkpoint file and assign them to the corresponding layers in your model.
|
||||
While the process is straightforward for most layers, the tensor-parallel layers necessitate some additional care as their weights should be partitioned to multiple GPUs.
|
||||
|
||||
This method should load the weights from the HuggingFace's checkpoint file and assign them to the corresponding layers in your model. Specifically, for `MergedColumnParallelLinear` and `QKVParallelLinear` layers, if the original model has separated weight matrices, you need to load the different parts separately.
|
||||
|
||||
5. Register your model
|
||||
----------------------
|
||||
|
@ -20,12 +20,15 @@ Alongside each architecture, we include some popular models that use it.
|
||||
* - :code:`BaiChuanForCausalLM`
|
||||
- Baichuan
|
||||
- :code:`baichuan-inc/Baichuan-7B`, :code:`baichuan-inc/Baichuan-13B-Chat`, etc.
|
||||
* - :code:`ChatGLMModel`
|
||||
- ChatGLM
|
||||
- :code:`THUDM/chatglm2-6b`, :code:`THUDM/chatglm3-6b`, etc.
|
||||
* - :code:`BloomForCausalLM`
|
||||
- BLOOM, BLOOMZ, BLOOMChat
|
||||
- :code:`bigscience/bloom`, :code:`bigscience/bloomz`, etc.
|
||||
* - :code:`FalconForCausalLM`
|
||||
- Falcon
|
||||
- :code:`tiiuae/falcon-7b``, :code:`tiiuae/falcon-40b`, :code:`tiiuae/falcon-rw-7b`, etc.
|
||||
- :code:`tiiuae/falcon-7b`, :code:`tiiuae/falcon-40b`, :code:`tiiuae/falcon-rw-7b`, etc.
|
||||
* - :code:`GPT2LMHeadModel`
|
||||
- GPT-2
|
||||
- :code:`gpt2`, :code:`gpt2-xl`, etc.
|
||||
@ -44,15 +47,24 @@ Alongside each architecture, we include some popular models that use it.
|
||||
* - :code:`LlamaForCausalLM`
|
||||
- LLaMA, LLaMA-2, Vicuna, Alpaca, Koala, Guanaco
|
||||
- :code:`meta-llama/Llama-2-13b-hf`, :code:`meta-llama/Llama-2-70b-hf`, :code:`openlm-research/open_llama_13b`, :code:`lmsys/vicuna-13b-v1.3`, :code:`young-geng/koala`, etc.
|
||||
* - :code:`MistralForCausalLM`
|
||||
- Mistral, Mistral-Instruct
|
||||
- :code:`mistralai/Mistral-7B-v0.1`, :code:`mistralai/Mistral-7B-Instruct-v0.1`, etc.
|
||||
* - :code:`MPTForCausalLM`
|
||||
- MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter
|
||||
- :code:`mosaicml/mpt-7b`, :code:`mosaicml/mpt-7b-storywriter`, :code:`mosaicml/mpt-30b`, etc.
|
||||
* - :code:`OPTForCausalLM`
|
||||
- OPT, OPT-IML
|
||||
- :code:`facebook/opt-66b`, :code:`facebook/opt-iml-max-30b`, etc.
|
||||
* - :code:`PhiForCausalLM`
|
||||
- Phi-1.5
|
||||
- :code:`microsoft/phi-1_5`, etc.
|
||||
* - :code:`QWenLMHeadModel`
|
||||
- Qwen
|
||||
- :code:`Qwen/Qwen-7B`, :code:`Qwen/Qwen-7B-Chat`, etc.
|
||||
* - :code:`YiForCausalLM`
|
||||
- Yi
|
||||
- :code:`01-ai/Yi-6B`, :code:`01-ai/Yi-34B`, etc.
|
||||
|
||||
If your model uses one of the above model architectures, you can seamlessly run your model with vLLM.
|
||||
Otherwise, please refer to :ref:`Adding a New Model <adding_a_new_model>` for instructions on how to implement support for your model.
|
||||
@ -69,4 +81,18 @@ Alternatively, you can raise an issue on our `GitHub <https://github.com/vllm-pr
|
||||
output = llm.generate("Hello, my name is")
|
||||
print(output)
|
||||
|
||||
To use model from www.modelscope.cn
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
$ export VLLM_USE_MODELSCOPE=True
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from vllm import LLM
|
||||
|
||||
llm = LLM(model=..., revision=..., trust_remote_code=True) # Name or path of your model
|
||||
output = llm.generate("Hello, my name is")
|
||||
print(output)
|
||||
|
||||
If vLLM successfully generates text, it indicates that your model is supported.
|
||||
|
69
docs/source/quantization/auto_awq.rst
Normal file
@ -0,0 +1,69 @@
|
||||
.. _auto_awq:
|
||||
|
||||
AutoAWQ
|
||||
==================
|
||||
|
||||
To create a new 4-bit quantized model, you can leverage `AutoAWQ <https://github.com/casper-hansen/AutoAWQ>`_.
|
||||
Quantizing reduces the model's precision from FP16 to INT4 which effectively reduces the file size by ~70%.
|
||||
The main benefits are lower latency and memory usage.
|
||||
|
||||
You can quantize your own models by installing AutoAWQ or picking one of the `400+ models on Huggingface <https://huggingface.co/models?sort=trending&search=awq>`_.
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ pip install autoawq
|
||||
|
||||
After installing AutoAWQ, you are ready to quantize a model. Here is an example of how to quantize Vicuna 7B v1.5:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from awq import AutoAWQForCausalLM
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
model_path = 'lmsys/vicuna-7b-v1.5'
|
||||
quant_path = 'vicuna-7b-v1.5-awq'
|
||||
quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM" }
|
||||
|
||||
# Load model
|
||||
model = AutoAWQForCausalLM.from_pretrained(model_path, **{"low_cpu_mem_usage": True})
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
|
||||
# Quantize
|
||||
model.quantize(tokenizer, quant_config=quant_config)
|
||||
|
||||
# Save quantized model
|
||||
model.save_quantized(quant_path)
|
||||
tokenizer.save_pretrained(quant_path)
|
||||
|
||||
To run an AWQ model with vLLM, you can use `TheBloke/Llama-2-7b-Chat-AWQ <https://huggingface.co/TheBloke/Llama-2-7b-Chat-AWQ>`_ with the following command:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ python examples/llm_engine_example.py --model TheBloke/Llama-2-7b-Chat-AWQ --quantization awq
|
||||
|
||||
AWQ models are also supported directly through the LLM entrypoint:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
# Create a sampling params object.
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
|
||||
# Create an LLM.
|
||||
llm = LLM(model="TheBloke/Llama-2-7b-Chat-AWQ", quantization="AWQ")
|
||||
# Generate texts from the prompts. The output is a list of RequestOutput objects
|
||||
# that contain the prompt, generated text, and other information.
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
# Print the outputs.
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
21
docs/source/serving/deploying_with_docker.rst
Normal file
@ -0,0 +1,21 @@
|
||||
.. _deploying_with_docker:
|
||||
|
||||
Deploying with Docker
|
||||
============================
|
||||
|
||||
You can build and run vLLM from source via the provided dockerfile. To build vLLM:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ DOCKER_BUILDKIT=1 docker build . --target vllm --tag vllm --build-arg max_jobs=8
|
||||
|
||||
To run vLLM:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ docker run --runtime nvidia --gpus all \
|
||||
-v ~/.cache/huggingface:/root/.cache/huggingface \
|
||||
-p 8000:8000 \
|
||||
--env "HUGGING_FACE_HUB_TOKEN=<secret>" \
|
||||
vllm <args...>
|
||||
|
6
docs/source/serving/deploying_with_triton.rst
Normal file
@ -0,0 +1,6 @@
|
||||
.. _deploying_with_triton:
|
||||
|
||||
Deploying with NVIDIA Triton
|
||||
============================
|
||||
|
||||
The `Triton Inference Server <https://github.com/triton-inference-server>`_ hosts a tutorial demonstrating how to quickly deploy a simple `facebook/opt-125m <https://huggingface.co/facebook/opt-125m>`_ model using vLLM. Please see `Deploying a vLLM model in Triton <https://github.com/triton-inference-server/tutorials/blob/main/Quick_Deploy/vLLM/README.md#deploying-a-vllm-model-in-triton>`_ for more details.
|
@ -39,7 +39,7 @@ def build_demo():
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="localhost")
|
||||
parser.add_argument("--host", type=str, default=None)
|
||||
parser.add_argument("--port", type=int, default=8001)
|
||||
parser.add_argument("--model-url",
|
||||
type=str,
|
||||
|
@ -1,17 +1,14 @@
|
||||
import argparse
|
||||
from typing import List, Tuple
|
||||
|
||||
from vllm import EngineArgs, LLMEngine, SamplingParams
|
||||
from vllm import EngineArgs, LLMEngine, SamplingParams, RequestOutput
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
# Parse the CLI argument and initialize the engine.
|
||||
engine_args = EngineArgs.from_cli_args(args)
|
||||
engine = LLMEngine.from_engine_args(engine_args)
|
||||
|
||||
# Test the following prompts.
|
||||
test_prompts = [
|
||||
def create_test_prompts() -> List[Tuple[str, SamplingParams]]:
|
||||
"""Create a list of test prompts with their sampling parameters."""
|
||||
return [
|
||||
("A robot may not injure a human being",
|
||||
SamplingParams(temperature=0.0)),
|
||||
SamplingParams(temperature=0.0, logprobs=1, prompt_logprobs=1)),
|
||||
("To be or not to be,",
|
||||
SamplingParams(temperature=0.8, top_k=5, presence_penalty=0.2)),
|
||||
("What is the meaning of life?",
|
||||
@ -25,22 +22,36 @@ def main(args: argparse.Namespace):
|
||||
temperature=0.0)),
|
||||
]
|
||||
|
||||
# Run the engine by calling `engine.step()` manually.
|
||||
|
||||
def process_requests(engine: LLMEngine,
|
||||
test_prompts: List[Tuple[str, SamplingParams]]):
|
||||
"""Continuously process a list of prompts and handle the outputs."""
|
||||
request_id = 0
|
||||
while True:
|
||||
# To test continuous batching, we add one request at each step.
|
||||
|
||||
while test_prompts or engine.has_unfinished_requests():
|
||||
if test_prompts:
|
||||
prompt, sampling_params = test_prompts.pop(0)
|
||||
engine.add_request(str(request_id), prompt, sampling_params)
|
||||
request_id += 1
|
||||
|
||||
request_outputs = engine.step()
|
||||
request_outputs: List[RequestOutput] = engine.step()
|
||||
|
||||
for request_output in request_outputs:
|
||||
if request_output.finished:
|
||||
print(request_output)
|
||||
|
||||
if not (engine.has_unfinished_requests() or test_prompts):
|
||||
break
|
||||
|
||||
def initialize_engine(args: argparse.Namespace) -> LLMEngine:
|
||||
"""Initialize the LLMEngine from the command line arguments."""
|
||||
engine_args = EngineArgs.from_cli_args(args)
|
||||
return LLMEngine.from_engine_args(engine_args)
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
"""Main function that sets up and runs the prompt processing."""
|
||||
engine = initialize_engine(args)
|
||||
test_prompts = create_test_prompts()
|
||||
process_requests(engine, test_prompts)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
39
format.sh
@ -44,7 +44,6 @@ YAPF_FLAGS=(
|
||||
|
||||
YAPF_EXCLUDES=(
|
||||
'--exclude' 'build/**'
|
||||
'--exclude' 'vllm/model_executor/parallel_utils/**'
|
||||
)
|
||||
|
||||
# Format specified files
|
||||
@ -72,7 +71,7 @@ format_changed() {
|
||||
|
||||
# Format all files
|
||||
format_all() {
|
||||
yapf --in-place "${YAPF_FLAGS[@]}" "${YAPF_EXCLUDES[@]}" vllm
|
||||
yapf --in-place "${YAPF_FLAGS[@]}" "${YAPF_EXCLUDES[@]}" vllm tests
|
||||
}
|
||||
|
||||
## This flag formats individual files. --files *must* be the first command line
|
||||
@ -94,9 +93,43 @@ echo 'vLLM yapf: Done'
|
||||
# echo 'vLLM mypy:'
|
||||
# mypy
|
||||
|
||||
# Lint specified files
|
||||
lint() {
|
||||
pylint "$@"
|
||||
}
|
||||
|
||||
# Lint files that differ from main branch. Ignores dirs that are not slated
|
||||
# for autolint yet.
|
||||
lint_changed() {
|
||||
# The `if` guard ensures that the list of filenames is not empty, which
|
||||
# could cause pylint to receive 0 positional arguments, making it hang
|
||||
# waiting for STDIN.
|
||||
#
|
||||
# `diff-filter=ACM` and $MERGEBASE is to ensure we only lint files that
|
||||
# exist on both branches.
|
||||
MERGEBASE="$(git merge-base origin/main HEAD)"
|
||||
|
||||
if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then
|
||||
git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs \
|
||||
pylint
|
||||
fi
|
||||
|
||||
}
|
||||
|
||||
# Run Pylint
|
||||
echo 'vLLM Pylint:'
|
||||
pylint vllm
|
||||
## This flag lints individual files. --files *must* be the first command line
|
||||
## arg to use this option.
|
||||
if [[ "$1" == '--files' ]]; then
|
||||
lint "${@:2}"
|
||||
# If `--all` is passed, then any further arguments are ignored and the
|
||||
# entire python directory is linted.
|
||||
elif [[ "$1" == '--all' ]]; then
|
||||
lint vllm tests
|
||||
else
|
||||
# Format only the files that changed in last commit.
|
||||
lint_changed
|
||||
fi
|
||||
|
||||
if ! git diff --quiet &>/dev/null; then
|
||||
echo 'Reformatted files. Please review and stage the changes.'
|
||||
|
@ -3,7 +3,7 @@ requires = [
|
||||
"ninja",
|
||||
"packaging",
|
||||
"setuptools",
|
||||
"torch >= 2.0.0",
|
||||
"torch >= 2.1.0",
|
||||
"wheel",
|
||||
]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
@ -11,3 +11,5 @@ types-setuptools
|
||||
# testing
|
||||
pytest
|
||||
pytest-forked
|
||||
pytest-asyncio
|
||||
|
||||
|
@ -1,11 +1,14 @@
|
||||
ninja # For faster builds.
|
||||
psutil
|
||||
ray >= 2.5.1
|
||||
pandas # Required for Ray data.
|
||||
pyarrow # Required for Ray data.
|
||||
sentencepiece # Required for LLaMA tokenizer.
|
||||
numpy
|
||||
torch >= 2.0.0
|
||||
transformers >= 4.33.1 # Required for Code Llama.
|
||||
xformers >= 0.0.21
|
||||
einops # Required for phi-1_5
|
||||
torch >= 2.1.0
|
||||
transformers >= 4.34.0 # Required for Mistral.
|
||||
xformers >= 0.0.22.post7 # Required for CUDA 12.1.
|
||||
fastapi
|
||||
uvicorn
|
||||
pydantic < 2 # Required for OpenAI server.
|
||||
uvicorn[standard]
|
||||
pydantic == 1.10.13 # Required for OpenAI server.
|
||||
|
207
setup.py
@ -3,6 +3,7 @@ import os
|
||||
import re
|
||||
import subprocess
|
||||
from typing import List, Set
|
||||
import warnings
|
||||
|
||||
from packaging.version import parse, Version
|
||||
import setuptools
|
||||
@ -11,6 +12,11 @@ from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME
|
||||
|
||||
ROOT_DIR = os.path.dirname(__file__)
|
||||
|
||||
MAIN_CUDA_VERSION = "12.1"
|
||||
|
||||
# Supported NVIDIA GPU architectures.
|
||||
SUPPORTED_ARCHS = {"7.0", "7.5", "8.0", "8.6", "8.9", "9.0"}
|
||||
|
||||
# Compiler flags.
|
||||
CXX_FLAGS = ["-g", "-O2", "-std=c++17"]
|
||||
# TODO(woosuk): Should we use -O3?
|
||||
@ -22,7 +28,7 @@ NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
|
||||
|
||||
if CUDA_HOME is None:
|
||||
raise RuntimeError(
|
||||
f"Cannot find CUDA_HOME. CUDA must be available to build the package.")
|
||||
"Cannot find CUDA_HOME. CUDA must be available to build the package.")
|
||||
|
||||
|
||||
def get_nvcc_cuda_version(cuda_dir: str) -> Version:
|
||||
@ -38,47 +44,95 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version:
|
||||
return nvcc_cuda_version
|
||||
|
||||
|
||||
# Collect the compute capabilities of all available GPUs.
|
||||
device_count = torch.cuda.device_count()
|
||||
compute_capabilities: Set[int] = set()
|
||||
for i in range(device_count):
|
||||
major, minor = torch.cuda.get_device_capability(i)
|
||||
if major < 7:
|
||||
def get_torch_arch_list() -> Set[str]:
|
||||
# TORCH_CUDA_ARCH_LIST can have one or more architectures,
|
||||
# e.g. "8.0" or "7.5,8.0,8.6+PTX". Here, the "8.6+PTX" option asks the
|
||||
# compiler to additionally include PTX code that can be runtime-compiled
|
||||
# and executed on the 8.6 or newer architectures. While the PTX code will
|
||||
# not give the best performance on the newer architectures, it provides
|
||||
# forward compatibility.
|
||||
env_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST", None)
|
||||
if env_arch_list is None:
|
||||
return set()
|
||||
|
||||
# List are separated by ; or space.
|
||||
torch_arch_list = set(env_arch_list.replace(" ", ";").split(";"))
|
||||
if not torch_arch_list:
|
||||
return set()
|
||||
|
||||
# Filter out the invalid architectures and print a warning.
|
||||
valid_archs = SUPPORTED_ARCHS.union({s + "+PTX" for s in SUPPORTED_ARCHS})
|
||||
arch_list = torch_arch_list.intersection(valid_archs)
|
||||
# If none of the specified architectures are valid, raise an error.
|
||||
if not arch_list:
|
||||
raise RuntimeError(
|
||||
"GPUs with compute capability less than 7.0 are not supported.")
|
||||
compute_capabilities.add(major * 10 + minor)
|
||||
"None of the CUDA architectures in `TORCH_CUDA_ARCH_LIST` env "
|
||||
f"variable ({env_arch_list}) is supported. "
|
||||
f"Supported CUDA architectures are: {valid_archs}.")
|
||||
invalid_arch_list = torch_arch_list - valid_archs
|
||||
if invalid_arch_list:
|
||||
warnings.warn(
|
||||
f"Unsupported CUDA architectures ({invalid_arch_list}) are "
|
||||
"excluded from the `TORCH_CUDA_ARCH_LIST` env variable "
|
||||
f"({env_arch_list}). Supported CUDA architectures are: "
|
||||
f"{valid_archs}.")
|
||||
return arch_list
|
||||
|
||||
|
||||
# First, check the TORCH_CUDA_ARCH_LIST environment variable.
|
||||
compute_capabilities = get_torch_arch_list()
|
||||
if not compute_capabilities:
|
||||
# If TORCH_CUDA_ARCH_LIST is not defined or empty, target all available
|
||||
# GPUs on the current machine.
|
||||
device_count = torch.cuda.device_count()
|
||||
for i in range(device_count):
|
||||
major, minor = torch.cuda.get_device_capability(i)
|
||||
if major < 7:
|
||||
raise RuntimeError(
|
||||
"GPUs with compute capability below 7.0 are not supported.")
|
||||
compute_capabilities.add(f"{major}.{minor}")
|
||||
|
||||
nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME)
|
||||
if not compute_capabilities:
|
||||
# If no GPU is specified nor available, add all supported architectures
|
||||
# based on the NVCC CUDA version.
|
||||
compute_capabilities = SUPPORTED_ARCHS.copy()
|
||||
if nvcc_cuda_version < Version("11.1"):
|
||||
compute_capabilities.remove("8.6")
|
||||
if nvcc_cuda_version < Version("11.8"):
|
||||
compute_capabilities.remove("8.9")
|
||||
compute_capabilities.remove("9.0")
|
||||
|
||||
# Validate the NVCC CUDA version.
|
||||
nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME)
|
||||
if nvcc_cuda_version < Version("11.0"):
|
||||
raise RuntimeError("CUDA 11.0 or higher is required to build the package.")
|
||||
if 86 in compute_capabilities and nvcc_cuda_version < Version("11.1"):
|
||||
raise RuntimeError(
|
||||
"CUDA 11.1 or higher is required for GPUs with compute capability 8.6.")
|
||||
if 89 in compute_capabilities and nvcc_cuda_version < Version("11.8"):
|
||||
# CUDA 11.8 is required to generate the code targeting compute capability 8.9.
|
||||
# However, GPUs with compute capability 8.9 can also run the code generated by
|
||||
# the previous versions of CUDA 11 and targeting compute capability 8.0.
|
||||
# Therefore, if CUDA 11.8 is not available, we target compute capability 8.0
|
||||
# instead of 8.9.
|
||||
compute_capabilities.remove(89)
|
||||
compute_capabilities.add(80)
|
||||
if 90 in compute_capabilities and nvcc_cuda_version < Version("11.8"):
|
||||
raise RuntimeError(
|
||||
"CUDA 11.8 or higher is required for GPUs with compute capability 9.0.")
|
||||
|
||||
# If no GPU is available, add all supported compute capabilities.
|
||||
if not compute_capabilities:
|
||||
compute_capabilities = {70, 75, 80}
|
||||
if nvcc_cuda_version >= Version("11.1"):
|
||||
compute_capabilities.add(86)
|
||||
if nvcc_cuda_version >= Version("11.8"):
|
||||
compute_capabilities.add(89)
|
||||
compute_capabilities.add(90)
|
||||
if nvcc_cuda_version < Version("11.1"):
|
||||
if any(cc.startswith("8.6") for cc in compute_capabilities):
|
||||
raise RuntimeError(
|
||||
"CUDA 11.1 or higher is required for compute capability 8.6.")
|
||||
if nvcc_cuda_version < Version("11.8"):
|
||||
if any(cc.startswith("8.9") for cc in compute_capabilities):
|
||||
# CUDA 11.8 is required to generate the code targeting compute capability 8.9.
|
||||
# However, GPUs with compute capability 8.9 can also run the code generated by
|
||||
# the previous versions of CUDA 11 and targeting compute capability 8.0.
|
||||
# Therefore, if CUDA 11.8 is not available, we target compute capability 8.0
|
||||
# instead of 8.9.
|
||||
warnings.warn(
|
||||
"CUDA 11.8 or higher is required for compute capability 8.9. "
|
||||
"Targeting compute capability 8.0 instead.")
|
||||
compute_capabilities = set(cc for cc in compute_capabilities
|
||||
if not cc.startswith("8.9"))
|
||||
compute_capabilities.add("8.0+PTX")
|
||||
if any(cc.startswith("9.0") for cc in compute_capabilities):
|
||||
raise RuntimeError(
|
||||
"CUDA 11.8 or higher is required for compute capability 9.0.")
|
||||
|
||||
# Add target compute capabilities to NVCC flags.
|
||||
for capability in compute_capabilities:
|
||||
NVCC_FLAGS += ["-gencode", f"arch=compute_{capability},code=sm_{capability}"]
|
||||
num = capability[0] + capability[2]
|
||||
NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=sm_{num}"]
|
||||
if capability.endswith("+PTX"):
|
||||
NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=compute_{num}"]
|
||||
|
||||
# Use NVCC threads to parallelize the build.
|
||||
if nvcc_cuda_version >= Version("11.2"):
|
||||
@ -91,7 +145,10 @@ ext_modules = []
|
||||
cache_extension = CUDAExtension(
|
||||
name="vllm.cache_ops",
|
||||
sources=["csrc/cache.cpp", "csrc/cache_kernels.cu"],
|
||||
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS},
|
||||
extra_compile_args={
|
||||
"cxx": CXX_FLAGS,
|
||||
"nvcc": NVCC_FLAGS,
|
||||
},
|
||||
)
|
||||
ext_modules.append(cache_extension)
|
||||
|
||||
@ -99,7 +156,10 @@ ext_modules.append(cache_extension)
|
||||
attention_extension = CUDAExtension(
|
||||
name="vllm.attention_ops",
|
||||
sources=["csrc/attention.cpp", "csrc/attention/attention_kernels.cu"],
|
||||
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS},
|
||||
extra_compile_args={
|
||||
"cxx": CXX_FLAGS,
|
||||
"nvcc": NVCC_FLAGS,
|
||||
},
|
||||
)
|
||||
ext_modules.append(attention_extension)
|
||||
|
||||
@ -107,7 +167,10 @@ ext_modules.append(attention_extension)
|
||||
positional_encoding_extension = CUDAExtension(
|
||||
name="vllm.pos_encoding_ops",
|
||||
sources=["csrc/pos_encoding.cpp", "csrc/pos_encoding_kernels.cu"],
|
||||
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS},
|
||||
extra_compile_args={
|
||||
"cxx": CXX_FLAGS,
|
||||
"nvcc": NVCC_FLAGS,
|
||||
},
|
||||
)
|
||||
ext_modules.append(positional_encoding_extension)
|
||||
|
||||
@ -115,7 +178,10 @@ ext_modules.append(positional_encoding_extension)
|
||||
layernorm_extension = CUDAExtension(
|
||||
name="vllm.layernorm_ops",
|
||||
sources=["csrc/layernorm.cpp", "csrc/layernorm_kernels.cu"],
|
||||
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS},
|
||||
extra_compile_args={
|
||||
"cxx": CXX_FLAGS,
|
||||
"nvcc": NVCC_FLAGS,
|
||||
},
|
||||
)
|
||||
ext_modules.append(layernorm_extension)
|
||||
|
||||
@ -123,31 +189,73 @@ ext_modules.append(layernorm_extension)
|
||||
activation_extension = CUDAExtension(
|
||||
name="vllm.activation_ops",
|
||||
sources=["csrc/activation.cpp", "csrc/activation_kernels.cu"],
|
||||
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS},
|
||||
extra_compile_args={
|
||||
"cxx": CXX_FLAGS,
|
||||
"nvcc": NVCC_FLAGS,
|
||||
},
|
||||
)
|
||||
ext_modules.append(activation_extension)
|
||||
|
||||
# Quantization kernels.
|
||||
quantization_extension = CUDAExtension(
|
||||
name="vllm.quantization_ops",
|
||||
sources=[
|
||||
"csrc/quantization.cpp",
|
||||
"csrc/quantization/awq/gemm_kernels.cu",
|
||||
"csrc/quantization/squeezellm/quant_cuda_kernel.cu",
|
||||
],
|
||||
extra_compile_args={
|
||||
"cxx": CXX_FLAGS,
|
||||
"nvcc": NVCC_FLAGS,
|
||||
},
|
||||
)
|
||||
ext_modules.append(quantization_extension)
|
||||
|
||||
# Misc. CUDA utils.
|
||||
cuda_utils_extension = CUDAExtension(
|
||||
name="vllm.cuda_utils",
|
||||
sources=["csrc/cuda_utils.cpp", "csrc/cuda_utils_kernels.cu"],
|
||||
extra_compile_args={
|
||||
"cxx": CXX_FLAGS,
|
||||
"nvcc": NVCC_FLAGS,
|
||||
},
|
||||
)
|
||||
ext_modules.append(cuda_utils_extension)
|
||||
|
||||
|
||||
def get_path(*filepath) -> str:
|
||||
return os.path.join(ROOT_DIR, *filepath)
|
||||
|
||||
|
||||
def find_version(filepath: str):
|
||||
def find_version(filepath: str) -> str:
|
||||
"""Extract version information from the given filepath.
|
||||
|
||||
Adapted from https://github.com/ray-project/ray/blob/0b190ee1160eeca9796bc091e07eaebf4c85b511/python/setup.py
|
||||
"""
|
||||
with open(filepath) as fp:
|
||||
version_match = re.search(
|
||||
r"^__version__ = ['\"]([^'\"]*)['\"]", fp.read(), re.M)
|
||||
version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]",
|
||||
fp.read(), re.M)
|
||||
if version_match:
|
||||
return version_match.group(1)
|
||||
raise RuntimeError("Unable to find version string.")
|
||||
|
||||
|
||||
def get_vllm_version() -> str:
|
||||
version = find_version(get_path("vllm", "__init__.py"))
|
||||
cuda_version = str(nvcc_cuda_version)
|
||||
if cuda_version != MAIN_CUDA_VERSION:
|
||||
cuda_version_str = cuda_version.replace(".", "")[:3]
|
||||
version += f"+cu{cuda_version_str}"
|
||||
return version
|
||||
|
||||
|
||||
def read_readme() -> str:
|
||||
"""Read the README file."""
|
||||
return io.open(get_path("README.md"), "r", encoding="utf-8").read()
|
||||
"""Read the README file if present."""
|
||||
p = get_path("README.md")
|
||||
if os.path.isfile(p):
|
||||
return io.open(get_path("README.md"), "r", encoding="utf-8").read()
|
||||
else:
|
||||
return ""
|
||||
|
||||
|
||||
def get_requirements() -> List[str]:
|
||||
@ -159,10 +267,11 @@ def get_requirements() -> List[str]:
|
||||
|
||||
setuptools.setup(
|
||||
name="vllm",
|
||||
version=find_version(get_path("vllm", "__init__.py")),
|
||||
version=get_vllm_version(),
|
||||
author="vLLM Team",
|
||||
license="Apache 2.0",
|
||||
description="A high-throughput and memory-efficient inference and serving engine for LLMs",
|
||||
description=("A high-throughput and memory-efficient inference and "
|
||||
"serving engine for LLMs"),
|
||||
long_description=read_readme(),
|
||||
long_description_content_type="text/markdown",
|
||||
url="https://github.com/vllm-project/vllm",
|
||||
@ -174,13 +283,15 @@ setuptools.setup(
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"License :: OSI Approved :: Apache Software License",
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||
],
|
||||
packages=setuptools.find_packages(
|
||||
exclude=("assets", "benchmarks", "csrc", "docs", "examples", "tests")),
|
||||
packages=setuptools.find_packages(exclude=("benchmarks", "csrc", "docs",
|
||||
"examples", "tests")),
|
||||
python_requires=">=3.8",
|
||||
install_requires=get_requirements(),
|
||||
ext_modules=ext_modules,
|
||||
cmdclass={"build_ext": BuildExtension},
|
||||
package_data={"vllm": ["py.typed"]},
|
||||
)
|
||||
|
0
tests/__init__.py
Normal file
@ -14,6 +14,7 @@ app = vllm.entrypoints.api_server.app
|
||||
|
||||
class AsyncLLMEngineWithStats(AsyncLLMEngine):
|
||||
|
||||
# pylint: disable=redefined-outer-name
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._num_aborts = 0
|
||||
@ -40,8 +41,7 @@ if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
engine = AsyncLLMEngineWithStats.from_engine_args(engine_args,
|
||||
start_engine_loop=False)
|
||||
engine = AsyncLLMEngineWithStats.from_engine_args(engine_args)
|
||||
vllm.entrypoints.api_server.engine = engine
|
||||
uvicorn.run(
|
||||
app,
|
||||
|
@ -24,6 +24,7 @@ def _query_server(prompt: str) -> dict:
|
||||
def api_server():
|
||||
script_path = Path(__file__).parent.joinpath(
|
||||
"api_server_async_engine.py").absolute()
|
||||
# pylint: disable=consider-using-with
|
||||
uvicorn_process = subprocess.Popen([
|
||||
sys.executable, "-u",
|
||||
str(script_path), "--model", "facebook/opt-125m"
|
||||
@ -32,6 +33,7 @@ def api_server():
|
||||
uvicorn_process.terminate()
|
||||
|
||||
|
||||
# pylint: disable=redefined-outer-name, unused-argument
|
||||
def test_api_server(api_server):
|
||||
"""
|
||||
Run the API server and test it.
|
||||
@ -47,6 +49,7 @@ def test_api_server(api_server):
|
||||
prompts = ["Hello world"] * 1
|
||||
result = None
|
||||
while not result:
|
||||
# pylint: disable=bare-except
|
||||
try:
|
||||
for result in pool.map(_query_server, prompts):
|
||||
break
|
||||
|
80
tests/async_engine/test_async_llm_engine.py
Normal file
@ -0,0 +1,80 @@
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestOutput:
|
||||
request_id: int
|
||||
finished: bool = False
|
||||
|
||||
|
||||
class MockEngine:
|
||||
|
||||
def __init__(self):
|
||||
self.step_calls = 0
|
||||
self.add_request_calls = 0
|
||||
self.abort_request_calls = 0
|
||||
self.request_id = None
|
||||
|
||||
async def step_async(self):
|
||||
self.step_calls += 1
|
||||
return [RequestOutput(
|
||||
request_id=self.request_id)] if self.request_id else []
|
||||
|
||||
def generate(self, request_id):
|
||||
self.request_id = request_id
|
||||
|
||||
def stop_generating(self):
|
||||
self.request_id = None
|
||||
|
||||
def add_request(self, **kwargs):
|
||||
del kwargs # Unused
|
||||
self.add_request_calls += 1
|
||||
|
||||
def abort_request(self, request_id):
|
||||
del request_id # Unused
|
||||
self.abort_request_calls += 1
|
||||
|
||||
|
||||
class MockAsyncLLMEngine(AsyncLLMEngine):
|
||||
|
||||
def _init_engine(self, *args, **kwargs):
|
||||
return MockEngine()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_requests_event():
|
||||
engine = MockAsyncLLMEngine(worker_use_ray=False, engine_use_ray=False)
|
||||
engine.start_background_loop()
|
||||
await asyncio.sleep(0.01)
|
||||
assert engine.engine.step_calls == 0
|
||||
|
||||
await engine.add_request("1", "", None)
|
||||
await asyncio.sleep(0.01)
|
||||
assert engine.engine.add_request_calls == 1
|
||||
assert engine.engine.step_calls == 1
|
||||
|
||||
await engine.add_request("2", "", None)
|
||||
engine.engine.generate("2")
|
||||
await asyncio.sleep(0)
|
||||
assert engine.engine.add_request_calls == 2
|
||||
assert engine.engine.step_calls == 2
|
||||
await asyncio.sleep(0)
|
||||
assert engine.engine.step_calls == 3
|
||||
engine.engine.stop_generating()
|
||||
await asyncio.sleep(0)
|
||||
assert engine.engine.step_calls == 4
|
||||
await asyncio.sleep(0)
|
||||
assert engine.engine.step_calls == 4
|
||||
|
||||
await engine.add_request("3", "", None)
|
||||
await asyncio.sleep(0.01)
|
||||
assert engine.engine.add_request_calls == 3
|
||||
assert engine.engine.step_calls == 5
|
||||
await asyncio.sleep(0.01)
|
||||
assert engine.engine.add_request_calls == 3
|
||||
assert engine.engine.step_calls == 5
|
@ -4,10 +4,25 @@ from vllm.engine.async_llm_engine import RequestTracker
|
||||
from vllm.outputs import RequestOutput
|
||||
|
||||
|
||||
class DummyEvent:
|
||||
|
||||
def __init__(self):
|
||||
self.flag = False
|
||||
|
||||
def set(self):
|
||||
self.flag = True
|
||||
|
||||
def clear(self):
|
||||
self.flag = False
|
||||
|
||||
|
||||
def test_request_tracker():
|
||||
tracker = RequestTracker()
|
||||
tracker.new_requests_event = DummyEvent()
|
||||
stream_1 = tracker.add_request("1")
|
||||
assert tracker.new_requests_event.flag
|
||||
new, finished = tracker.get_new_and_finished_requests()
|
||||
assert not tracker.new_requests_event.flag
|
||||
assert len(new) == 1
|
||||
assert new[0]["request_id"] == "1"
|
||||
assert not finished
|
||||
@ -15,7 +30,9 @@ def test_request_tracker():
|
||||
|
||||
stream_2 = tracker.add_request("2")
|
||||
stream_3 = tracker.add_request("3")
|
||||
assert tracker.new_requests_event.flag
|
||||
new, finished = tracker.get_new_and_finished_requests()
|
||||
assert not tracker.new_requests_event.flag
|
||||
assert len(new) == 2
|
||||
assert new[0]["request_id"] == "2"
|
||||
assert new[1]["request_id"] == "3"
|
||||
@ -26,6 +43,7 @@ def test_request_tracker():
|
||||
# request_ids must be unique
|
||||
with pytest.raises(KeyError):
|
||||
tracker.add_request("1")
|
||||
assert not tracker.new_requests_event.flag
|
||||
|
||||
tracker.abort_request("1")
|
||||
new, finished = tracker.get_new_and_finished_requests()
|
||||
@ -36,6 +54,7 @@ def test_request_tracker():
|
||||
|
||||
stream_4 = tracker.add_request("4")
|
||||
tracker.abort_request("4")
|
||||
assert tracker.new_requests_event.flag
|
||||
new, finished = tracker.get_new_and_finished_requests()
|
||||
assert len(finished) == 1
|
||||
assert "4" in finished
|
||||
@ -43,9 +62,11 @@ def test_request_tracker():
|
||||
assert stream_4.finished
|
||||
|
||||
stream_5 = tracker.add_request("5")
|
||||
assert tracker.new_requests_event.flag
|
||||
tracker.process_request_output(
|
||||
RequestOutput("2", "output", [], [], finished=True))
|
||||
RequestOutput("2", "output", [], [], [], finished=True))
|
||||
new, finished = tracker.get_new_and_finished_requests()
|
||||
assert not tracker.new_requests_event.flag
|
||||
assert len(finished) == 1
|
||||
assert "2" in finished
|
||||
assert len(new) == 1
|
||||
|
@ -8,6 +8,7 @@ from vllm import LLM, SamplingParams
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
_TEST_PROMPTS = [
|
||||
# pylint: disable=line-too-long
|
||||
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs.",
|
||||
"Briefly describe the major milestones in the development of artificial intelligence from 1950 to 2020.",
|
||||
"Compare and contrast artificial intelligence with human intelligence in terms of processing information.",
|
||||
@ -106,6 +107,39 @@ class HfRunner:
|
||||
outputs[i] = (output_ids, output_str)
|
||||
return outputs
|
||||
|
||||
def generate_greedy_logprobs(
|
||||
self,
|
||||
prompts: List[str],
|
||||
max_tokens: int,
|
||||
) -> List[List[torch.Tensor]]:
|
||||
all_logprobs = []
|
||||
for prompt in prompts:
|
||||
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
|
||||
output = self.model.generate(
|
||||
input_ids.cuda(),
|
||||
use_cache=True,
|
||||
do_sample=False,
|
||||
max_new_tokens=max_tokens,
|
||||
output_hidden_states=True,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
seq_logprobs = []
|
||||
for hidden_states in output.hidden_states:
|
||||
last_hidden_states = hidden_states[-1][0]
|
||||
logits = torch.matmul(
|
||||
last_hidden_states,
|
||||
self.model.get_output_embeddings().weight.t(),
|
||||
)
|
||||
if self.model.get_output_embeddings().bias is not None:
|
||||
logits += self.model.get_output_embeddings(
|
||||
).bias.unsqueeze(0)
|
||||
logprobs = torch.nn.functional.log_softmax(logits,
|
||||
dim=-1,
|
||||
dtype=torch.float32)
|
||||
seq_logprobs.append(logprobs)
|
||||
all_logprobs.append(seq_logprobs)
|
||||
return all_logprobs
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def hf_runner():
|
||||
|
83
tests/distributed/test_comm_ops.py
Normal file
@ -0,0 +1,83 @@
|
||||
"""Test the communication operators.
|
||||
|
||||
Run `pytest tests/distributed/test_comm_ops.py --forked`.
|
||||
"""
|
||||
from multiprocessing import Process, set_start_method
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.engine.ray_utils import get_open_port
|
||||
from vllm.model_executor.parallel_utils.communication_op import (
|
||||
tensor_model_parallel_all_reduce,
|
||||
tensor_model_parallel_all_gather,
|
||||
)
|
||||
from vllm.worker.worker import _init_distributed_environment
|
||||
|
||||
|
||||
def init_test_distributed_environment(pipeline_parallel_size: int,
|
||||
tensor_parallel_size: int, rank: int,
|
||||
distributed_init_port: str):
|
||||
parallel_config = ParallelConfig(pipeline_parallel_size,
|
||||
tensor_parallel_size,
|
||||
worker_use_ray=True)
|
||||
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
|
||||
torch.cuda.set_device(rank)
|
||||
_init_distributed_environment(parallel_config, rank,
|
||||
distributed_init_method)
|
||||
|
||||
|
||||
def all_reduce_test_worker(tensor_parallel_size: int, rank: int,
|
||||
distributed_init_port: str):
|
||||
init_test_distributed_environment(1, tensor_parallel_size, rank,
|
||||
distributed_init_port)
|
||||
num_elements = 8
|
||||
all_tensors = [
|
||||
torch.arange(num_elements, dtype=torch.float32, device="cuda") *
|
||||
(r + 1) for r in range(tensor_parallel_size)
|
||||
]
|
||||
expected = torch.sum(torch.stack(all_tensors, dim=0), dim=0)
|
||||
t = all_tensors[rank]
|
||||
t = tensor_model_parallel_all_reduce(t)
|
||||
assert torch.allclose(t, expected)
|
||||
|
||||
|
||||
def all_gather_test_worker(tensor_parallel_size: int, rank: int,
|
||||
distributed_init_port: str):
|
||||
init_test_distributed_environment(1, tensor_parallel_size, rank,
|
||||
distributed_init_port)
|
||||
num_dimensions = 3
|
||||
tensor_size = list(range(2, num_dimensions + 2))
|
||||
total_size = 1
|
||||
for s in tensor_size:
|
||||
total_size *= s
|
||||
for all_gather_dimension in range(num_dimensions):
|
||||
all_tensors = [
|
||||
torch.arange(total_size, dtype=torch.float32,
|
||||
device="cuda").reshape(tensor_size) * (r + 1)
|
||||
for r in range(tensor_parallel_size)
|
||||
]
|
||||
expected = torch.cat(all_tensors, dim=all_gather_dimension)
|
||||
t = all_tensors[rank]
|
||||
t = tensor_model_parallel_all_gather(t, all_gather_dimension)
|
||||
assert torch.allclose(t, expected)
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||
reason="Need at least 2 GPUs to run the test.")
|
||||
@pytest.mark.parametrize("tensor_parallel_size", [2])
|
||||
@pytest.mark.parametrize("test_target",
|
||||
[all_reduce_test_worker, all_gather_test_worker])
|
||||
def test_multi_process_tensor_parallel(tensor_parallel_size, test_target):
|
||||
set_start_method("spawn", force=True)
|
||||
distributed_init_port = get_open_port()
|
||||
processes = []
|
||||
for rank in range(tensor_parallel_size):
|
||||
p = Process(target=test_target,
|
||||
args=(tensor_parallel_size, rank, distributed_init_port))
|
||||
p.start()
|
||||
processes.append(p)
|
||||
for p in processes:
|
||||
p.join()
|
||||
assert all(p.exitcode == 0 for p in processes)
|
63
tests/engine/test_detokenize.py
Normal file
@ -0,0 +1,63 @@
|
||||
import pytest
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm.transformers_utils.tokenizer import detokenize_incrementally
|
||||
|
||||
TRUTH = [
|
||||
# pylint: disable=line-too-long
|
||||
"Hello here, this is a simple test",
|
||||
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. It is designed to be used in production environments, where inference and serving",
|
||||
"我很感谢你的热情"
|
||||
]
|
||||
TOKENIZERS = [
|
||||
"facebook/opt-125m",
|
||||
"gpt2",
|
||||
"bigcode/tiny_starcoder_py",
|
||||
"EleutherAI/gpt-j-6b",
|
||||
"EleutherAI/pythia-70m",
|
||||
"bigscience/bloom-560m",
|
||||
"mosaicml/mpt-7b",
|
||||
"tiiuae/falcon-7b",
|
||||
"meta-llama/Llama-2-7b-hf",
|
||||
"codellama/CodeLlama-7b-hf",
|
||||
]
|
||||
|
||||
|
||||
def _run_incremental_decode(tokenizer, all_input_ids,
|
||||
skip_special_tokens: bool):
|
||||
decoded_text = ""
|
||||
offset = 0
|
||||
token_offset = 0
|
||||
prev_tokens = None
|
||||
for i in range(len(all_input_ids)):
|
||||
new_tokens, text, offset, token_offset = detokenize_incrementally(
|
||||
tokenizer,
|
||||
all_input_ids[:i + 1],
|
||||
prev_tokens,
|
||||
offset,
|
||||
token_offset,
|
||||
skip_special_tokens=skip_special_tokens)
|
||||
decoded_text += text
|
||||
if prev_tokens is None:
|
||||
prev_tokens = new_tokens
|
||||
else:
|
||||
prev_tokens += new_tokens
|
||||
return decoded_text
|
||||
|
||||
|
||||
@pytest.mark.parametrize("truth", TRUTH)
|
||||
@pytest.mark.parametrize("tokenizer_id", TOKENIZERS)
|
||||
@pytest.mark.parametrize("skip_special_tokens", (True, False))
|
||||
def test_decode_streaming(tokenizer_id, truth, skip_special_tokens):
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
|
||||
all_input_ids = tokenizer(truth, add_special_tokens=False)["input_ids"]
|
||||
if skip_special_tokens:
|
||||
all_input_ids = ([tokenizer.bos_token_id]
|
||||
if tokenizer.bos_token_id is not None else
|
||||
[]) + all_input_ids + [tokenizer.eos_token_id]
|
||||
|
||||
decoded_text = _run_incremental_decode(
|
||||
tokenizer, all_input_ids, skip_special_tokens=skip_special_tokens)
|
||||
|
||||
assert decoded_text == truth
|
@ -29,8 +29,8 @@ def test_silu_and_mul(
|
||||
) -> None:
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
x = torch.randn(num_tokens, 2 * d, dtype=dtype, device='cuda')
|
||||
out = torch.empty(num_tokens, d, dtype=dtype, device='cuda')
|
||||
x = torch.randn(num_tokens, 2 * d, dtype=dtype, device="cuda")
|
||||
out = torch.empty(num_tokens, d, dtype=dtype, device="cuda")
|
||||
activation_ops.silu_and_mul(out, x)
|
||||
ref_out = ref_silu_and_mul(x)
|
||||
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
|
||||
@ -49,8 +49,8 @@ def test_gelu_new(
|
||||
) -> None:
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
x = torch.randn(num_tokens, d, dtype=dtype, device='cuda')
|
||||
out = torch.empty(num_tokens, d, dtype=dtype, device='cuda')
|
||||
x = torch.randn(num_tokens, d, dtype=dtype, device="cuda")
|
||||
out = torch.empty(num_tokens, d, dtype=dtype, device="cuda")
|
||||
activation_ops.gelu_new(out, x)
|
||||
ref_out = get_activation("gelu_new")(x)
|
||||
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
|
||||
@ -68,8 +68,8 @@ def test_gelu_fast(
|
||||
) -> None:
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
x = torch.randn(num_tokens, d, dtype=dtype, device='cuda')
|
||||
out = torch.empty(num_tokens, d, dtype=dtype, device='cuda')
|
||||
x = torch.randn(num_tokens, d, dtype=dtype, device="cuda")
|
||||
out = torch.empty(num_tokens, d, dtype=dtype, device="cuda")
|
||||
activation_ops.gelu_fast(out, x)
|
||||
ref_out = get_activation("gelu_fast")(x)
|
||||
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
|
||||
|
@ -7,16 +7,21 @@ from xformers import ops as xops
|
||||
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
|
||||
|
||||
from vllm import attention_ops
|
||||
from vllm.utils import get_max_shared_memory_bytes
|
||||
|
||||
MAX_SEQ_LEN = 8192
|
||||
NUM_BLOCKS = 128 # Arbitrary values for testing
|
||||
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
||||
# This will change depending on the compute capability.
|
||||
# - 512 as a buffer
|
||||
MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
|
||||
NUM_BLOCKS = 40000 # Arbitrary values for testing
|
||||
PARTITION_SIZE = 512
|
||||
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||
NUM_GEN_SEQS = [7] # Arbitrary values for testing
|
||||
NUM_PREFILL_SEQS = [1, 3, 7] # Arbitrary values for testing
|
||||
NUM_PREFILL_SEQS = [3] # Arbitrary values for testing
|
||||
NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing
|
||||
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
|
||||
BLOCK_SIZES = [8, 16, 32]
|
||||
BLOCK_SIZES = [16, 32]
|
||||
USE_ALIBI = [False, True]
|
||||
SEEDS = [0]
|
||||
|
||||
@ -92,6 +97,7 @@ def ref_single_query_cached_kv_attention(
|
||||
output[i].copy_(out, non_blocking=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("version", ["v1", "v2"])
|
||||
@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@ -99,9 +105,9 @@ def ref_single_query_cached_kv_attention(
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@torch.inference_mode()
|
||||
def test_single_query_cached_kv_attention(
|
||||
def test_paged_attention(
|
||||
kv_cache_factory,
|
||||
version: str,
|
||||
num_seqs: int,
|
||||
num_heads: Tuple[int, int],
|
||||
head_size: int,
|
||||
@ -135,6 +141,7 @@ def test_single_query_cached_kv_attention(
|
||||
device="cuda")
|
||||
|
||||
context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
|
||||
context_lens[-1] = MAX_SEQ_LEN
|
||||
max_context_len = max(context_lens)
|
||||
context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda")
|
||||
|
||||
@ -157,19 +164,54 @@ def test_single_query_cached_kv_attention(
|
||||
|
||||
# Call the paged attention kernel.
|
||||
output = torch.empty_like(query)
|
||||
attention_ops.single_query_cached_kv_attention(
|
||||
output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
head_mapping,
|
||||
scale,
|
||||
block_tables,
|
||||
context_lens,
|
||||
block_size,
|
||||
max_context_len,
|
||||
alibi_slopes,
|
||||
)
|
||||
if version == "v1":
|
||||
attention_ops.paged_attention_v1(
|
||||
output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
head_mapping,
|
||||
scale,
|
||||
block_tables,
|
||||
context_lens,
|
||||
block_size,
|
||||
max_context_len,
|
||||
alibi_slopes,
|
||||
)
|
||||
elif version == "v2":
|
||||
num_partitions = ((max_context_len + PARTITION_SIZE - 1) //
|
||||
PARTITION_SIZE)
|
||||
assert PARTITION_SIZE % block_size == 0
|
||||
num_seqs, num_heads, head_size = output.shape
|
||||
tmp_output = torch.empty(
|
||||
size=(num_seqs, num_heads, num_partitions, head_size),
|
||||
dtype=output.dtype,
|
||||
device=output.device,
|
||||
)
|
||||
exp_sums = torch.empty(
|
||||
size=(num_seqs, num_heads, num_partitions),
|
||||
dtype=torch.float32,
|
||||
device=output.device,
|
||||
)
|
||||
max_logits = torch.empty_like(exp_sums)
|
||||
attention_ops.paged_attention_v2(
|
||||
output,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
head_mapping,
|
||||
scale,
|
||||
block_tables,
|
||||
context_lens,
|
||||
block_size,
|
||||
max_context_len,
|
||||
alibi_slopes,
|
||||
)
|
||||
else:
|
||||
assert False, f"Unknown version: {version}"
|
||||
|
||||
# Run the reference implementation.
|
||||
ref_output = torch.empty_like(query)
|
||||
@ -242,7 +284,11 @@ def test_multi_query_kv_attention(
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
seq_lens = random.sample(range(1, MAX_SEQ_LEN), num_seqs)
|
||||
# MAX_SEQ_LEN sometimes causes OOM in the reference implementation.
|
||||
# As the xformers library is already tested with its own tests, we can use
|
||||
# a smaller MAX_SEQ_LEN here.
|
||||
max_len = min(MAX_SEQ_LEN, 4096)
|
||||
seq_lens = random.sample(range(1, max_len), num_seqs)
|
||||
num_tokens = sum(seq_lens)
|
||||
|
||||
scale = float(1.0 / (head_size**0.5))
|
||||
|
@ -6,13 +6,13 @@ import torch
|
||||
from vllm import cache_ops
|
||||
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||
NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
|
||||
NUM_LAYERS = [5] # Arbitrary values for testing
|
||||
NUM_TOKENS = [83] # Arbitrary values for testing
|
||||
NUM_LAYERS = [1] # Arbitrary values for testing
|
||||
NUM_HEADS = [8] # Arbitrary values for testing
|
||||
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
|
||||
BLOCK_SIZES = [8, 16, 32]
|
||||
NUM_BLOCKS = [1024] # Arbitrary values for testing
|
||||
NUM_MAPPINGS = [32, 256] # Arbitrary values for testing
|
||||
NUM_BLOCKS = [1024, 36000] # Arbitrary values for testing
|
||||
NUM_MAPPINGS = [256] # Arbitrary values for testing
|
||||
SEEDS = [0]
|
||||
|
||||
|
||||
@ -69,9 +69,9 @@ def test_copy_blocks(
|
||||
for src, dsts in block_mapping.items():
|
||||
for dst in dsts:
|
||||
for cloned_key_cache in cloned_key_caches:
|
||||
cloned_key_cache[dst] = cloned_key_cache[src]
|
||||
cloned_key_cache[dst].copy_(cloned_key_cache[src])
|
||||
for cloned_value_cache in cloned_value_caches:
|
||||
cloned_value_cache[dst] = cloned_value_cache[src]
|
||||
cloned_value_cache[dst].copy_(cloned_value_cache[src])
|
||||
|
||||
# Compare the results.
|
||||
for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
|
||||
@ -106,14 +106,14 @@ def test_reshape_and_cache(
|
||||
# Create a random slot mapping.
|
||||
num_slots = block_size * num_blocks
|
||||
slot_mapping = random.sample(range(num_slots), num_tokens)
|
||||
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device='cuda')
|
||||
slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device="cuda")
|
||||
|
||||
qkv = torch.randn(num_tokens,
|
||||
3,
|
||||
num_heads,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device='cuda')
|
||||
device="cuda")
|
||||
_, key, value = qkv.unbind(dim=1)
|
||||
|
||||
# Create the KV caches.
|
||||
@ -132,7 +132,7 @@ def test_reshape_and_cache(
|
||||
|
||||
# Run the reference implementation.
|
||||
reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
|
||||
block_indicies = torch.div(slot_mapping, block_size, rounding_mode='floor')
|
||||
block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor")
|
||||
block_indicies = block_indicies.cpu().tolist()
|
||||
block_offsets = slot_mapping % block_size
|
||||
block_offsets = block_offsets.cpu().tolist()
|
||||
|
@ -133,13 +133,14 @@ def test_rotary_embedding(
|
||||
device="cuda")
|
||||
|
||||
# Create the rotary embedding.
|
||||
inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim))
|
||||
inv_freq = 1.0 / (base**(
|
||||
torch.arange(0, rotary_dim, 2, dtype=torch.float) / rotary_dim))
|
||||
t = torch.arange(max_position).float()
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq.float())
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
cos_sin_cache = torch.cat((cos, sin), dim=-1)
|
||||
cos_sin_cache = cos_sin_cache.to(dtype=dtype, device='cuda')
|
||||
cos_sin_cache = cos_sin_cache.to(dtype=dtype, device="cuda")
|
||||
|
||||
# Run the kernel. The kernel is in-place, so we need to clone the inputs.
|
||||
out_query = query.clone()
|
||||
|
@ -6,14 +6,16 @@ import pytest
|
||||
|
||||
MODELS = [
|
||||
"facebook/opt-125m",
|
||||
"meta-llama/Llama-2-7b-hf",
|
||||
"mistralai/Mistral-7B-v0.1",
|
||||
"tiiuae/falcon-7b",
|
||||
"gpt2",
|
||||
"bigcode/tiny_starcoder_py",
|
||||
"EleutherAI/gpt-j-6b",
|
||||
"EleutherAI/pythia-70m",
|
||||
"bigscience/bloom-560m",
|
||||
"mosaicml/mpt-7b",
|
||||
"tiiuae/falcon-7b",
|
||||
"meta-llama/Llama-2-7b-hf",
|
||||
"microsoft/phi-1_5",
|
||||
]
|
||||
|
||||
|
||||
|
55
tests/samplers/test_logprobs.py
Normal file
@ -0,0 +1,55 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import SamplingParams
|
||||
|
||||
MODELS = ["facebook/opt-125m"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
def test_get_prompt_logprobs(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
model,
|
||||
dtype,
|
||||
example_prompts,
|
||||
):
|
||||
max_tokens = 5
|
||||
hf_model = hf_runner(model, dtype=dtype)
|
||||
hf_logprobs = hf_model.generate_greedy_logprobs(
|
||||
example_prompts,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
del hf_model
|
||||
|
||||
vllm_model = vllm_runner(model, dtype=dtype)
|
||||
vllm_sampling_params = SamplingParams(max_tokens=max_tokens,
|
||||
logprobs=5,
|
||||
prompt_logprobs=5,
|
||||
temperature=0.0)
|
||||
vllm_results = vllm_model.model.generate(
|
||||
example_prompts, sampling_params=vllm_sampling_params)
|
||||
|
||||
# Test whether logprobs are included in the results.
|
||||
for result in vllm_results:
|
||||
assert result.prompt_logprobs is not None
|
||||
assert result.outputs[0].logprobs is not None
|
||||
|
||||
# Test whether prompt logprobs are consistent with HF
|
||||
for vllm_result, hf_logprob in zip(vllm_results, hf_logprobs):
|
||||
# Check prompt logprobs
|
||||
vllm_prompt_logprobs = vllm_result.prompt_logprobs[1:]
|
||||
for i, vllm_prompt_logprob_dict in enumerate(vllm_prompt_logprobs):
|
||||
for token_id, logprob in vllm_prompt_logprob_dict.items():
|
||||
torch.testing.assert_close(logprob,
|
||||
hf_logprob[0][i][token_id].item(),
|
||||
atol=1e-2,
|
||||
rtol=1e-2)
|
||||
vllm_sample_logprobs = vllm_result.outputs[0].logprobs
|
||||
for i, vllm_sample_logprob_dict in enumerate(vllm_sample_logprobs):
|
||||
for token_id, logprob in vllm_sample_logprob_dict.items():
|
||||
torch.testing.assert_close(logprob,
|
||||
hf_logprob[i][-1][token_id].item(),
|
||||
atol=1e-2,
|
||||
rtol=1e-2)
|
219
tests/samplers/test_sampler.py
Normal file
@ -0,0 +1,219 @@
|
||||
# pylint: disable=protected-access
|
||||
import random
|
||||
from typing import Tuple
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
|
||||
from vllm.worker.worker import Worker
|
||||
|
||||
|
||||
class MockLogitsSampler(Sampler):
|
||||
|
||||
def __init__(self, vocab_size: int, fake_logits: torch.Tensor):
|
||||
super().__init__(vocab_size=vocab_size)
|
||||
self.fake_logits = fake_logits
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
with patch("vllm.model_executor.layers.sampler._prune_hidden_states",
|
||||
lambda x, y: x):
|
||||
with patch("vllm.model_executor.layers.sampler._get_logits",
|
||||
lambda *args, **kwargs: self.fake_logits):
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
|
||||
def _prepare_test(
|
||||
batch_size: int
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler, Worker]:
|
||||
vocab_size = 32000
|
||||
input_tensor = torch.rand((batch_size, 1024),
|
||||
device="cuda",
|
||||
dtype=torch.float16)
|
||||
fake_logits = torch.full((batch_size, vocab_size),
|
||||
1e-2,
|
||||
device=input_tensor.device,
|
||||
dtype=input_tensor.dtype)
|
||||
sampler = MockLogitsSampler(32000, fake_logits)
|
||||
worker = Worker(None, None, None)
|
||||
worker.block_size = 16
|
||||
return input_tensor, fake_logits, sampler, worker
|
||||
|
||||
|
||||
RANDOM_SEEDS = list(range(128))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
def test_sampler_all_greedy(seed: int):
|
||||
set_random_seed(seed)
|
||||
batch_size = random.randint(1, 256)
|
||||
input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size)
|
||||
|
||||
seq_group_metadata_list = []
|
||||
for i in range(batch_size):
|
||||
seq_group_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
seq_data={0: SequenceData([1, 2, 3])},
|
||||
sampling_params=SamplingParams(temperature=0, ),
|
||||
block_tables={0: [1]},
|
||||
))
|
||||
|
||||
_, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
|
||||
sampler_output = sampler(embedding=None,
|
||||
hidden_states=input_tensor,
|
||||
input_metadata=input_metadata)
|
||||
expected = torch.argmax(fake_logits, dim=-1)
|
||||
for i, sequence_output in enumerate(sampler_output):
|
||||
for nth_output in sequence_output.samples:
|
||||
assert nth_output.output_token == expected[i].item()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
def test_sampler_all_random(seed: int):
|
||||
set_random_seed(seed)
|
||||
batch_size = random.randint(1, 256)
|
||||
input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size)
|
||||
|
||||
for i in range(batch_size):
|
||||
fake_logits[i, i] = 1e2
|
||||
|
||||
seq_group_metadata_list = []
|
||||
for i in range(batch_size):
|
||||
seq_group_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
seq_data={0: SequenceData([1, 2, 3])},
|
||||
sampling_params=SamplingParams(
|
||||
temperature=1.0,
|
||||
n=random.randint(1, 10),
|
||||
),
|
||||
block_tables={0: [1]},
|
||||
))
|
||||
|
||||
_, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
|
||||
sampler_output = sampler(embedding=None,
|
||||
hidden_states=input_tensor,
|
||||
input_metadata=input_metadata)
|
||||
for i, sequence_output in enumerate(sampler_output):
|
||||
for nth_output in sequence_output.samples:
|
||||
assert nth_output.output_token == i
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
def test_sampler_all_beam(seed: int):
|
||||
set_random_seed(seed)
|
||||
batch_size = random.randint(1, 256)
|
||||
input_tensor, _, sampler, worker = _prepare_test(batch_size)
|
||||
|
||||
seq_group_metadata_list = []
|
||||
for i in range(batch_size):
|
||||
seq_group_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
seq_data={0: SequenceData([1, 2, 3])},
|
||||
sampling_params=SamplingParams(
|
||||
temperature=0,
|
||||
best_of=2,
|
||||
use_beam_search=True,
|
||||
),
|
||||
block_tables={0: [1]},
|
||||
))
|
||||
|
||||
_, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
|
||||
sampler(embedding=None,
|
||||
hidden_states=input_tensor,
|
||||
input_metadata=input_metadata)
|
||||
# no assertion here as I am not sure how to determine whether
|
||||
# the outputs are expected - in other words, this just tests
|
||||
# whether there are no exceptions in the sampler
|
||||
# when handling an all-beam search case.
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
def test_sampler_mixed(seed: int):
|
||||
set_random_seed(seed)
|
||||
batch_size = random.randint(1, 256)
|
||||
input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size)
|
||||
|
||||
seq_group_metadata_list = []
|
||||
expected_tokens = []
|
||||
for i in range(batch_size):
|
||||
n = 1
|
||||
sampling_type = random.randint(0, 2)
|
||||
if sampling_type == 0:
|
||||
sampling_params = SamplingParams(temperature=0)
|
||||
elif sampling_type == 1:
|
||||
n = random.randint(1, 10)
|
||||
sampling_params = SamplingParams(
|
||||
temperature=random.random() + 0.1,
|
||||
top_p=min(random.random() + 0.1, 1),
|
||||
top_k=random.randint(0, 10) or -1,
|
||||
n=n,
|
||||
presence_penalty=random.randint(0, 1),
|
||||
)
|
||||
else:
|
||||
sampling_params = SamplingParams(temperature=0,
|
||||
use_beam_search=True,
|
||||
best_of=2)
|
||||
for idx in range(n):
|
||||
fake_logits[i, i + idx] = 1e2
|
||||
expected_tokens.append(i + idx)
|
||||
seq_group_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
seq_data={0: SequenceData([1, 2, 3])},
|
||||
sampling_params=sampling_params,
|
||||
block_tables={0: [1]},
|
||||
))
|
||||
|
||||
_, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
|
||||
sampler_output = sampler(embedding=None,
|
||||
hidden_states=input_tensor,
|
||||
input_metadata=input_metadata)
|
||||
for i, sequence_output in enumerate(sampler_output):
|
||||
if seq_group_metadata_list[i].sampling_params.use_beam_search:
|
||||
continue
|
||||
for nth_output in sequence_output.samples:
|
||||
assert nth_output.output_token in expected_tokens
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
def test_sampler_logits_processors(seed: int):
|
||||
set_random_seed(seed)
|
||||
batch_size = random.randint(1, 256)
|
||||
input_tensor, _, sampler, worker = _prepare_test(batch_size)
|
||||
|
||||
# This sample logits processor gives infinite score to the i-th token,
|
||||
# where i is the length of the input sequence.
|
||||
# We therefore expect the output token sequence to be [0, 1, 2, ...]
|
||||
def pick_ith(token_ids, logits):
|
||||
logits[len(token_ids)] = float("inf")
|
||||
return logits
|
||||
|
||||
seq_group_metadata_list = []
|
||||
for i in range(batch_size):
|
||||
seq_group_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
seq_data={0: SequenceData([1, 2, 3])},
|
||||
sampling_params=SamplingParams(temperature=0,
|
||||
logits_processors=[pick_ith]),
|
||||
block_tables={0: [1]},
|
||||
))
|
||||
|
||||
_, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
|
||||
sampler_output = sampler(embedding=None,
|
||||
hidden_states=input_tensor,
|
||||
input_metadata=input_metadata)
|
||||
for i, sequence_output in enumerate(sampler_output):
|
||||
for idx, nth_output in enumerate(sequence_output.samples):
|
||||
assert nth_output.output_token == idx
|
27
tests/test_regression.py
Normal file
@ -0,0 +1,27 @@
|
||||
"""Containing tests that check for regressions in vLLM's behavior.
|
||||
|
||||
It should include tests that are reported by users and making sure they
|
||||
will never happen again.
|
||||
|
||||
"""
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
|
||||
def test_duplicated_ignored_sequence_group():
|
||||
"""https://github.com/vllm-project/vllm/issues/1655"""
|
||||
|
||||
sampling_params = SamplingParams(temperature=0.01,
|
||||
top_p=0.1,
|
||||
max_tokens=256)
|
||||
llm = LLM(model="facebook/opt-125m",
|
||||
max_num_batched_tokens=4096,
|
||||
tensor_parallel_size=1)
|
||||
prompts = ["This is a short prompt", "This is a very long prompt " * 1000]
|
||||
outputs = llm.generate(prompts, sampling_params=sampling_params)
|
||||
|
||||
assert len(prompts) == len(outputs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
pytest.main([__file__])
|
44
tests/worker/test_worker.py
Normal file
@ -0,0 +1,44 @@
|
||||
# pylint: disable=protected-access
|
||||
import random
|
||||
import torch
|
||||
|
||||
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
|
||||
from vllm.worker.worker import Worker
|
||||
|
||||
|
||||
def test_worker_prepare_inputs_for_prompt():
|
||||
worker = Worker(None, None, None)
|
||||
worker.block_size = 16
|
||||
batch_size = random.randint(1, 256)
|
||||
prompt_lens = []
|
||||
seq_group_metadata_list = []
|
||||
for i in range(batch_size):
|
||||
# make sure all tokens fit into one block
|
||||
prompt_len = i % (worker.block_size - 1) + 1
|
||||
prompt_lens.append(prompt_len)
|
||||
seq_data = list(range(prompt_len))
|
||||
seq_group_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
seq_data={0: SequenceData(seq_data)},
|
||||
sampling_params=SamplingParams(temperature=0),
|
||||
block_tables={0: [1]},
|
||||
))
|
||||
expected_selected_token_indices = []
|
||||
selected_token_start_idx = 0
|
||||
max_seq_len = max(prompt_lens)
|
||||
for prompt_len in prompt_lens:
|
||||
expected_selected_token_indices.append(selected_token_start_idx +
|
||||
prompt_len - 1)
|
||||
selected_token_start_idx += max_seq_len
|
||||
input_tokens, input_positions, input_metadata = worker._prepare_inputs(
|
||||
seq_group_metadata_list)
|
||||
assert input_tokens.shape == input_positions.shape == (batch_size,
|
||||
max_seq_len)
|
||||
torch.testing.assert_close(input_tokens, input_positions)
|
||||
actual = input_metadata.selected_token_indices
|
||||
expected = torch.tensor(expected_selected_token_indices,
|
||||
device=actual.device,
|
||||
dtype=actual.dtype)
|
||||
torch.testing.assert_close(actual, expected)
|
@ -8,7 +8,7 @@ from vllm.entrypoints.llm import LLM
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
__version__ = "0.1.5"
|
||||
__version__ = "0.2.2"
|
||||
|
||||
__all__ = [
|
||||
"LLM",
|
||||
|
257
vllm/config.py
@ -1,4 +1,5 @@
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
import os
|
||||
|
||||
import torch
|
||||
from transformers import PretrainedConfig
|
||||
@ -38,6 +39,16 @@ class ModelConfig:
|
||||
will use FP16 precision for FP32 and FP16 models, and BF16 precision
|
||||
for BF16 models.
|
||||
seed: Random seed for reproducibility.
|
||||
revision: The specific model version to use. It can be a branch name,
|
||||
a tag name, or a commit id. If unspecified, will use the default
|
||||
version.
|
||||
tokenizer_revision: The specific tokenizer version to use. It can be a
|
||||
branch name, a tag name, or a commit id. If unspecified, will use
|
||||
the default version.
|
||||
max_model_len: Maximum length of a sequence (including prompt and
|
||||
output). If None, will be derived from the model.
|
||||
quantization: Quantization method that was used to quantize the model
|
||||
weights. If None, we assume the model weights are not quantized.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -48,8 +59,12 @@ class ModelConfig:
|
||||
trust_remote_code: bool,
|
||||
download_dir: Optional[str],
|
||||
load_format: str,
|
||||
dtype: str,
|
||||
dtype: Union[str, torch.dtype],
|
||||
seed: int,
|
||||
revision: Optional[str] = None,
|
||||
tokenizer_revision: Optional[str] = None,
|
||||
max_model_len: Optional[int] = None,
|
||||
quantization: Optional[str] = None,
|
||||
) -> None:
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
@ -58,11 +73,28 @@ class ModelConfig:
|
||||
self.download_dir = download_dir
|
||||
self.load_format = load_format
|
||||
self.seed = seed
|
||||
self.revision = revision
|
||||
self.tokenizer_revision = tokenizer_revision
|
||||
self.quantization = quantization
|
||||
|
||||
self.hf_config = get_config(model, trust_remote_code)
|
||||
if os.environ.get("VLLM_USE_MODELSCOPE", "False").lower() == "true":
|
||||
# download model from ModelScope hub,
|
||||
# lazy import so that modelscope is not required for normal use.
|
||||
from modelscope.hub.snapshot_download import snapshot_download # pylint: disable=C
|
||||
model_path = snapshot_download(model_id=model,
|
||||
cache_dir=download_dir,
|
||||
revision=revision)
|
||||
self.model = model_path
|
||||
self.download_dir = model_path
|
||||
self.tokenizer = model_path
|
||||
|
||||
self.hf_config = get_config(self.model, trust_remote_code, revision)
|
||||
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
|
||||
self.max_model_len = _get_and_verify_max_len(self.hf_config,
|
||||
max_model_len)
|
||||
self._verify_load_format()
|
||||
self._verify_tokenizer_mode()
|
||||
self._verify_quantization()
|
||||
|
||||
def _verify_load_format(self) -> None:
|
||||
load_format = self.load_format.lower()
|
||||
@ -82,6 +114,33 @@ class ModelConfig:
|
||||
"either 'auto' or 'slow'.")
|
||||
self.tokenizer_mode = tokenizer_mode
|
||||
|
||||
def _verify_quantization(self) -> None:
|
||||
supported_quantization = ["awq", "squeezellm"]
|
||||
if self.quantization is not None:
|
||||
self.quantization = self.quantization.lower()
|
||||
|
||||
# Parse quantization method from the HF model config, if available.
|
||||
hf_quant_config = getattr(self.hf_config, "quantization_config", None)
|
||||
if hf_quant_config is not None:
|
||||
hf_quant_method = str(hf_quant_config["quant_method"]).lower()
|
||||
if self.quantization is None:
|
||||
self.quantization = hf_quant_method
|
||||
elif self.quantization != hf_quant_method:
|
||||
raise ValueError(
|
||||
"Quantization method specified in the model config "
|
||||
f"({hf_quant_method}) does not match the quantization "
|
||||
f"method specified in the `quantization` argument "
|
||||
f"({self.quantization}).")
|
||||
|
||||
if self.quantization is not None:
|
||||
if self.quantization not in supported_quantization:
|
||||
raise ValueError(
|
||||
f"Unknown quantization method: {self.quantization}. Must "
|
||||
f"be one of {supported_quantization}.")
|
||||
logger.warning(f"{self.quantization} quantization is not fully "
|
||||
"optimized yet. The speed can be slower than "
|
||||
"non-quantized models.")
|
||||
|
||||
def verify_with_parallel_config(
|
||||
self,
|
||||
parallel_config: "ParallelConfig",
|
||||
@ -109,48 +168,49 @@ class ModelConfig:
|
||||
# FIXME(woosuk): This may not be true for all models.
|
||||
return self.hf_config.hidden_size // self.hf_config.num_attention_heads
|
||||
|
||||
def get_num_heads(self, parallel_config: "ParallelConfig") -> int:
|
||||
def get_total_num_kv_heads(self) -> int:
|
||||
"""Returns the total number of KV heads."""
|
||||
# For GPTBigCode & Falcon:
|
||||
# Note: for falcon, when new_decoder_architecture is True, the
|
||||
# NOTE: for falcon, when new_decoder_architecture is True, the
|
||||
# multi_query flag is ignored and we use n_head_kv for the number of
|
||||
# KV heads.
|
||||
falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
|
||||
new_decoder_arch_falcon = (
|
||||
self.hf_config.model_type == "falcon"
|
||||
self.hf_config.model_type in falcon_model_types
|
||||
and getattr(self.hf_config, "new_decoder_architecture", False))
|
||||
if not new_decoder_arch_falcon and getattr(self.hf_config,
|
||||
"multi_query", False):
|
||||
# Multi-query attention, only one KV head.
|
||||
# Currently, tensor parallelism is not supported in this case.
|
||||
return 1
|
||||
# For Falcon:
|
||||
if getattr(self.hf_config, "n_head_kv", None) is not None:
|
||||
return (self.hf_config.n_head_kv //
|
||||
parallel_config.tensor_parallel_size)
|
||||
# For LLaMA-2:
|
||||
if getattr(self.hf_config, "num_key_value_heads", None) is not None:
|
||||
return (self.hf_config.num_key_value_heads //
|
||||
parallel_config.tensor_parallel_size)
|
||||
total_num_attention_heads = self.hf_config.num_attention_heads
|
||||
return total_num_attention_heads // parallel_config.tensor_parallel_size
|
||||
|
||||
def get_max_model_len(self) -> int:
|
||||
max_model_len = float("inf")
|
||||
possible_keys = [
|
||||
# OPT
|
||||
"max_position_embeddings",
|
||||
# GPT-2
|
||||
"n_positions",
|
||||
# MPT
|
||||
"max_seq_len",
|
||||
# Others
|
||||
"max_sequence_length",
|
||||
"max_seq_length",
|
||||
"seq_len",
|
||||
attributes = [
|
||||
# For Falcon:
|
||||
"n_head_kv",
|
||||
"num_kv_heads",
|
||||
# For LLaMA-2:
|
||||
"num_key_value_heads",
|
||||
# For ChatGLM:
|
||||
"multi_query_group_num",
|
||||
]
|
||||
for key in possible_keys:
|
||||
max_len_key = getattr(self.hf_config, key, None)
|
||||
if max_len_key is not None:
|
||||
max_model_len = min(max_model_len, max_len_key)
|
||||
return max_model_len
|
||||
for attr in attributes:
|
||||
num_kv_heads = getattr(self.hf_config, attr, None)
|
||||
if num_kv_heads is not None:
|
||||
return num_kv_heads
|
||||
|
||||
# For non-grouped-query attention models, the number of KV heads is
|
||||
# equal to the number of attention heads.
|
||||
return self.hf_config.num_attention_heads
|
||||
|
||||
def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
|
||||
"""Returns the number of KV heads per GPU."""
|
||||
total_num_kv_heads = self.get_total_num_kv_heads()
|
||||
# If tensor parallelism is used, we divide the number of KV heads by
|
||||
# the tensor parallel size. We will replicate the KV heads in the
|
||||
# case where the number of KV heads is smaller than the tensor
|
||||
# parallel size so each GPU has at least one KV head.
|
||||
return max(1,
|
||||
total_num_kv_heads // parallel_config.tensor_parallel_size)
|
||||
|
||||
def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
|
||||
total_num_hidden_layers = self.hf_config.num_hidden_layers
|
||||
@ -172,10 +232,12 @@ class CacheConfig:
|
||||
block_size: int,
|
||||
gpu_memory_utilization: float,
|
||||
swap_space: int,
|
||||
sliding_window: Optional[int] = None,
|
||||
) -> None:
|
||||
self.block_size = block_size
|
||||
self.gpu_memory_utilization = gpu_memory_utilization
|
||||
self.swap_space_bytes = swap_space * _GB
|
||||
self.sliding_window = sliding_window
|
||||
self._verify_args()
|
||||
|
||||
# Will be set after profiling.
|
||||
@ -249,13 +311,41 @@ class SchedulerConfig:
|
||||
iteration.
|
||||
max_model_len: Maximum length of a sequence (including prompt
|
||||
and generated text).
|
||||
max_paddings: Maximum number of paddings to be added to a batch.
|
||||
"""
|
||||
|
||||
def __init__(self, max_num_batched_tokens: int, max_num_seqs: int,
|
||||
max_model_len: int) -> None:
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
def __init__(
|
||||
self,
|
||||
max_num_batched_tokens: Optional[int],
|
||||
max_num_seqs: int,
|
||||
max_model_len: int,
|
||||
max_paddings: int,
|
||||
) -> None:
|
||||
if max_num_batched_tokens is not None:
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
else:
|
||||
# If max_model_len is too short, use 2048 as the default value for
|
||||
# higher throughput.
|
||||
self.max_num_batched_tokens = max(max_model_len, 2048)
|
||||
self.max_num_seqs = max_num_seqs
|
||||
self.max_model_len = max_model_len
|
||||
self.max_paddings = max_paddings
|
||||
self._verify_args()
|
||||
|
||||
def _verify_args(self) -> None:
|
||||
if self.max_num_batched_tokens < self.max_model_len:
|
||||
raise ValueError(
|
||||
f"max_num_batched_tokens ({self.max_num_batched_tokens}) is "
|
||||
f"smaller than max_model_len ({self.max_model_len}). "
|
||||
"This effectively limits the maximum sequence length to "
|
||||
"max_num_batched_tokens and makes vLLM reject longer "
|
||||
"sequences. Please increase max_num_batched_tokens or "
|
||||
"decrease max_model_len.")
|
||||
if self.max_num_batched_tokens < self.max_num_seqs:
|
||||
raise ValueError(
|
||||
f"max_num_batched_tokens ({self.max_num_batched_tokens}) must "
|
||||
"be greater than or equal to max_num_seqs "
|
||||
f"({self.max_num_seqs}).")
|
||||
|
||||
|
||||
_STR_DTYPE_TO_TORCH_DTYPE = {
|
||||
@ -269,7 +359,7 @@ _STR_DTYPE_TO_TORCH_DTYPE = {
|
||||
|
||||
def _get_and_verify_dtype(
|
||||
config: PretrainedConfig,
|
||||
dtype: str,
|
||||
dtype: Union[str, torch.dtype],
|
||||
) -> torch.dtype:
|
||||
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
|
||||
# because config.torch_dtype can be None.
|
||||
@ -277,17 +367,23 @@ def _get_and_verify_dtype(
|
||||
if config_dtype is None:
|
||||
config_dtype = torch.float32
|
||||
|
||||
dtype = dtype.lower()
|
||||
if dtype == "auto":
|
||||
if config_dtype == torch.float32:
|
||||
# Following the common practice, we use float16 for float32 models.
|
||||
torch_dtype = torch.float16
|
||||
if isinstance(dtype, str):
|
||||
dtype = dtype.lower()
|
||||
if dtype == "auto":
|
||||
if config_dtype == torch.float32:
|
||||
# Following the common practice, we use float16 for float32
|
||||
# models.
|
||||
torch_dtype = torch.float16
|
||||
else:
|
||||
torch_dtype = config_dtype
|
||||
else:
|
||||
torch_dtype = config_dtype
|
||||
if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
|
||||
raise ValueError(f"Unknown dtype: {dtype}")
|
||||
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
|
||||
elif isinstance(dtype, torch.dtype):
|
||||
torch_dtype = dtype
|
||||
else:
|
||||
if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
|
||||
raise ValueError(f"Unknown dtype: {dtype}")
|
||||
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
|
||||
raise ValueError(f"Unknown dtype: {dtype}")
|
||||
|
||||
# Verify the dtype.
|
||||
if torch_dtype != config_dtype:
|
||||
@ -301,13 +397,62 @@ def _get_and_verify_dtype(
|
||||
# Casting between float16 and bfloat16 is allowed with a warning.
|
||||
logger.warning(f"Casting {config_dtype} to {torch_dtype}.")
|
||||
|
||||
# Check if the GPU supports the dtype.
|
||||
if torch_dtype == torch.bfloat16:
|
||||
compute_capability = torch.cuda.get_device_capability()
|
||||
if compute_capability[0] < 8:
|
||||
gpu_name = torch.cuda.get_device_name()
|
||||
raise ValueError(
|
||||
"Bfloat16 is only supported on GPUs with compute capability "
|
||||
f"of at least 8.0. Your {gpu_name} GPU has compute capability "
|
||||
f"{compute_capability[0]}.{compute_capability[1]}.")
|
||||
return torch_dtype
|
||||
|
||||
|
||||
def _get_and_verify_max_len(
|
||||
hf_config: PretrainedConfig,
|
||||
max_model_len: Optional[int],
|
||||
) -> int:
|
||||
"""Get and verify the model's maximum length."""
|
||||
derived_max_model_len = float("inf")
|
||||
possible_keys = [
|
||||
# OPT
|
||||
"max_position_embeddings",
|
||||
# GPT-2
|
||||
"n_positions",
|
||||
# MPT
|
||||
"max_seq_len",
|
||||
# ChatGLM2
|
||||
"seq_length",
|
||||
# Others
|
||||
"max_sequence_length",
|
||||
"max_seq_length",
|
||||
"seq_len",
|
||||
]
|
||||
for key in possible_keys:
|
||||
max_len_key = getattr(hf_config, key, None)
|
||||
if max_len_key is not None:
|
||||
derived_max_model_len = min(derived_max_model_len, max_len_key)
|
||||
if derived_max_model_len == float("inf"):
|
||||
if max_model_len is not None:
|
||||
# If max_model_len is specified, we use it.
|
||||
return max_model_len
|
||||
|
||||
default_max_len = 2048
|
||||
logger.warning(
|
||||
"The model's config.json does not contain any of the following "
|
||||
"keys to determine the original maximum length of the model: "
|
||||
f"{possible_keys}. Assuming the model's maximum length is "
|
||||
f"{default_max_len}.")
|
||||
derived_max_model_len = default_max_len
|
||||
|
||||
rope_scaling = getattr(hf_config, "rope_scaling", None)
|
||||
if rope_scaling is not None:
|
||||
assert "factor" in rope_scaling
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
if rope_scaling["type"] == "yarn":
|
||||
derived_max_model_len = rope_scaling[
|
||||
"original_max_position_embeddings"]
|
||||
derived_max_model_len *= scaling_factor
|
||||
|
||||
if max_model_len is None:
|
||||
max_model_len = derived_max_model_len
|
||||
elif max_model_len > derived_max_model_len:
|
||||
raise ValueError(
|
||||
f"User-specified max_model_len ({max_model_len}) is greater than "
|
||||
f"the derived max_model_len ({max_len_key}={derived_max_model_len}"
|
||||
" in model's config.json). This may lead to incorrect model "
|
||||
"outputs or CUDA errors. Make sure the value is correct and "
|
||||
"within the model context size.")
|
||||
return int(max_model_len)
|
||||
|
@ -63,10 +63,18 @@ class BlockSpaceManager:
|
||||
num_gpu_blocks: int,
|
||||
num_cpu_blocks: int,
|
||||
watermark: float = 0.01,
|
||||
sliding_window: Optional[int] = None,
|
||||
) -> None:
|
||||
self.block_size = block_size
|
||||
self.num_total_gpu_blocks = num_gpu_blocks
|
||||
self.num_total_cpu_blocks = num_cpu_blocks
|
||||
|
||||
self.block_sliding_window = None
|
||||
if sliding_window is not None:
|
||||
assert sliding_window % block_size == 0, (sliding_window,
|
||||
block_size)
|
||||
self.block_sliding_window = sliding_window // block_size
|
||||
|
||||
self.watermark = watermark
|
||||
assert watermark >= 0.0
|
||||
|
||||
@ -83,6 +91,9 @@ class BlockSpaceManager:
|
||||
# the same prompt. This may not be true for preempted sequences.
|
||||
seq = seq_group.get_seqs()[0]
|
||||
num_required_blocks = len(seq.logical_token_blocks)
|
||||
if self.block_sliding_window is not None:
|
||||
num_required_blocks = min(num_required_blocks,
|
||||
self.block_sliding_window)
|
||||
num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
|
||||
# Use watermark to avoid frequent cache eviction.
|
||||
return (num_free_gpu_blocks - num_required_blocks >=
|
||||
@ -95,8 +106,12 @@ class BlockSpaceManager:
|
||||
|
||||
# Allocate new physical token blocks that will store the prompt tokens.
|
||||
block_table: BlockTable = []
|
||||
for _ in range(len(seq.logical_token_blocks)):
|
||||
block = self.gpu_allocator.allocate()
|
||||
for logical_idx in range(len(seq.logical_token_blocks)):
|
||||
if (self.block_sliding_window is not None
|
||||
and logical_idx >= self.block_sliding_window):
|
||||
block = block_table[logical_idx % self.block_sliding_window]
|
||||
else:
|
||||
block = self.gpu_allocator.allocate()
|
||||
# Set the reference counts of the token blocks.
|
||||
block.ref_count = seq_group.num_seqs()
|
||||
block_table.append(block)
|
||||
@ -118,11 +133,17 @@ class BlockSpaceManager:
|
||||
block_table = self.block_tables[seq.seq_id]
|
||||
|
||||
if len(block_table) < len(logical_blocks):
|
||||
# The sequence has a new logical block.
|
||||
# Allocate a new physical block.
|
||||
block = self.gpu_allocator.allocate()
|
||||
block_table.append(block)
|
||||
return None
|
||||
if (self.block_sliding_window
|
||||
and len(block_table) >= self.block_sliding_window):
|
||||
# re-use a block
|
||||
block_table.append(block_table[len(block_table) %
|
||||
self.block_sliding_window])
|
||||
else:
|
||||
# The sequence has a new logical block.
|
||||
# Allocate a new physical block.
|
||||
block = self.gpu_allocator.allocate()
|
||||
block_table.append(block)
|
||||
return None
|
||||
|
||||
# We want to append the token to the last physical block.
|
||||
last_block = block_table[-1]
|
||||
@ -154,9 +175,7 @@ class BlockSpaceManager:
|
||||
for seq in seq_group.get_seqs():
|
||||
if seq.is_finished():
|
||||
continue
|
||||
block_table = self.block_tables[seq.seq_id]
|
||||
for block in block_table:
|
||||
blocks.add(block)
|
||||
blocks.update(self.block_tables[seq.seq_id])
|
||||
return list(blocks)
|
||||
|
||||
def can_swap_in(self, seq_group: SequenceGroup) -> bool:
|
||||
@ -224,7 +243,7 @@ class BlockSpaceManager:
|
||||
return block_number_mapping
|
||||
|
||||
def _free_block_table(self, block_table: BlockTable) -> None:
|
||||
for block in block_table:
|
||||
for block in set(block_table):
|
||||
if block.device == Device.GPU:
|
||||
self.gpu_allocator.free(block)
|
||||
else:
|
||||
|
@ -73,7 +73,7 @@ class Scheduler:
|
||||
block_size=self.cache_config.block_size,
|
||||
num_gpu_blocks=self.cache_config.num_gpu_blocks,
|
||||
num_cpu_blocks=self.cache_config.num_cpu_blocks,
|
||||
)
|
||||
sliding_window=self.cache_config.sliding_window)
|
||||
|
||||
# TODO(zhuohan): Use deque instead of list for better performance.
|
||||
# Sequence groups in the WAITING state.
|
||||
@ -121,7 +121,7 @@ class Scheduler:
|
||||
blocks_to_copy: Dict[int, List[int]] = {}
|
||||
|
||||
# Fix the current time.
|
||||
now = time.time()
|
||||
now = time.monotonic()
|
||||
|
||||
# Join waiting sequences if possible.
|
||||
if not self.swapped:
|
||||
@ -131,7 +131,8 @@ class Scheduler:
|
||||
# requests in the generation phase.
|
||||
num_curr_seqs = sum(seq_group.get_max_num_running_seqs()
|
||||
for seq_group in self.running)
|
||||
num_batched_tokens = 0
|
||||
seq_lens: List[int] = []
|
||||
|
||||
# Optimization: We do not sort the waiting queue since the preempted
|
||||
# sequence groups are added to the front and the new sequence groups
|
||||
# are added to the back.
|
||||
@ -157,7 +158,9 @@ class Scheduler:
|
||||
break
|
||||
|
||||
# If the number of batched tokens exceeds the limit, stop.
|
||||
if (num_batched_tokens + num_prompt_tokens >
|
||||
new_seq_lens = seq_lens + [num_prompt_tokens]
|
||||
num_batched_tokens = len(new_seq_lens) * max(new_seq_lens)
|
||||
if (num_batched_tokens >
|
||||
self.scheduler_config.max_num_batched_tokens):
|
||||
break
|
||||
|
||||
@ -168,18 +171,22 @@ class Scheduler:
|
||||
self.scheduler_config.max_num_seqs):
|
||||
break
|
||||
|
||||
num_paddings = num_batched_tokens - sum(new_seq_lens)
|
||||
if num_paddings > self.scheduler_config.max_paddings:
|
||||
break
|
||||
seq_lens = new_seq_lens
|
||||
|
||||
seq_group = self.waiting.pop(0)
|
||||
self._allocate(seq_group)
|
||||
self.running.append(seq_group)
|
||||
num_batched_tokens += num_prompt_tokens
|
||||
num_curr_seqs += num_new_seqs
|
||||
scheduled.append(seq_group)
|
||||
|
||||
if scheduled:
|
||||
if scheduled or ignored_seq_groups:
|
||||
scheduler_outputs = SchedulerOutputs(
|
||||
scheduled_seq_groups=scheduled,
|
||||
prompt_run=True,
|
||||
num_batched_tokens=num_batched_tokens,
|
||||
num_batched_tokens=len(seq_lens) * max(seq_lens),
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
@ -268,7 +275,7 @@ class Scheduler:
|
||||
# Create input data structures.
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata] = []
|
||||
for seq_group in scheduler_outputs.scheduled_seq_groups:
|
||||
seq_data: Dict[int, List[SequenceData]] = {}
|
||||
seq_data: Dict[int, SequenceData] = {}
|
||||
block_tables: Dict[int, List[int]] = {}
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||
seq_id = seq.seq_id
|
||||
|
@ -18,20 +18,24 @@ class EngineArgs:
|
||||
load_format: str = 'auto'
|
||||
dtype: str = 'auto'
|
||||
seed: int = 0
|
||||
max_model_len: Optional[int] = None
|
||||
worker_use_ray: bool = False
|
||||
pipeline_parallel_size: int = 1
|
||||
tensor_parallel_size: int = 1
|
||||
block_size: int = 16
|
||||
swap_space: int = 4 # GiB
|
||||
gpu_memory_utilization: float = 0.90
|
||||
max_num_batched_tokens: int = 2560
|
||||
max_num_batched_tokens: Optional[int] = None
|
||||
max_num_seqs: int = 256
|
||||
max_paddings: int = 256
|
||||
disable_log_stats: bool = False
|
||||
revision: Optional[str] = None
|
||||
tokenizer_revision: Optional[str] = None
|
||||
quantization: Optional[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.tokenizer is None:
|
||||
self.tokenizer = self.model
|
||||
self.max_num_seqs = min(self.max_num_seqs, self.max_num_batched_tokens)
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(
|
||||
@ -48,6 +52,20 @@ class EngineArgs:
|
||||
type=str,
|
||||
default=EngineArgs.tokenizer,
|
||||
help='name or path of the huggingface tokenizer to use')
|
||||
parser.add_argument(
|
||||
'--revision',
|
||||
type=str,
|
||||
default=None,
|
||||
help='the specific model version to use. It can be a branch '
|
||||
'name, a tag name, or a commit id. If unspecified, will use '
|
||||
'the default version.')
|
||||
parser.add_argument(
|
||||
'--tokenizer-revision',
|
||||
type=str,
|
||||
default=None,
|
||||
help='the specific tokenizer version to use. It can be a branch '
|
||||
'name, a tag name, or a commit id. If unspecified, will use '
|
||||
'the default version.')
|
||||
parser.add_argument('--tokenizer-mode',
|
||||
type=str,
|
||||
default=EngineArgs.tokenizer_mode,
|
||||
@ -79,16 +97,22 @@ class EngineArgs:
|
||||
'a numpy cache to speed up the loading. '
|
||||
'"dummy" will initialize the weights with random values, '
|
||||
'which is mainly for profiling.')
|
||||
# TODO(woosuk): Support FP32.
|
||||
parser.add_argument(
|
||||
'--dtype',
|
||||
type=str,
|
||||
default=EngineArgs.dtype,
|
||||
choices=['auto', 'half', 'bfloat16', 'float'],
|
||||
choices=[
|
||||
'auto', 'half', 'float16', 'bfloat16', 'float', 'float32'
|
||||
],
|
||||
help='data type for model weights and activations. '
|
||||
'The "auto" option will use FP16 precision '
|
||||
'for FP32 and FP16 models, and BF16 precision '
|
||||
'for BF16 models.')
|
||||
parser.add_argument('--max-model-len',
|
||||
type=int,
|
||||
default=None,
|
||||
help='model context length. If unspecified, '
|
||||
'will be automatically derived from the model.')
|
||||
# Parallel arguments
|
||||
parser.add_argument('--worker-use-ray',
|
||||
action='store_true',
|
||||
@ -133,9 +157,20 @@ class EngineArgs:
|
||||
type=int,
|
||||
default=EngineArgs.max_num_seqs,
|
||||
help='maximum number of sequences per iteration')
|
||||
parser.add_argument('--max-paddings',
|
||||
type=int,
|
||||
default=EngineArgs.max_paddings,
|
||||
help='maximum number of paddings in a batch')
|
||||
parser.add_argument('--disable-log-stats',
|
||||
action='store_true',
|
||||
help='disable logging statistics')
|
||||
# Quantization settings.
|
||||
parser.add_argument('--quantization',
|
||||
'-q',
|
||||
type=str,
|
||||
choices=['awq', 'squeezellm', None],
|
||||
default=None,
|
||||
help='Method used to quantize the weights')
|
||||
return parser
|
||||
|
||||
@classmethod
|
||||
@ -149,20 +184,22 @@ class EngineArgs:
|
||||
def create_engine_configs(
|
||||
self,
|
||||
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]:
|
||||
# Initialize the configs.
|
||||
model_config = ModelConfig(self.model, self.tokenizer,
|
||||
self.tokenizer_mode, self.trust_remote_code,
|
||||
self.download_dir, self.load_format,
|
||||
self.dtype, self.seed)
|
||||
cache_config = CacheConfig(self.block_size,
|
||||
self.gpu_memory_utilization,
|
||||
self.swap_space)
|
||||
self.dtype, self.seed, self.revision,
|
||||
self.tokenizer_revision, self.max_model_len,
|
||||
self.quantization)
|
||||
cache_config = CacheConfig(
|
||||
self.block_size, self.gpu_memory_utilization, self.swap_space,
|
||||
getattr(model_config.hf_config, 'sliding_window', None))
|
||||
parallel_config = ParallelConfig(self.pipeline_parallel_size,
|
||||
self.tensor_parallel_size,
|
||||
self.worker_use_ray)
|
||||
scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
|
||||
self.max_num_seqs,
|
||||
model_config.get_max_model_len())
|
||||
model_config.max_model_len,
|
||||
self.max_paddings)
|
||||
return model_config, cache_config, parallel_config, scheduler_config
|
||||
|
||||
|
||||
@ -171,6 +208,7 @@ class AsyncEngineArgs(EngineArgs):
|
||||
"""Arguments for asynchronous vLLM engine."""
|
||||
engine_use_ray: bool = False
|
||||
disable_log_requests: bool = False
|
||||
max_log_len: Optional[int] = None
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(
|
||||
@ -183,4 +221,10 @@ class AsyncEngineArgs(EngineArgs):
|
||||
parser.add_argument('--disable-log-requests',
|
||||
action='store_true',
|
||||
help='disable logging requests')
|
||||
parser.add_argument('--max-log-len',
|
||||
type=int,
|
||||
default=None,
|
||||
help='max number of prompt characters or prompt '
|
||||
'ID numbers being printed in log. '
|
||||
'Default: unlimited.')
|
||||
return parser
|
||||
|
@ -1,7 +1,8 @@
|
||||
import asyncio
|
||||
import time
|
||||
from functools import partial
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
|
||||
from typing import (Any, Dict, Iterable, List, Optional, Set, Tuple, Type,
|
||||
Union)
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
@ -78,14 +79,24 @@ class RequestTracker:
|
||||
self._finished_requests: asyncio.Queue[str] = asyncio.Queue()
|
||||
self._new_requests: asyncio.Queue[Tuple[AsyncStream,
|
||||
dict]] = asyncio.Queue()
|
||||
self.new_requests_event = None
|
||||
|
||||
def __contains__(self, item):
|
||||
return item in self._request_streams
|
||||
|
||||
def propagate_exception(self, exc: Exception) -> None:
|
||||
"""Propagate an exception to all request streams."""
|
||||
for stream in self._request_streams.values():
|
||||
stream.put(exc)
|
||||
def init_event(self):
|
||||
self.new_requests_event = asyncio.Event()
|
||||
|
||||
def propagate_exception(self,
|
||||
exc: Exception,
|
||||
request_id: Optional[str] = None) -> None:
|
||||
"""Propagate an exception to request streams
|
||||
(all if request_id is None)."""
|
||||
if request_id is not None:
|
||||
self._request_streams[request_id].put(exc)
|
||||
else:
|
||||
for stream in self._request_streams.values():
|
||||
stream.put(exc)
|
||||
|
||||
def process_request_output(self,
|
||||
request_output: RequestOutput,
|
||||
@ -112,6 +123,9 @@ class RequestTracker:
|
||||
"request_id": request_id,
|
||||
**engine_add_request_kwargs
|
||||
}))
|
||||
|
||||
self.new_requests_event.set()
|
||||
|
||||
return stream
|
||||
|
||||
def abort_request(self, request_id: str, *, verbose: bool = False) -> None:
|
||||
@ -128,10 +142,10 @@ class RequestTracker:
|
||||
|
||||
self._request_streams[request_id].finish()
|
||||
|
||||
def get_new_and_finished_requests(self) -> Tuple[List[dict], Set[str]]:
|
||||
def get_new_and_finished_requests(self) -> Tuple[List[Dict], Set[str]]:
|
||||
"""Get the new requests and finished requests to be
|
||||
sent to the engine."""
|
||||
new_requests: List[dict] = []
|
||||
new_requests: List[Dict] = []
|
||||
finished_requests: Set[str] = set()
|
||||
|
||||
while not self._finished_requests.empty():
|
||||
@ -148,8 +162,13 @@ class RequestTracker:
|
||||
self._request_streams[stream.request_id] = stream
|
||||
new_requests.append(new_request)
|
||||
|
||||
self.new_requests_event.clear()
|
||||
|
||||
return new_requests, finished_requests
|
||||
|
||||
async def wait_for_new_requests(self):
|
||||
await self.new_requests_event.wait()
|
||||
|
||||
|
||||
class _AsyncLLMEngine(LLMEngine):
|
||||
"""Extension of LLMEngine to add async methods."""
|
||||
@ -164,10 +183,9 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
and updates the scheduler with the model outputs. Finally, it decodes
|
||||
the sequences and returns the newly generated results.
|
||||
"""
|
||||
(seq_group_metadata_list, scheduler_outputs,
|
||||
early_return) = self._schedule()
|
||||
if early_return is not None:
|
||||
return early_return
|
||||
seq_group_metadata_list, scheduler_outputs, ignored = self._schedule()
|
||||
if scheduler_outputs.is_empty():
|
||||
return ignored
|
||||
|
||||
# Execute the model.
|
||||
output = await self._run_workers_async(
|
||||
@ -178,7 +196,7 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
blocks_to_copy=scheduler_outputs.blocks_to_copy,
|
||||
)
|
||||
|
||||
return self._process_model_outputs(output, scheduler_outputs)
|
||||
return self._process_model_outputs(output, scheduler_outputs) + ignored
|
||||
|
||||
async def _run_workers_async(
|
||||
self,
|
||||
@ -188,18 +206,17 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""Runs the given method on all workers."""
|
||||
all_outputs = []
|
||||
coros = []
|
||||
for worker in self.workers:
|
||||
if self.parallel_config.worker_use_ray:
|
||||
executor = partial(worker.execute_method.remote, method)
|
||||
coros.append(
|
||||
worker.execute_method.remote(method, *args, **kwargs))
|
||||
else:
|
||||
executor = getattr(worker, method)
|
||||
coros.append(asyncio.get_event_loop().run_in_executor(
|
||||
None, partial(executor, *args, **kwargs)))
|
||||
|
||||
output = executor(*args, **kwargs)
|
||||
all_outputs.append(output)
|
||||
|
||||
if self.parallel_config.worker_use_ray:
|
||||
all_outputs = await asyncio.gather(*all_outputs)
|
||||
all_outputs = await asyncio.gather(*coros)
|
||||
|
||||
if get_all_outputs:
|
||||
return all_outputs
|
||||
@ -230,6 +247,8 @@ class AsyncLLMEngine:
|
||||
async frontend will be executed in a separate process as the
|
||||
model workers.
|
||||
log_requests: Whether to log the requests.
|
||||
start_engine_loop: If True, the background task to run the engine
|
||||
will be automatically started in the generate call.
|
||||
*args, *kwargs: Arguments for LLMEngine.
|
||||
"""
|
||||
|
||||
@ -240,17 +259,22 @@ class AsyncLLMEngine:
|
||||
engine_use_ray: bool,
|
||||
*args,
|
||||
log_requests: bool = True,
|
||||
start_engine_loop: bool = False,
|
||||
max_log_len: Optional[int] = None,
|
||||
start_engine_loop: bool = True,
|
||||
**kwargs) -> None:
|
||||
self.worker_use_ray = worker_use_ray
|
||||
self.engine_use_ray = engine_use_ray
|
||||
self.log_requests = log_requests
|
||||
self.max_log_len = max_log_len
|
||||
self.engine = self._init_engine(*args, **kwargs)
|
||||
|
||||
self.request_tracker: RequestTracker = RequestTracker()
|
||||
self.background_loop = None
|
||||
if start_engine_loop:
|
||||
self.start_background_loop()
|
||||
# We need to keep a reference to unshielded
|
||||
# task as well to prevent it from being garbage
|
||||
# collected
|
||||
self._background_loop_unshielded = None
|
||||
self.start_engine_loop = start_engine_loop
|
||||
self._request_tracker = RequestTracker()
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
@ -261,11 +285,14 @@ class AsyncLLMEngine:
|
||||
"""Start the background loop."""
|
||||
if self.is_running:
|
||||
raise RuntimeError("Background loop is already running.")
|
||||
self.background_loop = asyncio.get_event_loop().create_task(
|
||||
self.run_engine_loop())
|
||||
self.background_loop.add_done_callback(
|
||||
self._request_tracker.init_event()
|
||||
|
||||
self._background_loop_unshielded = asyncio.get_event_loop(
|
||||
).create_task(self.run_engine_loop())
|
||||
self._background_loop_unshielded.add_done_callback(
|
||||
partial(_raise_exception_on_finish,
|
||||
request_tracker=self.request_tracker))
|
||||
request_tracker=self._request_tracker))
|
||||
self.background_loop = asyncio.shield(self._background_loop_unshielded)
|
||||
|
||||
def _init_engine(self, *args,
|
||||
**kwargs) -> Union[_AsyncLLMEngine, "ray.ObjectRef"]:
|
||||
@ -277,11 +304,13 @@ class AsyncLLMEngine:
|
||||
engine_class = ray.remote(num_gpus=1)(self._engine_class).remote
|
||||
return engine_class(*args, **kwargs)
|
||||
|
||||
async def engine_step(self):
|
||||
"""Kick the engine to process the waiting requests."""
|
||||
async def engine_step(self) -> bool:
|
||||
"""Kick the engine to process the waiting requests.
|
||||
|
||||
Returns True if there are in-progress requests."""
|
||||
|
||||
new_requests, finished_requests = (
|
||||
self.request_tracker.get_new_and_finished_requests())
|
||||
self._request_tracker.get_new_and_finished_requests())
|
||||
|
||||
for new_request in new_requests:
|
||||
# Add the request into the vLLM engine's waiting queue.
|
||||
@ -301,9 +330,11 @@ class AsyncLLMEngine:
|
||||
|
||||
# Put the outputs into the corresponding streams.
|
||||
for request_output in request_outputs:
|
||||
self.request_tracker.process_request_output(
|
||||
self._request_tracker.process_request_output(
|
||||
request_output, verbose=self.log_requests)
|
||||
|
||||
return len(request_outputs) > 0
|
||||
|
||||
async def _engine_abort(self, request_ids: Iterable[str]):
|
||||
if self.engine_use_ray:
|
||||
await self.engine.abort_request.remote(request_ids)
|
||||
@ -311,8 +342,12 @@ class AsyncLLMEngine:
|
||||
self.engine.abort_request(request_ids)
|
||||
|
||||
async def run_engine_loop(self):
|
||||
# Initialize the RequestTracker here so it uses the right event loop.
|
||||
has_requests_in_progress = False
|
||||
while True:
|
||||
await self.engine_step()
|
||||
if not has_requests_in_progress:
|
||||
await self._request_tracker.wait_for_new_requests()
|
||||
has_requests_in_progress = await self.engine_step()
|
||||
await asyncio.sleep(0)
|
||||
|
||||
async def add_request(
|
||||
@ -324,19 +359,30 @@ class AsyncLLMEngine:
|
||||
arrival_time: Optional[float] = None,
|
||||
) -> AsyncStream:
|
||||
if self.log_requests:
|
||||
shortened_prompt = prompt
|
||||
shortened_token_ids = prompt_token_ids
|
||||
if self.max_log_len is not None:
|
||||
if shortened_prompt is not None:
|
||||
shortened_prompt = shortened_prompt[:self.max_log_len]
|
||||
if shortened_token_ids is not None:
|
||||
shortened_token_ids = shortened_token_ids[:self.
|
||||
max_log_len]
|
||||
logger.info(f"Received request {request_id}: "
|
||||
f"prompt: {prompt!r}, "
|
||||
f"prompt: {shortened_prompt!r}, "
|
||||
f"sampling params: {sampling_params}, "
|
||||
f"prompt token ids: {prompt_token_ids}.")
|
||||
f"prompt token ids: {shortened_token_ids}.")
|
||||
|
||||
if not self.is_running:
|
||||
raise AsyncEngineDeadError(
|
||||
"Background loop is not running. If it was running, "
|
||||
"inspect the output to find the stacktrace of the "
|
||||
"error that caused the background loop to stop "
|
||||
"(AsyncEngineDeadError).")
|
||||
if self.start_engine_loop:
|
||||
self.start_background_loop()
|
||||
else:
|
||||
raise AsyncEngineDeadError(
|
||||
"Background loop is not running. If it was running, "
|
||||
"inspect the output to find the stacktrace of the "
|
||||
"error that caused the background loop to stop "
|
||||
"(AsyncEngineDeadError).")
|
||||
|
||||
stream = self.request_tracker.add_request(
|
||||
stream = self._request_tracker.add_request(
|
||||
request_id,
|
||||
prompt=prompt,
|
||||
sampling_params=sampling_params,
|
||||
@ -370,7 +416,8 @@ class AsyncLLMEngine:
|
||||
request.
|
||||
"""
|
||||
# Preprocess the request.
|
||||
arrival_time = time.time()
|
||||
# This should not be used for logging, as it is monotonic time.
|
||||
arrival_time = time.monotonic()
|
||||
|
||||
try:
|
||||
stream = await self.add_request(request_id,
|
||||
@ -381,8 +428,9 @@ class AsyncLLMEngine:
|
||||
|
||||
async for request_output in stream:
|
||||
yield request_output
|
||||
except Exception as e:
|
||||
# If there is an exception, abort the request.
|
||||
except (Exception, asyncio.CancelledError) as e:
|
||||
# If there is an exception or coroutine is cancelled, abort the
|
||||
# request.
|
||||
self._abort(request_id)
|
||||
raise e
|
||||
|
||||
@ -413,8 +461,8 @@ class AsyncLLMEngine:
|
||||
Args:
|
||||
request_id: The unique id of the request.
|
||||
"""
|
||||
self.request_tracker.abort_request(request_id,
|
||||
verbose=self.log_requests)
|
||||
self._request_tracker.abort_request(request_id,
|
||||
verbose=self.log_requests)
|
||||
|
||||
async def get_model_config(self) -> ModelConfig:
|
||||
"""Get the model configuration of the vLLM engine."""
|
||||
@ -426,7 +474,7 @@ class AsyncLLMEngine:
|
||||
@classmethod
|
||||
def from_engine_args(cls,
|
||||
engine_args: AsyncEngineArgs,
|
||||
start_engine_loop: bool = False) -> "AsyncLLMEngine":
|
||||
start_engine_loop: bool = True) -> "AsyncLLMEngine":
|
||||
"""Creates an async LLM engine from the engine arguments."""
|
||||
# Create the engine configs.
|
||||
engine_configs = engine_args.create_engine_configs()
|
||||
@ -435,12 +483,13 @@ class AsyncLLMEngine:
|
||||
distributed_init_method, placement_group = initialize_cluster(
|
||||
parallel_config, engine_args.engine_use_ray)
|
||||
# Create the async LLM engine.
|
||||
engine = cls(engine_args.worker_use_ray,
|
||||
engine = cls(parallel_config.worker_use_ray,
|
||||
engine_args.engine_use_ray,
|
||||
*engine_configs,
|
||||
distributed_init_method,
|
||||
placement_group,
|
||||
log_requests=not engine_args.disable_log_requests,
|
||||
log_stats=not engine_args.disable_log_stats,
|
||||
max_log_len=engine_args.max_log_len,
|
||||
start_engine_loop=start_engine_loop)
|
||||
return engine
|
||||
|
@ -12,8 +12,8 @@ from vllm.logger import init_logger
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup,
|
||||
SequenceGroupMetadata, SequenceOutputs,
|
||||
SequenceStatus)
|
||||
SequenceGroupMetadata, SequenceGroupOutputs,
|
||||
SequenceOutputs, SequenceStatus)
|
||||
from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
|
||||
get_tokenizer)
|
||||
from vllm.utils import Counter
|
||||
@ -54,8 +54,8 @@ class LLMEngine:
|
||||
scheduler_config: The configuration related to the request scheduler.
|
||||
distributed_init_method: The initialization method for distributed
|
||||
execution. See `torch.distributed.init_process_group` for details.
|
||||
stage_devices: The list of devices for each stage. Each stage is a list
|
||||
of (rank, node_resource, device) tuples.
|
||||
placement_group: Ray placement group for distributed execution.
|
||||
Required for distributed execution.
|
||||
log_stats: Whether to log statistics.
|
||||
"""
|
||||
|
||||
@ -74,16 +74,22 @@ class LLMEngine:
|
||||
f"model={model_config.model!r}, "
|
||||
f"tokenizer={model_config.tokenizer!r}, "
|
||||
f"tokenizer_mode={model_config.tokenizer_mode}, "
|
||||
f"revision={model_config.revision}, "
|
||||
f"tokenizer_revision={model_config.tokenizer_revision}, "
|
||||
f"trust_remote_code={model_config.trust_remote_code}, "
|
||||
f"dtype={model_config.dtype}, "
|
||||
f"max_seq_len={model_config.max_model_len}, "
|
||||
f"download_dir={model_config.download_dir!r}, "
|
||||
f"load_format={model_config.load_format}, "
|
||||
f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
|
||||
f"quantization={model_config.quantization}, "
|
||||
f"seed={model_config.seed})")
|
||||
# TODO(woosuk): Print more configs in debug mode.
|
||||
|
||||
self.model_config = model_config
|
||||
self.cache_config = cache_config
|
||||
assert self.cache_config.sliding_window == getattr(
|
||||
self.model_config.hf_config, "sliding_window", None)
|
||||
self.parallel_config = parallel_config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.log_stats = log_stats
|
||||
@ -92,7 +98,9 @@ class LLMEngine:
|
||||
self.tokenizer = get_tokenizer(
|
||||
model_config.tokenizer,
|
||||
tokenizer_mode=model_config.tokenizer_mode,
|
||||
trust_remote_code=model_config.trust_remote_code)
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
tokenizer_revision=model_config.tokenizer_revision,
|
||||
revision=model_config.revision)
|
||||
self.seq_counter = Counter()
|
||||
|
||||
# Create the parallel GPU workers.
|
||||
@ -153,7 +161,7 @@ class LLMEngine:
|
||||
placement_group=placement_group,
|
||||
placement_group_capture_child_tasks=True),
|
||||
**ray_remote_kwargs,
|
||||
)(RayWorker).remote()
|
||||
)(RayWorker).remote(self.model_config.trust_remote_code)
|
||||
self.workers.append(worker)
|
||||
|
||||
# Initialize torch distributed process group for the workers.
|
||||
@ -248,10 +256,10 @@ class LLMEngine:
|
||||
prompt_token_ids: The token IDs of the prompt. If None, we
|
||||
use the tokenizer to convert the prompts to token IDs.
|
||||
arrival_time: The arrival time of the request. If None, we use
|
||||
the current time.
|
||||
the current monotonic time.
|
||||
"""
|
||||
if arrival_time is None:
|
||||
arrival_time = time.time()
|
||||
arrival_time = time.monotonic()
|
||||
if prompt_token_ids is None:
|
||||
assert prompt is not None
|
||||
prompt_token_ids = self.tokenizer.encode(prompt)
|
||||
@ -291,14 +299,12 @@ class LLMEngine:
|
||||
def _schedule(
|
||||
self
|
||||
) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs,
|
||||
Optional[List[RequestOutput]]]:
|
||||
List[RequestOutput]]:
|
||||
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
|
||||
if scheduler_outputs.is_empty():
|
||||
return seq_group_metadata_list, scheduler_outputs, [
|
||||
RequestOutput.from_seq_group(seq_group)
|
||||
for seq_group in scheduler_outputs.ignored_seq_groups
|
||||
]
|
||||
return seq_group_metadata_list, scheduler_outputs, None
|
||||
return seq_group_metadata_list, scheduler_outputs, [
|
||||
RequestOutput.from_seq_group(seq_group)
|
||||
for seq_group in scheduler_outputs.ignored_seq_groups
|
||||
]
|
||||
|
||||
def _check_beam_search_early_stopping(
|
||||
self,
|
||||
@ -344,9 +350,15 @@ class LLMEngine:
|
||||
eos_token_id=self.tokenizer.eos_token_id))
|
||||
return current_worst_score >= highest_attainable_score
|
||||
|
||||
def _process_sequence_group_samples(
|
||||
self, seq_group: SequenceGroup,
|
||||
samples: List[SequenceOutputs]) -> None:
|
||||
def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
|
||||
outputs: SequenceGroupOutputs) -> None:
|
||||
# Process prompt logprobs
|
||||
prompt_logprobs = outputs.prompt_logprobs
|
||||
if prompt_logprobs is not None:
|
||||
seq_group.prompt_logprobs = prompt_logprobs
|
||||
|
||||
# Process samples
|
||||
samples = outputs.samples
|
||||
parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
|
||||
existing_finished_seqs = seq_group.get_finished_seqs()
|
||||
parent_child_dict = {
|
||||
@ -386,7 +398,7 @@ class LLMEngine:
|
||||
child_seqs.append((parent, parent))
|
||||
|
||||
for seq, _ in child_seqs:
|
||||
self._decode_sequence(seq)
|
||||
self._decode_sequence(seq, seq_group.sampling_params)
|
||||
self._check_stop(seq, seq_group.sampling_params)
|
||||
|
||||
# Non-beam search case
|
||||
@ -514,8 +526,8 @@ class LLMEngine:
|
||||
scheduler_outputs: SchedulerOutputs) -> List[RequestOutput]:
|
||||
# Update the scheduled sequence groups with the model outputs.
|
||||
scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups
|
||||
for seq_group, samples in zip(scheduled_seq_groups, output):
|
||||
self._process_sequence_group_samples(seq_group, samples)
|
||||
for seq_group, outputs in zip(scheduled_seq_groups, output):
|
||||
self._process_sequence_group_outputs(seq_group, outputs)
|
||||
|
||||
# Free the finished sequence groups.
|
||||
self.scheduler.free_finished_seq_groups()
|
||||
@ -542,10 +554,9 @@ class LLMEngine:
|
||||
and updates the scheduler with the model outputs. Finally, it decodes
|
||||
the sequences and returns the newly generated results.
|
||||
"""
|
||||
(seq_group_metadata_list, scheduler_outputs,
|
||||
early_return) = self._schedule()
|
||||
if early_return is not None:
|
||||
return early_return
|
||||
seq_group_metadata_list, scheduler_outputs, ignored = self._schedule()
|
||||
if scheduler_outputs.is_empty():
|
||||
return ignored
|
||||
|
||||
# Execute the model.
|
||||
output = self._run_workers(
|
||||
@ -563,7 +574,7 @@ class LLMEngine:
|
||||
prompt_run: bool,
|
||||
num_batched_tokens: int,
|
||||
) -> None:
|
||||
now = time.time()
|
||||
now = time.monotonic()
|
||||
# Log the number of batched input tokens.
|
||||
if prompt_run:
|
||||
self.num_prompt_tokens.append((now, num_batched_tokens))
|
||||
@ -621,17 +632,25 @@ class LLMEngine:
|
||||
f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%")
|
||||
self.last_logging_time = now
|
||||
|
||||
def _decode_sequence(self, seq: Sequence) -> None:
|
||||
def _decode_sequence(self, seq: Sequence, prms: SamplingParams) -> None:
|
||||
"""Decodes the new token for a sequence."""
|
||||
new_token, new_output_text = detokenize_incrementally(
|
||||
self.tokenizer,
|
||||
seq.output_tokens,
|
||||
seq.get_last_token_id(),
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
if new_token is not None:
|
||||
seq.output_tokens.append(new_token)
|
||||
seq.output_text = new_output_text
|
||||
(new_tokens, new_output_text, prefix_offset,
|
||||
read_offset) = detokenize_incrementally(
|
||||
self.tokenizer,
|
||||
all_input_ids=seq.get_token_ids(),
|
||||
prev_tokens=seq.tokens,
|
||||
prefix_offset=seq.prefix_offset,
|
||||
read_offset=seq.read_offset,
|
||||
skip_special_tokens=prms.skip_special_tokens,
|
||||
spaces_between_special_tokens=prms.spaces_between_special_tokens,
|
||||
)
|
||||
if seq.tokens is None:
|
||||
seq.tokens = new_tokens
|
||||
else:
|
||||
seq.tokens.extend(new_tokens)
|
||||
seq.prefix_offset = prefix_offset
|
||||
seq.read_offset = read_offset
|
||||
seq.output_text += new_output_text
|
||||
|
||||
def _check_stop(self, seq: Sequence,
|
||||
sampling_params: SamplingParams) -> None:
|
||||
@ -643,6 +662,9 @@ class LLMEngine:
|
||||
seq.output_text = seq.output_text[:-len(stop_str)]
|
||||
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||
return
|
||||
if seq.get_last_token_id() in sampling_params.stop_token_ids:
|
||||
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||
return
|
||||
|
||||
# Check if the sequence has reached max_model_len.
|
||||
if seq.get_len() > self.scheduler_config.max_model_len:
|
||||
|
@ -2,6 +2,9 @@ import socket
|
||||
from typing import Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
try:
|
||||
import ray
|
||||
@ -11,7 +14,11 @@ try:
|
||||
"""Ray wrapper for vllm.worker.Worker, allowing Worker to be
|
||||
lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
def __init__(self, init_cached_hf_modules=False) -> None:
|
||||
if init_cached_hf_modules:
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from transformers.dynamic_module_utils import init_hf_modules
|
||||
init_hf_modules()
|
||||
self.worker = None
|
||||
|
||||
def init_worker(self, worker_init_fn):
|
||||
@ -24,7 +31,10 @@ try:
|
||||
executor = getattr(self, method)
|
||||
return executor(*args, **kwargs)
|
||||
|
||||
except ImportError:
|
||||
except ImportError as e:
|
||||
logger.warning(f"Failed to import Ray with {e!r}. "
|
||||
"For distributed inference, please install Ray with "
|
||||
"`pip install ray pandas pyarrow`.")
|
||||
ray = None
|
||||
TorchDistributedWorker = None
|
||||
RayWorker = None # pylint: disable=invalid-name
|
||||
@ -53,11 +63,10 @@ def initialize_cluster(
|
||||
the default Ray cluster address.
|
||||
|
||||
Returns:
|
||||
A tuple of (`distributed_init_method`, `all_stage_devices`). The
|
||||
A tuple of (`distributed_init_method`, `placement_group`). The
|
||||
`distributed_init_method` is the address for initializing the
|
||||
distributed backend. `all_stage_devices` includes device IDs for
|
||||
each worker in each pipeline stage. Each device ID is a tuple of
|
||||
(rank, node resource, device id).
|
||||
distributed backend. `placement_group` includes the specification
|
||||
of the resources for each distributed worker.
|
||||
"""
|
||||
if parallel_config.worker_use_ray or engine_use_ray:
|
||||
if ray is None:
|
||||
|
@ -2,7 +2,7 @@ import argparse
|
||||
import json
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from fastapi import BackgroundTasks, FastAPI, Request
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||
import uvicorn
|
||||
|
||||
@ -17,6 +17,12 @@ app = FastAPI()
|
||||
engine = None
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health() -> Response:
|
||||
"""Health check."""
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@app.post("/generate")
|
||||
async def generate(request: Request) -> Response:
|
||||
"""Generate completion for the request.
|
||||
@ -32,9 +38,6 @@ async def generate(request: Request) -> Response:
|
||||
sampling_params = SamplingParams(**request_dict)
|
||||
request_id = random_uuid()
|
||||
|
||||
if not engine.is_running:
|
||||
engine.start_background_loop()
|
||||
|
||||
results_generator = engine.generate(prompt, sampling_params, request_id)
|
||||
|
||||
# Streaming case
|
||||
@ -47,14 +50,8 @@ async def generate(request: Request) -> Response:
|
||||
ret = {"text": text_outputs}
|
||||
yield (json.dumps(ret) + "\0").encode("utf-8")
|
||||
|
||||
async def abort_request() -> None:
|
||||
await engine.abort(request_id)
|
||||
|
||||
if stream:
|
||||
background_tasks = BackgroundTasks()
|
||||
# Abort the request if the client disconnects.
|
||||
background_tasks.add_task(abort_request)
|
||||
return StreamingResponse(stream_results(), background=background_tasks)
|
||||
return StreamingResponse(stream_results())
|
||||
|
||||
# Non-streaming case
|
||||
final_output = None
|
||||
@ -74,14 +71,13 @@ async def generate(request: Request) -> Response:
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="localhost")
|
||||
parser.add_argument("--host", type=str, default=None)
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
engine = AsyncLLMEngine.from_engine_args(engine_args,
|
||||
start_engine_loop=False)
|
||||
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
|
||||
uvicorn.run(app,
|
||||
host=args.host,
|
||||
|
@ -37,7 +37,24 @@ class LLM:
|
||||
the `torch_dtype` attribute specified in the model config file.
|
||||
However, if the `torch_dtype` in the config is `float32`, we will
|
||||
use `float16` instead.
|
||||
quantization: The method used to quantize the model weights. Currently,
|
||||
we support "awq". If None, we assume the model weights are not
|
||||
quantized and use `dtype` to determine the data type of the weights.
|
||||
revision: The specific model version to use. It can be a branch name,
|
||||
a tag name, or a commit id.
|
||||
tokenizer_revision: The specific tokenizer version to use. It can be a
|
||||
branch name, a tag name, or a commit id.
|
||||
seed: The seed to initialize the random number generator for sampling.
|
||||
gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
|
||||
reserve for the model weights, activations, and KV cache. Higher
|
||||
values will increase the KV cache size and thus improve the model's
|
||||
throughput. However, if the value is too high, it may cause out-of-
|
||||
memory (OOM) errors.
|
||||
swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
|
||||
This can be used for temporarily storing the states of the requests
|
||||
when their `best_of` sampling parameters are larger than 1. If all
|
||||
requests will have `best_of=1`, you can safely set this to 0.
|
||||
Otherwise, too small values may cause out-of-memory (OOM) errors.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -48,7 +65,12 @@ class LLM:
|
||||
trust_remote_code: bool = False,
|
||||
tensor_parallel_size: int = 1,
|
||||
dtype: str = "auto",
|
||||
quantization: Optional[str] = None,
|
||||
revision: Optional[str] = None,
|
||||
tokenizer_revision: Optional[str] = None,
|
||||
seed: int = 0,
|
||||
gpu_memory_utilization: float = 0.9,
|
||||
swap_space: int = 4,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
if "disable_log_stats" not in kwargs:
|
||||
@ -60,7 +82,12 @@ class LLM:
|
||||
trust_remote_code=trust_remote_code,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
dtype=dtype,
|
||||
quantization=quantization,
|
||||
revision=revision,
|
||||
tokenizer_revision=tokenizer_revision,
|
||||
seed=seed,
|
||||
gpu_memory_utilization=gpu_memory_utilization,
|
||||
swap_space=swap_space,
|
||||
**kwargs,
|
||||
)
|
||||
self.llm_engine = LLMEngine.from_engine_args(engine_args)
|
||||
|
@ -10,10 +10,10 @@ from typing import AsyncGenerator, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import fastapi
|
||||
import uvicorn
|
||||
from fastapi import BackgroundTasks, Request
|
||||
from fastapi import Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from fastapi.responses import JSONResponse, StreamingResponse, Response
|
||||
from packaging import version
|
||||
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
@ -130,6 +130,8 @@ async def check_length(
|
||||
input_ids = tokenizer(prompt).input_ids
|
||||
token_num = len(input_ids)
|
||||
|
||||
if request.max_tokens is None:
|
||||
request.max_tokens = max_model_len - token_num
|
||||
if token_num + request.max_tokens > max_model_len:
|
||||
return input_ids, create_error_response(
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
@ -143,6 +145,12 @@ async def check_length(
|
||||
return input_ids, None
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health() -> Response:
|
||||
"""Health check."""
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@app.get("/v1/models")
|
||||
async def show_available_models():
|
||||
"""Show available models. Right now we only have one model."""
|
||||
@ -192,14 +200,11 @@ async def create_chat_completion(request: ChatCompletionRequest,
|
||||
"""
|
||||
logger.info(f"Received chat completion request: {request}")
|
||||
|
||||
if not engine.is_running:
|
||||
engine.start_background_loop()
|
||||
|
||||
error_check_ret = await check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
if request.logit_bias is not None:
|
||||
if request.logit_bias is not None and len(request.logit_bias) > 0:
|
||||
# TODO: support logit_bias in vLLM engine.
|
||||
return create_error_response(HTTPStatus.BAD_REQUEST,
|
||||
"logit_bias is not currently supported")
|
||||
@ -211,8 +216,9 @@ async def create_chat_completion(request: ChatCompletionRequest,
|
||||
|
||||
model_name = request.model
|
||||
request_id = f"cmpl-{random_uuid()}"
|
||||
created_time = int(time.time())
|
||||
created_time = int(time.monotonic())
|
||||
try:
|
||||
spaces_between_special_tokens = request.spaces_between_special_tokens
|
||||
sampling_params = SamplingParams(
|
||||
n=request.n,
|
||||
presence_penalty=request.presence_penalty,
|
||||
@ -220,11 +226,14 @@ async def create_chat_completion(request: ChatCompletionRequest,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
stop=request.stop,
|
||||
stop_token_ids=request.stop_token_ids,
|
||||
max_tokens=request.max_tokens,
|
||||
best_of=request.best_of,
|
||||
top_k=request.top_k,
|
||||
ignore_eos=request.ignore_eos,
|
||||
use_beam_search=request.use_beam_search,
|
||||
skip_special_tokens=request.skip_special_tokens,
|
||||
spaces_between_special_tokens=spaces_between_special_tokens,
|
||||
)
|
||||
except ValueError as e:
|
||||
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
|
||||
@ -232,13 +241,11 @@ async def create_chat_completion(request: ChatCompletionRequest,
|
||||
result_generator = engine.generate(prompt, sampling_params, request_id,
|
||||
token_ids)
|
||||
|
||||
async def abort_request() -> None:
|
||||
await engine.abort(request_id)
|
||||
|
||||
def create_stream_response_json(
|
||||
index: int,
|
||||
text: str,
|
||||
finish_reason: Optional[str] = None,
|
||||
usage: Optional[UsageInfo] = None,
|
||||
) -> str:
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=index,
|
||||
@ -251,7 +258,10 @@ async def create_chat_completion(request: ChatCompletionRequest,
|
||||
model=model_name,
|
||||
choices=[choice_data],
|
||||
)
|
||||
response_json = response.json(ensure_ascii=False)
|
||||
if usage is not None:
|
||||
response.usage = usage
|
||||
# exclude unset to leave details out of each sse
|
||||
response_json = response.json(exclude_unset=True, ensure_ascii=False)
|
||||
|
||||
return response_json
|
||||
|
||||
@ -277,36 +287,40 @@ async def create_chat_completion(request: ChatCompletionRequest,
|
||||
i = output.index
|
||||
delta_text = output.text[len(previous_texts[i]):]
|
||||
previous_texts[i] = output.text
|
||||
previous_num_tokens[i] = len(output.token_ids)
|
||||
completion_tokens = len(output.token_ids)
|
||||
previous_num_tokens[i] = completion_tokens
|
||||
response_json = create_stream_response_json(
|
||||
index=i,
|
||||
text=delta_text,
|
||||
)
|
||||
yield f"data: {response_json}\n\n"
|
||||
if output.finish_reason is not None:
|
||||
prompt_tokens = len(res.prompt_token_ids)
|
||||
final_usage = UsageInfo(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
response_json = create_stream_response_json(
|
||||
index=i,
|
||||
text="",
|
||||
finish_reason=output.finish_reason,
|
||||
usage=final_usage,
|
||||
)
|
||||
yield f"data: {response_json}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
# Streaming response
|
||||
if request.stream:
|
||||
background_tasks = BackgroundTasks()
|
||||
# Abort the request if the client disconnects.
|
||||
background_tasks.add_task(abort_request)
|
||||
return StreamingResponse(completion_stream_generator(),
|
||||
media_type="text/event-stream",
|
||||
background=background_tasks)
|
||||
media_type="text/event-stream")
|
||||
|
||||
# Non-streaming response
|
||||
final_res: RequestOutput = None
|
||||
async for res in result_generator:
|
||||
if await raw_request.is_disconnected():
|
||||
# Abort the request if the client disconnects.
|
||||
await abort_request()
|
||||
await engine.abort(request_id)
|
||||
return create_error_response(HTTPStatus.BAD_REQUEST,
|
||||
"Client disconnected")
|
||||
final_res = res
|
||||
@ -367,9 +381,6 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
|
||||
"""
|
||||
logger.info(f"Received completion request: {request}")
|
||||
|
||||
if not engine.is_running:
|
||||
engine.start_background_loop()
|
||||
|
||||
error_check_ret = await check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
@ -385,7 +396,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
|
||||
return create_error_response(HTTPStatus.BAD_REQUEST,
|
||||
"suffix is not currently supported")
|
||||
|
||||
if request.logit_bias is not None:
|
||||
if request.logit_bias is not None and len(request.logit_bias) > 0:
|
||||
# TODO: support logit_bias in vLLM engine.
|
||||
return create_error_response(HTTPStatus.BAD_REQUEST,
|
||||
"logit_bias is not currently supported")
|
||||
@ -420,8 +431,9 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
created_time = int(time.time())
|
||||
created_time = int(time.monotonic())
|
||||
try:
|
||||
spaces_between_special_tokens = request.spaces_between_special_tokens
|
||||
sampling_params = SamplingParams(
|
||||
n=request.n,
|
||||
best_of=request.best_of,
|
||||
@ -431,10 +443,13 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
|
||||
top_p=request.top_p,
|
||||
top_k=request.top_k,
|
||||
stop=request.stop,
|
||||
stop_token_ids=request.stop_token_ids,
|
||||
ignore_eos=request.ignore_eos,
|
||||
max_tokens=request.max_tokens,
|
||||
logprobs=request.logprobs,
|
||||
use_beam_search=request.use_beam_search,
|
||||
skip_special_tokens=request.skip_special_tokens,
|
||||
spaces_between_special_tokens=spaces_between_special_tokens,
|
||||
)
|
||||
except ValueError as e:
|
||||
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
|
||||
@ -454,14 +469,12 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
|
||||
and (request.best_of is None or request.n == request.best_of)
|
||||
and not request.use_beam_search)
|
||||
|
||||
async def abort_request() -> None:
|
||||
await engine.abort(request_id)
|
||||
|
||||
def create_stream_response_json(
|
||||
index: int,
|
||||
text: str,
|
||||
logprobs: Optional[LogProbs] = None,
|
||||
finish_reason: Optional[str] = None,
|
||||
usage: Optional[UsageInfo] = None,
|
||||
) -> str:
|
||||
choice_data = CompletionResponseStreamChoice(
|
||||
index=index,
|
||||
@ -475,7 +488,9 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
|
||||
model=model_name,
|
||||
choices=[choice_data],
|
||||
)
|
||||
response_json = response.json(ensure_ascii=False)
|
||||
if usage is not None:
|
||||
response.usage = usage
|
||||
response_json = response.json(exclude_unset=True, ensure_ascii=False)
|
||||
|
||||
return response_json
|
||||
|
||||
@ -505,30 +520,34 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
|
||||
if output.finish_reason is not None:
|
||||
logprobs = (LogProbs()
|
||||
if request.logprobs is not None else None)
|
||||
prompt_tokens = len(res.prompt_token_ids)
|
||||
completion_tokens = len(output.token_ids)
|
||||
final_usage = UsageInfo(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
response_json = create_stream_response_json(
|
||||
index=i,
|
||||
text="",
|
||||
logprobs=logprobs,
|
||||
finish_reason=output.finish_reason,
|
||||
usage=final_usage,
|
||||
)
|
||||
yield f"data: {response_json}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
# Streaming response
|
||||
if stream:
|
||||
background_tasks = BackgroundTasks()
|
||||
# Abort the request if the client disconnects.
|
||||
background_tasks.add_task(abort_request)
|
||||
return StreamingResponse(completion_stream_generator(),
|
||||
media_type="text/event-stream",
|
||||
background=background_tasks)
|
||||
media_type="text/event-stream")
|
||||
|
||||
# Non-streaming response
|
||||
final_res: RequestOutput = None
|
||||
async for res in result_generator:
|
||||
if await raw_request.is_disconnected():
|
||||
# Abort the request if the client disconnects.
|
||||
await abort_request()
|
||||
await engine.abort(request_id)
|
||||
return create_error_response(HTTPStatus.BAD_REQUEST,
|
||||
"Client disconnected")
|
||||
final_res = res
|
||||
@ -581,10 +600,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="vLLM OpenAI-Compatible RESTful API server.")
|
||||
parser.add_argument("--host",
|
||||
type=str,
|
||||
default="localhost",
|
||||
help="host name")
|
||||
parser.add_argument("--host", type=str, default=None, help="host name")
|
||||
parser.add_argument("--port", type=int, default=8000, help="port number")
|
||||
parser.add_argument("--allow-credentials",
|
||||
action="store_true",
|
||||
@ -627,15 +643,15 @@ if __name__ == "__main__":
|
||||
served_model = args.model
|
||||
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
engine = AsyncLLMEngine.from_engine_args(engine_args,
|
||||
start_engine_loop=False)
|
||||
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
engine_model_config = asyncio.run(engine.get_model_config())
|
||||
max_model_len = engine_model_config.get_max_model_len()
|
||||
max_model_len = engine_model_config.max_model_len
|
||||
|
||||
# A separate tokenizer to map token IDs to strings.
|
||||
tokenizer = get_tokenizer(engine_args.tokenizer,
|
||||
tokenizer_mode=engine_args.tokenizer_mode,
|
||||
trust_remote_code=engine_args.trust_remote_code)
|
||||
tokenizer = get_tokenizer(
|
||||
engine_model_config.tokenizer,
|
||||
tokenizer_mode=engine_model_config.tokenizer_mode,
|
||||
trust_remote_code=engine_model_config.trust_remote_code)
|
||||
|
||||
uvicorn.run(app,
|
||||
host=args.host,
|
||||
|
@ -58,7 +58,7 @@ class ChatCompletionRequest(BaseModel):
|
||||
temperature: Optional[float] = 0.7
|
||||
top_p: Optional[float] = 1.0
|
||||
n: Optional[int] = 1
|
||||
max_tokens: Optional[int] = 16
|
||||
max_tokens: Optional[int] = None
|
||||
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
||||
stream: Optional[bool] = False
|
||||
presence_penalty: Optional[float] = 0.0
|
||||
@ -70,6 +70,9 @@ class ChatCompletionRequest(BaseModel):
|
||||
top_k: Optional[int] = -1
|
||||
ignore_eos: Optional[bool] = False
|
||||
use_beam_search: Optional[bool] = False
|
||||
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
||||
skip_special_tokens: Optional[bool] = True
|
||||
spaces_between_special_tokens: Optional[bool] = True
|
||||
|
||||
|
||||
class CompletionRequest(BaseModel):
|
||||
@ -94,6 +97,9 @@ class CompletionRequest(BaseModel):
|
||||
top_k: Optional[int] = -1
|
||||
ignore_eos: Optional[bool] = False
|
||||
use_beam_search: Optional[bool] = False
|
||||
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
||||
skip_special_tokens: Optional[bool] = True
|
||||
spaces_between_special_tokens: Optional[bool] = True
|
||||
|
||||
|
||||
class LogProbs(BaseModel):
|
||||
@ -133,6 +139,7 @@ class CompletionStreamResponse(BaseModel):
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: List[CompletionResponseStreamChoice]
|
||||
usage: Optional[UsageInfo]
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
@ -172,3 +179,5 @@ class ChatCompletionStreamResponse(BaseModel):
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: List[ChatCompletionResponseStreamChoice]
|
||||
usage: Optional[UsageInfo] = Field(
|
||||
default=None, description="data about request and response")
|
||||
|
@ -48,4 +48,9 @@ _setup_logger()
|
||||
|
||||
|
||||
def init_logger(name: str):
|
||||
return logging.getLogger(name)
|
||||
# Use the same settings as above for root logger
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
logger.addHandler(_default_handler)
|
||||
logger.propagate = False
|
||||
return logger
|
||||
|
@ -1,9 +1,9 @@
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from xformers.ops import AttentionBias
|
||||
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
from vllm.sequence import SequenceData
|
||||
|
||||
|
||||
@ -29,6 +29,9 @@ class InputMetadata:
|
||||
context_lens: torch.Tensor,
|
||||
max_context_len: int,
|
||||
block_tables: torch.Tensor,
|
||||
selected_token_indices: torch.Tensor,
|
||||
categorized_sample_indices: Dict[SamplingType, torch.Tensor],
|
||||
sliding_window: Optional[int] = None,
|
||||
) -> None:
|
||||
self.seq_groups = seq_groups
|
||||
self.seq_data = seq_data
|
||||
@ -37,31 +40,52 @@ class InputMetadata:
|
||||
self.context_lens = context_lens
|
||||
self.max_context_len = max_context_len
|
||||
self.block_tables = block_tables
|
||||
self.selected_token_indices = selected_token_indices
|
||||
self.categorized_sample_indices = categorized_sample_indices
|
||||
|
||||
self.max_prompt_len = max(prompt_lens) if prompt_lens else 0
|
||||
self.to_cache = None
|
||||
if sliding_window is not None:
|
||||
# We need to keep the positions of sliding windows within
|
||||
# the key / value tables, this is helpful to know which
|
||||
# elements we need to cache.
|
||||
to_cache, start_idx = [], 0
|
||||
for prompt_len in self.prompt_lens:
|
||||
to_cache.extend(
|
||||
range(
|
||||
start_idx + max(0, prompt_len - sliding_window),
|
||||
start_idx + prompt_len,
|
||||
))
|
||||
start_idx += self.max_prompt_len
|
||||
to_cache.extend(range(start_idx, slot_mapping.shape[0]))
|
||||
self.to_cache = torch.tensor(to_cache,
|
||||
dtype=torch.int32,
|
||||
device=self.slot_mapping.device)
|
||||
|
||||
self.num_prompts = len(prompt_lens)
|
||||
self.num_prompt_tokens = sum(prompt_lens)
|
||||
self.num_prompt_tokens = self.num_prompts * self.max_prompt_len
|
||||
self.num_generation_tokens = context_lens.shape[0]
|
||||
self.num_valid_tokens = slot_mapping.shape[0]
|
||||
if block_tables.numel() > 0:
|
||||
self.max_num_blocks_per_seq = block_tables.shape[1]
|
||||
else:
|
||||
self.max_num_blocks_per_seq = 0
|
||||
assert block_tables.shape[0] == self.num_generation_tokens
|
||||
assert context_lens.shape[0] == self.num_generation_tokens
|
||||
|
||||
# Set during the execution of the first attention op.
|
||||
self.attn_bias: List[AttentionBias] = []
|
||||
self.attn_bias: Optional[AttentionBias] = None
|
||||
|
||||
def __repr__(self) -> str:
|
||||
# Print only useful metadata.
|
||||
return (f'InputMetadata('
|
||||
f'num_valid_tokens={self.num_valid_tokens}, '
|
||||
f'num_prompt_tokens={self.num_prompt_tokens}, '
|
||||
f'num_prompts={self.num_prompts}, '
|
||||
f'prompt_lens={self.prompt_lens}, '
|
||||
f'num_generation_tokens={self.num_generation_tokens}, '
|
||||
f'context_lens={self.context_lens}, '
|
||||
f'max_context_len={self.max_context_len}), '
|
||||
f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, '
|
||||
f'block_tables={self.block_tables}), '
|
||||
f'slot_mapping={self.slot_mapping}')
|
||||
return (
|
||||
f'InputMetadata('
|
||||
f'num_prompt_tokens={self.num_prompt_tokens}, '
|
||||
f'num_prompts={self.num_prompts}, '
|
||||
f'prompt_lens={self.prompt_lens}, '
|
||||
f'num_generation_tokens={self.num_generation_tokens}, '
|
||||
f'context_lens={self.context_lens}, '
|
||||
f'max_context_len={self.max_context_len}), '
|
||||
f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, '
|
||||
f'block_tables={self.block_tables}, '
|
||||
f'selected_token_indices={self.selected_token_indices}, '
|
||||
f'categorized_sample_indices={self.categorized_sample_indices}, '
|
||||
f'slot_mapping={self.slot_mapping})')
|
||||
|
@ -1,24 +1,27 @@
|
||||
"""Custom activation functions."""
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm import activation_ops
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
|
||||
|
||||
class SiluAndMul(nn.Module):
|
||||
"""An activation function for SwiGLU.
|
||||
|
||||
The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[1] // 2.
|
||||
The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
|
||||
|
||||
Shapes:
|
||||
x: (num_tokens, 2 * d)
|
||||
return: (num_tokens, d)
|
||||
x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
|
||||
return: (batch_size, seq_len, d) or (num_tokens, d)
|
||||
"""
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
num_tokens = x.shape[0]
|
||||
d = x.shape[1] // 2
|
||||
out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device)
|
||||
d = x.shape[-1] // 2
|
||||
output_shape = (x.shape[:-1] + (d, ))
|
||||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||
activation_ops.silu_and_mul(out, x)
|
||||
return out
|
||||
|
||||
@ -26,9 +29,7 @@ class SiluAndMul(nn.Module):
|
||||
class NewGELU(nn.Module):
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
num_tokens = x.shape[0]
|
||||
d = x.shape[1]
|
||||
out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device)
|
||||
out = torch.empty_like(x)
|
||||
activation_ops.gelu_new(out, x)
|
||||
return out
|
||||
|
||||
@ -36,13 +37,32 @@ class NewGELU(nn.Module):
|
||||
class FastGELU(nn.Module):
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
num_tokens = x.shape[0]
|
||||
d = x.shape[1]
|
||||
out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device)
|
||||
out = torch.empty_like(x)
|
||||
activation_ops.gelu_fast(out, x)
|
||||
return out
|
||||
|
||||
|
||||
class ScaledActivation(nn.Module):
|
||||
"""An activation function with post-scale parameters.
|
||||
|
||||
This is used for some quantization methods like AWQ.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
act_module: nn.Module,
|
||||
hidden_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
):
|
||||
super().__init__()
|
||||
self.act = act_module
|
||||
self.scales = nn.Parameter(
|
||||
torch.empty(hidden_size, dtype=params_dtype, device="cuda"))
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
return self.act(x) / self.scales
|
||||
|
||||
|
||||
_ACTIVATION_REGISTRY = {
|
||||
"gelu": nn.GELU(),
|
||||
"gelu_fast": FastGELU(),
|
||||
@ -52,9 +72,27 @@ _ACTIVATION_REGISTRY = {
|
||||
}
|
||||
|
||||
|
||||
def get_act_fn(act_fn: str) -> nn.Module:
|
||||
def get_act_fn(
|
||||
act_fn_name: str,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
intermediate_size: Optional[int] = None,
|
||||
) -> nn.Module:
|
||||
"""Get an activation function by name."""
|
||||
act_fn = act_fn.lower()
|
||||
if act_fn in _ACTIVATION_REGISTRY:
|
||||
return _ACTIVATION_REGISTRY[act_fn]
|
||||
raise ValueError(f"Activation function {act_fn!r} is not supported.")
|
||||
act_fn_name = act_fn_name.lower()
|
||||
if act_fn_name not in _ACTIVATION_REGISTRY:
|
||||
raise ValueError(
|
||||
f"Activation function {act_fn_name!r} is not supported.")
|
||||
|
||||
act_fn = _ACTIVATION_REGISTRY[act_fn_name]
|
||||
if quant_config is not None:
|
||||
if act_fn_name in quant_config.get_scaled_act_names():
|
||||
if intermediate_size is None:
|
||||
raise ValueError(
|
||||
"intermediate_size must be specified for scaled "
|
||||
"activation functions.")
|
||||
return ScaledActivation(
|
||||
act_fn,
|
||||
intermediate_size,
|
||||
params_dtype=torch.get_default_dtype(),
|
||||
)
|
||||
return act_fn
|
||||
|
@ -1,5 +1,5 @@
|
||||
"""Multi-head attention."""
|
||||
from typing import List, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -9,35 +9,21 @@ from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
|
||||
|
||||
from vllm import attention_ops
|
||||
from vllm import cache_ops
|
||||
from vllm import pos_encoding_ops
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
|
||||
_SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256]
|
||||
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
|
||||
_PARTITION_SIZE = 512
|
||||
|
||||
|
||||
class PagedAttention(nn.Module):
|
||||
# pylint: disable=line-too-long
|
||||
"""GPT-style multi-head PagedAttention.
|
||||
|
||||
This class takes flattened 1D query, key, and value tensors as input. The
|
||||
input 1D tensors can either contain prompt tokens or generation tokens, in
|
||||
addition to paddings.
|
||||
|
||||
If the input tensors contain prompt tokens, the layout is as follows:
|
||||
|
||||
|<---------------------- num_valid_tokens ---------------------->|
|
||||
|<--------------- num_prompt_tokens -------------->|
|
||||
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|<--padding-->|
|
||||
|
||||
Otherwise, the layout is as follows:
|
||||
|
||||
|<------------------ num_valid_tokens ------------------->|
|
||||
|<------- num_generation_tokens (M) ------->|
|
||||
|<--generation_0-->|...|<--generation_M-1-->|<--padding-->|
|
||||
|
||||
The prompts might have different lengths, while the generation tokens always
|
||||
have length 1. The paddings are appended to make the input length a multiple
|
||||
of 8, which is desirable for Tensor Cores.
|
||||
This class takes query, key, and value tensors as input. The input tensors
|
||||
can either contain prompt tokens or generation tokens, in addition to
|
||||
paddings.
|
||||
|
||||
The class does the following:
|
||||
1. Perform multi_query_kv_attention for the prompts. This operation does
|
||||
@ -49,19 +35,21 @@ class PagedAttention(nn.Module):
|
||||
4. Perform single_query_cached_kv_attention for the generation tokens.
|
||||
This operation reads the previous key and value tensors from the KV
|
||||
cache.
|
||||
5. Output a flattened 1D tensor.
|
||||
5. Return the output tensor.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: Optional[int] = None) -> None:
|
||||
num_kv_heads: Optional[int] = None,
|
||||
sliding_window: Optional[int] = None) -> None:
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
||||
self.sliding_window = sliding_window
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
@ -73,13 +61,21 @@ class PagedAttention(nn.Module):
|
||||
raise ValueError(f"head_size ({self.head_size}) is not supported. "
|
||||
f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.")
|
||||
|
||||
def set_attn_bias(self, input_metadata: InputMetadata) -> None:
|
||||
if input_metadata.attn_bias:
|
||||
def set_attn_bias(
|
||||
self,
|
||||
input_metadata: InputMetadata,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
del dtype # Unused.
|
||||
if input_metadata.attn_bias is not None:
|
||||
# Already set by a previous layer.
|
||||
return
|
||||
prompt_lens = input_metadata.prompt_lens
|
||||
prompt_lens = [input_metadata.max_prompt_len
|
||||
] * input_metadata.num_prompts
|
||||
attn_bias = BlockDiagonalCausalMask.from_seqlens(prompt_lens)
|
||||
input_metadata.attn_bias.append(attn_bias)
|
||||
if self.sliding_window is not None:
|
||||
attn_bias = attn_bias.make_local_attention(self.sliding_window)
|
||||
input_metadata.attn_bias = attn_bias
|
||||
|
||||
def multi_query_kv_attention(
|
||||
self,
|
||||
@ -98,7 +94,6 @@ class PagedAttention(nn.Module):
|
||||
value: shape = [num_prompt_tokens, num_kv_heads, head_size]
|
||||
input_metadata: metadata for paged attention.
|
||||
"""
|
||||
|
||||
if self.num_kv_heads != self.num_heads:
|
||||
# Project the key and value tensors to the desired number of heads.
|
||||
key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=1)
|
||||
@ -111,7 +106,7 @@ class PagedAttention(nn.Module):
|
||||
query.unsqueeze(0),
|
||||
key.unsqueeze(0),
|
||||
value.unsqueeze(0),
|
||||
attn_bias=input_metadata.attn_bias[0],
|
||||
attn_bias=input_metadata.attn_bias,
|
||||
p=0.0,
|
||||
scale=self.scale,
|
||||
)
|
||||
@ -119,6 +114,14 @@ class PagedAttention(nn.Module):
|
||||
output.copy_(out.squeeze(0))
|
||||
return output
|
||||
|
||||
def get_alibi_slopes(self) -> Optional[torch.Tensor]:
|
||||
"""Returns the slopes for the alibi attention bias.
|
||||
|
||||
Returns:
|
||||
slopes: shape = [num_heads]
|
||||
"""
|
||||
return None
|
||||
|
||||
def single_query_cached_kv_attention(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
@ -126,6 +129,7 @@ class PagedAttention(nn.Module):
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
) -> None:
|
||||
"""PagedAttention for the generation tokens.
|
||||
|
||||
@ -137,21 +141,67 @@ class PagedAttention(nn.Module):
|
||||
value_cache: shape = [num_blocks, num_kv_heads, head_size,
|
||||
block_size]
|
||||
input_metadata: metadata for paged attention.
|
||||
alibi_slopes: shape = [num_heads]
|
||||
"""
|
||||
block_size = value_cache.shape[3]
|
||||
attention_ops.single_query_cached_kv_attention(
|
||||
output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
self.head_mapping,
|
||||
self.scale,
|
||||
input_metadata.block_tables,
|
||||
input_metadata.context_lens,
|
||||
block_size,
|
||||
input_metadata.max_context_len,
|
||||
None, # alibi_slopes
|
||||
)
|
||||
num_seqs, num_heads, head_size = query.shape
|
||||
max_num_partitions = (
|
||||
(input_metadata.max_context_len + _PARTITION_SIZE - 1) //
|
||||
_PARTITION_SIZE)
|
||||
# NOTE(woosuk): We use a simple heuristic to decide whether to use
|
||||
# PagedAttention V1 or V2. If the number of partitions is 1, we use
|
||||
# V1 to avoid the overhead of reduction. Also, if the number of
|
||||
# sequences or heads is large, we use V1 since there is enough work
|
||||
# to parallelize.
|
||||
# TODO(woosuk): Tune this heuristic.
|
||||
# For context len > 8192, use V2 kernel to avoid shared memory shortage.
|
||||
use_v1 = input_metadata.max_context_len <= 8192 and (
|
||||
max_num_partitions == 1 or num_seqs * num_heads > 512)
|
||||
if use_v1:
|
||||
# Run PagedAttention V1.
|
||||
attention_ops.paged_attention_v1(
|
||||
output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
self.head_mapping,
|
||||
self.scale,
|
||||
input_metadata.block_tables,
|
||||
input_metadata.context_lens,
|
||||
block_size,
|
||||
input_metadata.max_context_len,
|
||||
alibi_slopes,
|
||||
)
|
||||
else:
|
||||
# Run PagedAttention V2.
|
||||
assert _PARTITION_SIZE % block_size == 0
|
||||
tmp_output = torch.empty(
|
||||
size=(num_seqs, num_heads, max_num_partitions, head_size),
|
||||
dtype=output.dtype,
|
||||
device=output.device,
|
||||
)
|
||||
exp_sums = torch.empty(
|
||||
size=(num_seqs, num_heads, max_num_partitions),
|
||||
dtype=torch.float32,
|
||||
device=output.device,
|
||||
)
|
||||
max_logits = torch.empty_like(exp_sums)
|
||||
attention_ops.paged_attention_v2(
|
||||
output,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
self.head_mapping,
|
||||
self.scale,
|
||||
input_metadata.block_tables,
|
||||
input_metadata.context_lens,
|
||||
block_size,
|
||||
input_metadata.max_context_len,
|
||||
alibi_slopes,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -166,12 +216,12 @@ class PagedAttention(nn.Module):
|
||||
"""PagedAttention forward pass.
|
||||
|
||||
NOTE: The query, key, and value tensors must be sliced from a qkv
|
||||
tensor of shape [num_tokens, 3 * num_heads * head_size].
|
||||
tensor of shape [batch_size, seq_len, 3 * num_heads * head_size].
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads * head_size]
|
||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||
query: shape = [batch_size, seq_len, num_heads * head_size]
|
||||
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
|
||||
value: shape = [batch_size, num_kv_heads * head_size]
|
||||
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
|
||||
block_size, x]
|
||||
value_cache: shape = [num_blocks, num_kv_heads, head_size,
|
||||
@ -180,9 +230,9 @@ class PagedAttention(nn.Module):
|
||||
cache_event: event to wait for the cache operations to finish.
|
||||
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
shape = [batch_size, seq_len, num_heads * head_size]
|
||||
"""
|
||||
|
||||
batch_size, seq_len, _ = query.shape
|
||||
# Reshape the query, key, and value tensors.
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
@ -196,12 +246,12 @@ class PagedAttention(nn.Module):
|
||||
if num_prompt_tokens > 0:
|
||||
# Prompt run.
|
||||
assert input_metadata.num_generation_tokens == 0
|
||||
self.set_attn_bias(input_metadata)
|
||||
self.set_attn_bias(input_metadata, dtype=query.dtype)
|
||||
self.multi_query_kv_attention(
|
||||
output[:num_prompt_tokens],
|
||||
query[:num_prompt_tokens],
|
||||
key[:num_prompt_tokens],
|
||||
value[:num_prompt_tokens],
|
||||
output,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
input_metadata,
|
||||
)
|
||||
|
||||
@ -212,16 +262,21 @@ class PagedAttention(nn.Module):
|
||||
# Reshape the keys and values and store them in the cache.
|
||||
# When key_cache and value_cache are not provided, the new key
|
||||
# and value vectors will not be cached.
|
||||
num_valid_tokens = input_metadata.num_valid_tokens
|
||||
if (num_valid_tokens > 0 and key_cache is not None
|
||||
and value_cache is not None):
|
||||
# The stride is 3 because the key and value are sliced from qkv.
|
||||
if key_cache is not None and value_cache is not None:
|
||||
key_to_cache = key
|
||||
value_to_cache = value
|
||||
slot_mapping = input_metadata.slot_mapping.view(-1)
|
||||
if input_metadata.to_cache is not None:
|
||||
key_to_cache = key_to_cache[input_metadata.to_cache]
|
||||
value_to_cache = value_to_cache[input_metadata.to_cache]
|
||||
slot_mapping = slot_mapping[input_metadata.to_cache]
|
||||
|
||||
cache_ops.reshape_and_cache(
|
||||
key[:num_valid_tokens],
|
||||
value[:num_valid_tokens],
|
||||
key_to_cache,
|
||||
value_to_cache,
|
||||
key_cache,
|
||||
value_cache,
|
||||
input_metadata.slot_mapping,
|
||||
slot_mapping,
|
||||
)
|
||||
|
||||
if input_metadata.num_generation_tokens > 0:
|
||||
@ -231,18 +286,18 @@ class PagedAttention(nn.Module):
|
||||
"key_cache and value_cache must be provided when "
|
||||
"generating tokens.")
|
||||
# Compute the attention op for generation tokens.
|
||||
self.single_query_cached_kv_attention(
|
||||
output[num_prompt_tokens:num_valid_tokens],
|
||||
query[num_prompt_tokens:num_valid_tokens], key_cache,
|
||||
value_cache, input_metadata)
|
||||
self.single_query_cached_kv_attention(output, query, key_cache,
|
||||
value_cache, input_metadata,
|
||||
self.get_alibi_slopes())
|
||||
|
||||
# Reshape the output tensor.
|
||||
# NOTE(woosuk): The output tensor may include paddings.
|
||||
return output.view(-1, self.num_heads * self.head_size)
|
||||
return output.view(batch_size, seq_len,
|
||||
self.num_heads * self.head_size)
|
||||
|
||||
|
||||
class PagedAttentionWithRoPE(PagedAttention):
|
||||
"""PagedAttention with rotary embedding."""
|
||||
"""PagedAttention with rotary positional embedding."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -254,25 +309,16 @@ class PagedAttentionWithRoPE(PagedAttention):
|
||||
base: int = 10000,
|
||||
num_kv_heads: Optional[int] = None,
|
||||
is_neox_style: bool = True,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
sliding_window: Optional[int] = None,
|
||||
) -> None:
|
||||
super().__init__(num_heads, head_size, scale, num_kv_heads)
|
||||
self.is_neox_style = is_neox_style
|
||||
|
||||
# Create the cos and sin cache.
|
||||
inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim))
|
||||
t = torch.arange(max_position).float()
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq.float())
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
cache = torch.cat((cos, sin), dim=-1)
|
||||
|
||||
# FIXME(woosuk): This assumes that we configure the default dtype when
|
||||
# initializing the model.
|
||||
# TODO(woosuk): Make it more robust.
|
||||
torch_dtype = torch.get_default_dtype()
|
||||
cache = cache.to(torch_dtype)
|
||||
# Embedding size: [max_position, rotary_dim]
|
||||
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||
super().__init__(num_heads,
|
||||
head_size,
|
||||
scale,
|
||||
num_kv_heads,
|
||||
sliding_window=sliding_window)
|
||||
self.rotary_emb = get_rope(head_size, rotary_dim, max_position, base,
|
||||
is_neox_style, rope_scaling)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -288,10 +334,10 @@ class PagedAttentionWithRoPE(PagedAttention):
|
||||
""" PagedAttention forward pass with rotary embedding.
|
||||
|
||||
Args:
|
||||
positions: shape = [num_tokens]
|
||||
query: shape = [num_tokens, num_heads * head_size]
|
||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||
positions: shape = [batch_size, seq_len]
|
||||
query: shape = [batch_size, seq_len, num_heads * head_size]
|
||||
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
|
||||
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
|
||||
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
|
||||
block_size, x]
|
||||
value_cache: shape = [num_blocks, num_kv_heads, head_size,
|
||||
@ -300,19 +346,12 @@ class PagedAttentionWithRoPE(PagedAttention):
|
||||
cache_event: event to wait for the cache operations to finish.
|
||||
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
shape = [batch_size, seq_len, num_heads * head_size]
|
||||
"""
|
||||
|
||||
# Apply rotary embedding to the query and key before passing them
|
||||
# to the attention op.
|
||||
pos_encoding_ops.rotary_embedding(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
self.head_size,
|
||||
self.cos_sin_cache,
|
||||
self.is_neox_style,
|
||||
)
|
||||
query, key = self.rotary_emb(positions, query, key)
|
||||
return super().forward(
|
||||
query,
|
||||
key,
|
||||
@ -339,34 +378,36 @@ class PagedAttentionWithALiBi(PagedAttention):
|
||||
slopes = torch.tensor(slopes, dtype=torch.float32)
|
||||
self.register_buffer("alibi_slopes", slopes, persistent=False)
|
||||
|
||||
def set_attn_bias(self, input_metadata: InputMetadata) -> None:
|
||||
if input_metadata.attn_bias:
|
||||
def set_attn_bias(self, input_metadata: InputMetadata,
|
||||
dtype: torch.dtype) -> None:
|
||||
if input_metadata.attn_bias is not None:
|
||||
# Already set by a previous layer.
|
||||
return
|
||||
# Generates ALiBi mask for each prompt.
|
||||
for prompt_len in input_metadata.prompt_lens:
|
||||
bias = torch.arange(prompt_len)
|
||||
# Note(zhuohan): HF uses
|
||||
# `bias = bias[None, :].repeat(prompt_len, 1)`
|
||||
# here. We find that both biases give the same results, but
|
||||
# the bias below more accurately follows the original ALiBi
|
||||
# paper.
|
||||
bias = bias[None, :] - bias[:, None]
|
||||
bias = bias.to(self.alibi_slopes.device)
|
||||
# Generates ALiBi mask based on the max prompt length.
|
||||
max_prompt_len = input_metadata.max_prompt_len
|
||||
bias = torch.arange(max_prompt_len, dtype=dtype)
|
||||
# NOTE(zhuohan): HF uses
|
||||
# `bias = bias[None, :].repeat(prompt_len, 1)`
|
||||
# here. We find that both biases give the same results, but
|
||||
# the bias below more accurately follows the original ALiBi
|
||||
# paper.
|
||||
bias = bias[None, :] - bias[:, None]
|
||||
bias = bias.to(self.alibi_slopes.device)
|
||||
|
||||
# When using custom attention bias, xformers requires the bias to
|
||||
# be sliced from a tensor whose length is a multiple of 8.
|
||||
padded_len = (prompt_len + 7) // 8 * 8
|
||||
bias = torch.empty(
|
||||
1, # batch_size
|
||||
self.num_heads,
|
||||
prompt_len,
|
||||
padded_len,
|
||||
device=self.alibi_slopes.device,
|
||||
)[:, :, :, :prompt_len].copy_(bias)
|
||||
bias.mul_(self.alibi_slopes[:, None, None])
|
||||
attn_bias = LowerTriangularMaskWithTensorBias(bias)
|
||||
input_metadata.attn_bias.append(attn_bias)
|
||||
# When using custom attention bias, xformers requires the bias to
|
||||
# be sliced from a tensor whose length is a multiple of 8.
|
||||
padded_len = (max_prompt_len + 7) // 8 * 8
|
||||
bias = torch.empty(
|
||||
input_metadata.num_prompts,
|
||||
self.num_heads,
|
||||
max_prompt_len,
|
||||
padded_len,
|
||||
device=self.alibi_slopes.device,
|
||||
dtype=dtype,
|
||||
)[:, :, :, :max_prompt_len].copy_(bias)
|
||||
bias.mul_(self.alibi_slopes[:, None, None])
|
||||
attn_bias = LowerTriangularMaskWithTensorBias(bias)
|
||||
input_metadata.attn_bias = attn_bias
|
||||
|
||||
def multi_query_kv_attention(
|
||||
self,
|
||||
@ -391,56 +432,20 @@ class PagedAttentionWithALiBi(PagedAttention):
|
||||
value = torch.repeat_interleave(value,
|
||||
self.num_queries_per_kv,
|
||||
dim=1)
|
||||
batch_size = input_metadata.num_prompts
|
||||
seq_len = input_metadata.max_prompt_len
|
||||
|
||||
# FIXME(woosuk): Because xformers does not support dynamic sequence
|
||||
# lengths with custom attention bias, we process each prompt one by
|
||||
# one. This is inefficient, especially when we have many short prompts.
|
||||
start = 0
|
||||
for i, prompt_len in enumerate(input_metadata.prompt_lens):
|
||||
end = start + prompt_len
|
||||
out = xops.memory_efficient_attention_forward(
|
||||
query[None, start:end],
|
||||
key[None, start:end],
|
||||
value[None, start:end],
|
||||
attn_bias=input_metadata.attn_bias[i],
|
||||
p=0.0,
|
||||
scale=self.scale,
|
||||
)
|
||||
# TODO(woosuk): Unnecessary copy. Optimize.
|
||||
output[start:end].copy_(out.squeeze(0))
|
||||
start += prompt_len
|
||||
out = xops.memory_efficient_attention_forward(
|
||||
query.view(batch_size, seq_len, self.num_heads, self.head_size),
|
||||
key.view(batch_size, seq_len, self.num_heads, self.head_size),
|
||||
value.view(batch_size, seq_len, self.num_heads, self.head_size),
|
||||
attn_bias=input_metadata.attn_bias,
|
||||
p=0.0,
|
||||
scale=self.scale,
|
||||
)
|
||||
# TODO(woosuk): Unnecessary copy. Optimize.
|
||||
output.copy_(out.view(-1, self.num_heads, self.head_size))
|
||||
return output
|
||||
|
||||
def single_query_cached_kv_attention(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
) -> None:
|
||||
"""PagedAttention with ALiBi bias for the generation tokens.
|
||||
|
||||
Args:
|
||||
output: shape = [num_generation_tokens, num_heads, head_size]
|
||||
query: shape = [num_generation_tokens, num_heads, head_size]
|
||||
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
|
||||
block_size, x]
|
||||
value_cache: shape = [num_blocks, num_kv_heads, head_size,
|
||||
block_size]
|
||||
input_metadata: metadata for paged attention.
|
||||
"""
|
||||
block_size = value_cache.shape[3]
|
||||
attention_ops.single_query_cached_kv_attention(
|
||||
output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
self.head_mapping,
|
||||
self.scale,
|
||||
input_metadata.block_tables,
|
||||
input_metadata.context_lens,
|
||||
block_size,
|
||||
input_metadata.max_context_len,
|
||||
self.alibi_slopes,
|
||||
)
|
||||
def get_alibi_slopes(self) -> Optional[torch.Tensor]:
|
||||
return self.alibi_slopes
|
||||
|
@ -1,4 +1,6 @@
|
||||
"""Custom normalization layers."""
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
@ -21,7 +23,19 @@ class RMSNorm(nn.Module):
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
if residual is not None:
|
||||
layernorm_ops.fused_add_rms_norm(
|
||||
x,
|
||||
residual,
|
||||
self.weight.data,
|
||||
self.variance_epsilon,
|
||||
)
|
||||
return x, residual
|
||||
out = torch.empty_like(x)
|
||||
layernorm_ops.rms_norm(
|
||||
out,
|
||||
|
541
vllm/model_executor/layers/linear.py
Normal file
@ -0,0 +1,541 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.parallel_utils.communication_op import (
|
||||
tensor_model_parallel_all_reduce, tensor_model_parallel_all_gather)
|
||||
from vllm.model_executor.parallel_utils.utils import (
|
||||
divide, split_tensor_along_last_dim)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class LinearMethodBase(ABC):
|
||||
"""Base class for different (maybe quantized) linear methods."""
|
||||
|
||||
@abstractmethod
|
||||
def create_weights(self, input_size: int, output_size: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, torch.Tensor]:
|
||||
"""Create weights for a linear layer."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def apply_weights(self,
|
||||
weights: Dict[str, torch.Tensor],
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
"""Apply the weights to the input tensor."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class UnquantizedLinearMethod(LinearMethodBase):
|
||||
"""Linear method without quantization.
|
||||
|
||||
Args:
|
||||
separate_bias_add: If true, add bias separately after matrix
|
||||
multiplication.
|
||||
"""
|
||||
|
||||
def __init__(self, separate_bias_add: bool = False):
|
||||
self.separate_bias_add = separate_bias_add
|
||||
|
||||
def create_weights(self, input_size: int, output_size: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, torch.Tensor]:
|
||||
weight = Parameter(torch.empty(output_size,
|
||||
input_size,
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
|
||||
return {"weight": weight}
|
||||
|
||||
def apply_weights(self,
|
||||
weights: Dict[str, torch.Tensor],
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
weight = weights["weight"]
|
||||
if self.separate_bias_add:
|
||||
if bias:
|
||||
return F.linear(x, weight) + bias
|
||||
return F.linear(x, weight)
|
||||
return F.linear(x, weight, bias)
|
||||
|
||||
|
||||
class ReplicatedLinear(torch.nn.Module):
|
||||
"""Replicated linear layer.
|
||||
|
||||
Args:
|
||||
input_size: input dimension of the linear layer.
|
||||
output_size: output dimension of the linear layer.
|
||||
bias: If true, add bias.
|
||||
skip_bias_add: If true, skip adding bias but instead return it.
|
||||
params_dtype: Data type for the parameters.
|
||||
linear_method: (Maybe quantized) linear method.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
bias: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Keep input parameters
|
||||
self.input_size = input_size
|
||||
self.output_size = output_size
|
||||
self.skip_bias_add = skip_bias_add
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
self.params_dtype = params_dtype
|
||||
if linear_method is None:
|
||||
linear_method = UnquantizedLinearMethod()
|
||||
self.linear_method = linear_method
|
||||
self.linear_weights = self.linear_method.create_weights(
|
||||
self.input_size, self.output_size, self.params_dtype)
|
||||
for name, weight in self.linear_weights.items():
|
||||
self.register_parameter(name, weight)
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.output_size,
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=self.params_dtype))
|
||||
set_weight_attrs(self.bias, {"output_dim": 0})
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
output = self.linear_method.apply_weights(self.linear_weights, x, bias)
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
return output, output_bias
|
||||
|
||||
|
||||
class ColumnParallelLinear(torch.nn.Module):
|
||||
"""Linear layer with column parallelism.
|
||||
|
||||
The linear layer is defined as Y = XA + b. A is parallelized along
|
||||
its second dimension as A = [A_1, ..., A_p].
|
||||
|
||||
Args:
|
||||
input_size: first dimension of matrix A.
|
||||
output_size: second dimension of matrix A.
|
||||
bias: If true, add bias.
|
||||
gather_output: If true, call all-gather on output and make Y available
|
||||
to all GPUs, otherwise, every GPU will have its output
|
||||
which is Y_i = XA_i
|
||||
skip_bias_add: This was added to enable performance optimizations where
|
||||
bias can be fused with other element-wise operations. we
|
||||
skip adding bias but instead return it.
|
||||
params_dtype: Data type for the parameters.
|
||||
linear_method: (Maybe quantized) linear method.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
bias: bool = True,
|
||||
gather_output: bool = False,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Keep input parameters
|
||||
self.input_size = input_size
|
||||
self.output_size = output_size
|
||||
self.gather_output = gather_output
|
||||
# Divide the weight matrix along the last dimension.
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.output_size_per_partition = divide(output_size, tp_size)
|
||||
self.skip_bias_add = skip_bias_add
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
self.params_dtype = params_dtype
|
||||
if linear_method is None:
|
||||
linear_method = UnquantizedLinearMethod()
|
||||
self.linear_method = linear_method
|
||||
self.linear_weights = self.linear_method.create_weights(
|
||||
self.input_size, self.output_size_per_partition, self.params_dtype)
|
||||
for name, weight in self.linear_weights.items():
|
||||
self.register_parameter(name, weight)
|
||||
set_weight_attrs(weight, {"weight_loader": self.weight_loader})
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.output_size_per_partition,
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=params_dtype))
|
||||
set_weight_attrs(self.bias, {
|
||||
"output_dim": 0,
|
||||
"weight_loader": self.weight_loader,
|
||||
})
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
param_data = param.data
|
||||
if output_dim is not None:
|
||||
shard_size = param_data.shape[output_dim]
|
||||
start_idx = tp_rank * shard_size
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
||||
shard_size)
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
def forward(self, input_):
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
|
||||
# Matrix multiply.
|
||||
output_parallel = self.linear_method.apply_weights(
|
||||
self.linear_weights, input_, bias)
|
||||
if self.gather_output:
|
||||
# All-gather across the partitions.
|
||||
output = tensor_model_parallel_all_gather(output_parallel)
|
||||
else:
|
||||
output = output_parallel
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
return output, output_bias
|
||||
|
||||
|
||||
class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
"""Packed linear layers with column parallelism.
|
||||
|
||||
Similar to ColumnParallelLinear, but the weight matrix is concatenated
|
||||
along the output dimension. When the weight matrix is loaded, the
|
||||
different partitions are sharded separately.
|
||||
|
||||
Args:
|
||||
input_size: input dimension of the linear layer.
|
||||
output_sizes: list of output dimensions of the linear layer.
|
||||
bias: If true, add bias.
|
||||
gather_output: If true, call all-gather on output and make the output
|
||||
available to all GPUs, otherwise, every GPU will have
|
||||
its own output.
|
||||
skip_bias_add: This was added to enable performance optimizations where
|
||||
bias can be fused with other element-wise operations. we
|
||||
skip adding bias but instead return it.
|
||||
params_dtype: Data type for the parameters.
|
||||
linear_method: (Maybe quantized) linear method.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_sizes: List[int],
|
||||
bias: bool = True,
|
||||
gather_output: bool = False,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
self.output_sizes = output_sizes
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
assert all(output_size % tp_size == 0 for output_size in output_sizes)
|
||||
super().__init__(input_size, sum(output_sizes), bias, gather_output,
|
||||
skip_bias_add, params_dtype, linear_method)
|
||||
|
||||
def weight_loader(self,
|
||||
param: Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
loaded_shard_id: Optional[int] = None):
|
||||
param_data = param.data
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
if loaded_shard_id is None:
|
||||
# Loaded weight is already packed.
|
||||
if output_dim is None:
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
return
|
||||
current_shard_offset = 0
|
||||
shard_offsets = []
|
||||
for i, output_size in enumerate(self.output_sizes):
|
||||
shard_offsets.append((i, current_shard_offset, output_size))
|
||||
current_shard_offset += output_size
|
||||
packed_dim = getattr(param, "packed_dim", None)
|
||||
for shard_id, shard_offset, shard_size in shard_offsets:
|
||||
# If quantized, we need to adjust the offset and size to account
|
||||
# for the packing.
|
||||
if packed_dim == output_dim:
|
||||
shard_size = shard_size // param.pack_factor
|
||||
shard_offset = shard_offset // param.pack_factor
|
||||
loaded_weight_shard = loaded_weight.narrow(
|
||||
output_dim, shard_offset, shard_size)
|
||||
self.weight_loader(param, loaded_weight_shard, shard_id)
|
||||
return
|
||||
|
||||
assert loaded_shard_id < len(self.output_sizes)
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
if output_dim is not None:
|
||||
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
|
||||
shard_size = self.output_sizes[loaded_shard_id] // tp_size
|
||||
# If quantized, we need to adjust the offset and size to account
|
||||
# for the packing.
|
||||
packed_dim = getattr(param, "packed_dim", None)
|
||||
if packed_dim == output_dim:
|
||||
shard_size = shard_size // param.pack_factor
|
||||
shard_offset = shard_offset // param.pack_factor
|
||||
param_data = param_data.narrow(output_dim, shard_offset,
|
||||
shard_size)
|
||||
start_idx = tp_rank * shard_size
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
||||
shard_size)
|
||||
else:
|
||||
logger.warning(
|
||||
"Loading a weight without `output_dim` attribute in "
|
||||
"MergedColumnParallelLinear, assume the weight is "
|
||||
"the same for all partitions.")
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
|
||||
class QKVParallelLinear(ColumnParallelLinear):
|
||||
"""Linear layers for the attention's QKV transformation.
|
||||
|
||||
Linear layers for the linear transformation of the query, key, and value
|
||||
vectors in the attention layer. The weight matrix is concatenated along
|
||||
the output dimension. The layer is parallelized along the head dimension.
|
||||
When the number of key/value heads is smaller than the number of query
|
||||
heads (e.g., multi-query/grouped-query attention), the key/value head may
|
||||
be replicated while the query heads are partitioned.
|
||||
|
||||
Args:
|
||||
hidden_size: input hidden state size of the transformer.
|
||||
head_size: size of each attention head.
|
||||
total_num_heads: total number of attention query heads.
|
||||
total_num_kv_heads: total number of attention key/value heads. If
|
||||
None, assume total_num_kv_heads = total_num_heads.
|
||||
bias: If true, add bias.
|
||||
skip_bias_add: This was added to enable performance optimizations where
|
||||
bias can be fused with other element-wise operations. we
|
||||
skip adding bias but instead return it.
|
||||
params_dtype: Data type for the parameters.
|
||||
linear_method: (Maybe quantized) linear method.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
head_size: int,
|
||||
total_num_heads: int,
|
||||
total_num_kv_heads: Optional[int] = None,
|
||||
bias: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
self.hidden_size = hidden_size
|
||||
self.head_size = head_size
|
||||
self.total_num_heads = total_num_heads
|
||||
if total_num_kv_heads is None:
|
||||
total_num_kv_heads = total_num_heads
|
||||
self.total_num_kv_heads = total_num_kv_heads
|
||||
# Divide the weight matrix along the last dimension.
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.num_heads = divide(self.total_num_heads, tp_size)
|
||||
if tp_size >= self.total_num_kv_heads:
|
||||
self.num_kv_heads = 1
|
||||
self.num_kv_head_replicas = divide(tp_size,
|
||||
self.total_num_kv_heads)
|
||||
else:
|
||||
self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
|
||||
self.num_kv_head_replicas = 1
|
||||
input_size = self.hidden_size
|
||||
output_size = (self.num_heads +
|
||||
2 * self.num_kv_heads) * tp_size * self.head_size
|
||||
super().__init__(input_size, output_size, bias, False, skip_bias_add,
|
||||
params_dtype, linear_method)
|
||||
|
||||
def weight_loader(self,
|
||||
param: Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
loaded_shard_id: Optional[str] = None):
|
||||
param_data = param.data
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
if loaded_shard_id is None:
|
||||
# Loaded weight is already packed.
|
||||
if output_dim is None:
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
return
|
||||
shard_offsets = [
|
||||
# (shard_id, shard_offset, shard_size)
|
||||
("q", 0, self.total_num_heads * self.head_size),
|
||||
("k", self.total_num_heads * self.head_size,
|
||||
self.total_num_kv_heads * self.head_size),
|
||||
("v", (self.total_num_heads + self.total_num_kv_heads) *
|
||||
self.head_size, self.total_num_kv_heads * self.head_size),
|
||||
]
|
||||
packed_dim = getattr(param, "packed_dim", None)
|
||||
for shard_id, shard_offset, shard_size in shard_offsets:
|
||||
# If quantized, we need to adjust the offset and size to account
|
||||
# for the packing.
|
||||
if packed_dim == output_dim:
|
||||
shard_size = shard_size // param.pack_factor
|
||||
shard_offset = shard_offset // param.pack_factor
|
||||
loaded_weight_shard = loaded_weight.narrow(
|
||||
output_dim, shard_offset, shard_size)
|
||||
self.weight_loader(param, loaded_weight_shard, shard_id)
|
||||
return
|
||||
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
assert loaded_shard_id in ["q", "k", "v"]
|
||||
if output_dim is not None:
|
||||
if loaded_shard_id == "q":
|
||||
shard_offset = 0
|
||||
shard_size = self.num_heads * self.head_size
|
||||
elif loaded_shard_id == "k":
|
||||
shard_offset = self.num_heads * self.head_size
|
||||
shard_size = self.num_kv_heads * self.head_size
|
||||
elif loaded_shard_id == "v":
|
||||
shard_offset = (self.num_heads +
|
||||
self.num_kv_heads) * self.head_size
|
||||
shard_size = self.num_kv_heads * self.head_size
|
||||
# If quantized, we need to adjust the offset and size to account
|
||||
# for the packing.
|
||||
packed_dim = getattr(param, "packed_dim", None)
|
||||
if packed_dim == output_dim:
|
||||
shard_size = shard_size // param.pack_factor
|
||||
shard_offset = shard_offset // param.pack_factor
|
||||
param_data = param_data.narrow(output_dim, shard_offset,
|
||||
shard_size)
|
||||
shard_id = tp_rank // self.num_kv_head_replicas
|
||||
start_idx = shard_id * shard_size
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
||||
shard_size)
|
||||
else:
|
||||
logger.warning(
|
||||
"Loading a weight without `output_dim` attribute in "
|
||||
"QKVParallelLinear, assume the weight is the same "
|
||||
"for all partitions.")
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
|
||||
class RowParallelLinear(torch.nn.Module):
|
||||
"""Linear layer with row parallelism.
|
||||
|
||||
The linear layer is defined as Y = XA + b. A is parallelized along
|
||||
its first dimension and X along its second dimension as:
|
||||
- -
|
||||
| A_1 |
|
||||
| . |
|
||||
A = | . | X = [X_1, ..., X_p]
|
||||
| . |
|
||||
| A_p |
|
||||
- -
|
||||
Arguments:
|
||||
input_size: first dimension of matrix A.
|
||||
output_size: second dimension of matrix A.
|
||||
bias: If true, add bias. Note that bias is not parallelized.
|
||||
input_is_parallel: If true, we assume that the input is already
|
||||
split across the GPUs and we do not split
|
||||
again.
|
||||
skip_bias_add: This was added to enable performance optimization where
|
||||
bias can be fused with other element-wise operations.
|
||||
We skip adding bias but instead return it.
|
||||
params_dtype: Data type for the parameters.
|
||||
linear_method: (Maybe quantized) linear method.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
bias: bool = True,
|
||||
input_is_parallel: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
reduce_results: bool = True,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
# Keep input parameters
|
||||
self.input_size = input_size
|
||||
self.output_size = output_size
|
||||
self.input_is_parallel = input_is_parallel
|
||||
self.reduce_results = reduce_results
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
self.params_dtype = params_dtype
|
||||
|
||||
# Divide the weight matrix along the last dimension.
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.input_size_per_partition = divide(input_size, self.tp_size)
|
||||
self.skip_bias_add = skip_bias_add
|
||||
if linear_method is None:
|
||||
linear_method = UnquantizedLinearMethod()
|
||||
self.linear_method = linear_method
|
||||
self.linear_weights = self.linear_method.create_weights(
|
||||
self.input_size_per_partition, self.output_size, self.params_dtype)
|
||||
for name, weight in self.linear_weights.items():
|
||||
self.register_parameter(name, weight)
|
||||
set_weight_attrs(weight, {"weight_loader": self.weight_loader})
|
||||
|
||||
if not reduce_results and (bias and not skip_bias_add):
|
||||
raise ValueError("When not reduce the results, adding bias to the "
|
||||
"results can lead to incorrect results")
|
||||
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.output_size,
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=params_dtype))
|
||||
set_weight_attrs(self.bias, {
|
||||
"output_dim": 0,
|
||||
"weight_loader": self.weight_loader,
|
||||
})
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
input_dim = getattr(param, "input_dim", None)
|
||||
param_data = param.data
|
||||
if input_dim is not None:
|
||||
shard_size = param_data.shape[input_dim]
|
||||
start_idx = tp_rank * shard_size
|
||||
loaded_weight = loaded_weight.narrow(input_dim, start_idx,
|
||||
shard_size)
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
def forward(self, input_):
|
||||
# Set up backprop all-reduce.
|
||||
if self.input_is_parallel:
|
||||
input_parallel = input_
|
||||
else:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
splitted_input = split_tensor_along_last_dim(
|
||||
input_, num_partitions=self.tp_size)
|
||||
input_parallel = splitted_input[tp_rank].contiguous()
|
||||
|
||||
# Matrix multiply.
|
||||
output_parallel = self.linear_method.apply_weights(
|
||||
self.linear_weights, input_parallel)
|
||||
if self.reduce_results and self.tp_size > 1:
|
||||
output_ = tensor_model_parallel_all_reduce(output_parallel)
|
||||
else:
|
||||
output_ = output_parallel
|
||||
|
||||
if not self.skip_bias_add:
|
||||
output = output_ + self.bias if self.bias is not None else output_
|
||||
output_bias = None
|
||||
else:
|
||||
output = output_
|
||||
output_bias = self.bias
|
||||
return output, output_bias
|
22
vllm/model_executor/layers/quantization/__init__.py
Normal file
@ -0,0 +1,22 @@
|
||||
from typing import Type
|
||||
|
||||
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
||||
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
|
||||
_QUANTIZATION_CONFIG_REGISTRY = {
|
||||
"awq": AWQConfig,
|
||||
"squeezellm": SqueezeLLMConfig,
|
||||
}
|
||||
|
||||
|
||||
def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
||||
if quantization not in _QUANTIZATION_CONFIG_REGISTRY:
|
||||
raise ValueError(f"Invalid quantization method: {quantization}")
|
||||
return _QUANTIZATION_CONFIG_REGISTRY[quantization]
|
||||
|
||||
|
||||
__all__ = [
|
||||
"QuantizationConfig",
|
||||
"get_quantization_config",
|
||||
]
|
158
vllm/model_executor/layers/quantization/awq.py
Normal file
@ -0,0 +1,158 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm import quantization_ops
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
set_weight_attrs)
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
|
||||
|
||||
class AWQConfig(QuantizationConfig):
|
||||
"""Config class for AWQ.
|
||||
|
||||
Reference: https://arxiv.org/abs/2306.00978
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
zero_point: bool,
|
||||
) -> None:
|
||||
self.weight_bits = weight_bits
|
||||
self.group_size = group_size
|
||||
self.zero_point = zero_point
|
||||
|
||||
if self.weight_bits != 4:
|
||||
raise ValueError(
|
||||
"Currently, only 4-bit weight quantization is supported for "
|
||||
f"AWQ, but got {self.weight_bits} bits.")
|
||||
self.pack_factor = 32 // self.weight_bits
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"AWQConfig(weight_bits={self.weight_bits}, "
|
||||
f"group_size={self.group_size}, "
|
||||
f"zero_point={self.zero_point})")
|
||||
|
||||
def get_name(self) -> str:
|
||||
return "awq"
|
||||
|
||||
def get_supported_act_dtypes(self) -> List[torch.dtype]:
|
||||
return [torch.half]
|
||||
|
||||
def get_min_capability(self) -> int:
|
||||
# The AWQ kernel only supports Turing or newer GPUs.
|
||||
return 75
|
||||
|
||||
@staticmethod
|
||||
def get_config_filenames() -> List[str]:
|
||||
return [
|
||||
"quant_config.json", # E.g., casperhansen/vicuna-7b-v1.5-awq
|
||||
"quantize_config.json", # E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq # pylint: disable=line-too-long
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "AWQConfig":
|
||||
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
|
||||
group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
|
||||
zero_point = cls.get_from_keys(config, ["zero_point"])
|
||||
return cls(weight_bits, group_size, zero_point)
|
||||
|
||||
def get_linear_method(self) -> "AWQLinearMethod":
|
||||
return AWQLinearMethod(self)
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
|
||||
|
||||
|
||||
class AWQLinearMethod(LinearMethodBase):
|
||||
"""Linear method for AWQ.
|
||||
|
||||
Args:
|
||||
quant_config: The AWQ quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: AWQConfig):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, input_size: int, output_size: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, torch.Tensor]:
|
||||
if input_size % self.quant_config.group_size != 0:
|
||||
raise ValueError(
|
||||
"The input size is not aligned with the quantized "
|
||||
"weight shape. This can be caused by too large "
|
||||
"tensor parallel size.")
|
||||
if output_size % self.quant_config.pack_factor != 0:
|
||||
raise ValueError(
|
||||
"The output size is not aligned with the quantized "
|
||||
"weight shape. This can be caused by too large "
|
||||
"tensor parallel size.")
|
||||
|
||||
qweight = Parameter(
|
||||
torch.empty(
|
||||
input_size,
|
||||
output_size // self.quant_config.pack_factor,
|
||||
device="cuda",
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
qweight, {
|
||||
"input_dim": 0,
|
||||
"output_dim": 1,
|
||||
"packed_dim": 1,
|
||||
"pack_factor": self.quant_config.pack_factor,
|
||||
})
|
||||
qzeros = Parameter(
|
||||
torch.empty(
|
||||
input_size // self.quant_config.group_size,
|
||||
output_size // self.quant_config.pack_factor,
|
||||
device="cuda",
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
qzeros, {
|
||||
"input_dim": 0,
|
||||
"output_dim": 1,
|
||||
"packed_dim": 1,
|
||||
"pack_factor": self.quant_config.pack_factor,
|
||||
})
|
||||
scales = Parameter(
|
||||
torch.empty(
|
||||
input_size // self.quant_config.group_size,
|
||||
output_size,
|
||||
device="cuda",
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(scales, {
|
||||
"input_dim": 0,
|
||||
"output_dim": 1,
|
||||
})
|
||||
return {
|
||||
"qweight": qweight,
|
||||
"qzeros": qzeros,
|
||||
"scales": scales,
|
||||
}
|
||||
|
||||
def apply_weights(self,
|
||||
weights: Dict[str, torch.Tensor],
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
qweight = weights["qweight"]
|
||||
qzeros = weights["qzeros"]
|
||||
scales = weights["scales"]
|
||||
pack_factor = self.quant_config.pack_factor
|
||||
out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, ))
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
out = quantization_ops.awq_gemm(reshaped_x, qweight, scales, qzeros,
|
||||
pack_factor)
|
||||
if bias is not None:
|
||||
out = out + bias
|
||||
return out.reshape(out_shape)
|
64
vllm/model_executor/layers/quantization/base_config.py
Normal file
@ -0,0 +1,64 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.linear import LinearMethodBase
|
||||
|
||||
|
||||
class QuantizationConfig(ABC):
|
||||
"""Base class for quantization configs."""
|
||||
|
||||
@abstractmethod
|
||||
def get_name(self) -> str:
|
||||
"""Name of the quantization method."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_supported_act_dtypes(self) -> List[torch.dtype]:
|
||||
"""List of supported activation dtypes."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_min_capability(self) -> int:
|
||||
"""Minimum GPU capability to support the quantization method.
|
||||
|
||||
E.g., 70 for Volta, 75 for Turing, 80 for Ampere.
|
||||
This requirement is due to the custom CUDA kernels used by the
|
||||
quantization method.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_config_filenames() -> List[str]:
|
||||
"""List of filenames to search for in the model directory."""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "QuantizationConfig":
|
||||
"""Create a config class from the model's quantization config."""
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any:
|
||||
"""Get a value from the model's quantization config."""
|
||||
for key in keys:
|
||||
if key in config:
|
||||
return config[key]
|
||||
raise ValueError(f"Cannot find any of {keys} in the model's "
|
||||
"quantization config.")
|
||||
|
||||
@abstractmethod
|
||||
def get_linear_method(self) -> LinearMethodBase:
|
||||
"""Get the linear method to use for the quantized linear layer."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
"""Returns the activation function names that should be post-scaled.
|
||||
|
||||
For now, this is only used by AWQ.
|
||||
"""
|
||||
raise NotImplementedError
|
124
vllm/model_executor/layers/quantization/squeezellm.py
Normal file
@ -0,0 +1,124 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm import quantization_ops
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
set_weight_attrs)
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
|
||||
|
||||
class SqueezeLLMConfig(QuantizationConfig):
|
||||
"""Config class for SqueezeLLM.
|
||||
|
||||
Reference: https://arxiv.org/pdf/2306.07629
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_bits: int,
|
||||
) -> None:
|
||||
self.weight_bits = weight_bits
|
||||
|
||||
if self.weight_bits != 4:
|
||||
raise ValueError(
|
||||
"Currently, only 4-bit weight quantization is supported for "
|
||||
f"SqueezeLLM, but got {self.weight_bits} bits.")
|
||||
|
||||
self.pack_factor = 32 // self.weight_bits
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"SqueezeLLMConfig(weight_bits={self.weight_bits})"
|
||||
|
||||
def get_name(self) -> str:
|
||||
return "squeezellm"
|
||||
|
||||
def get_supported_act_dtypes(self) -> List[torch.dtype]:
|
||||
return [torch.half]
|
||||
|
||||
def get_min_capability(self) -> int:
|
||||
return 70
|
||||
|
||||
@staticmethod
|
||||
def get_config_filenames() -> List[str]:
|
||||
return ["quant_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "SqueezeLLMConfig":
|
||||
weight_bits = cls.get_from_keys(config, ["wbits"])
|
||||
return cls(weight_bits)
|
||||
|
||||
def get_linear_method(self) -> "SqueezeLLMLinearMethod":
|
||||
return SqueezeLLMLinearMethod(self)
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return []
|
||||
|
||||
|
||||
class SqueezeLLMLinearMethod(LinearMethodBase):
|
||||
"""Linear method for SqueezeLLM.
|
||||
|
||||
Args:
|
||||
quant_config: The SqueezeLLM quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: SqueezeLLMConfig):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, input_size: int, output_size: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, torch.Tensor]:
|
||||
if input_size % self.quant_config.pack_factor != 0:
|
||||
raise ValueError(
|
||||
"The input size is not aligned with the quantized "
|
||||
"weight shape. This can be caused by too large "
|
||||
"tensor parallel size.")
|
||||
qweight = Parameter(
|
||||
torch.empty(
|
||||
input_size // self.quant_config.pack_factor,
|
||||
output_size,
|
||||
device="cuda",
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
qweight, {
|
||||
"input_dim": 0,
|
||||
"output_dim": 1,
|
||||
"packed_dim": 0,
|
||||
"pack_factor": self.quant_config.pack_factor,
|
||||
})
|
||||
lookup_table = Parameter(
|
||||
torch.empty(
|
||||
output_size,
|
||||
self.quant_config.weight_bits**2,
|
||||
device="cuda",
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(lookup_table, {
|
||||
"output_dim": 0,
|
||||
})
|
||||
return {
|
||||
"qweight": qweight,
|
||||
"lookup_table": lookup_table,
|
||||
}
|
||||
|
||||
def apply_weights(self,
|
||||
weights: Dict[str, torch.Tensor],
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
qweight = weights["qweight"]
|
||||
lookup_table = weights["lookup_table"]
|
||||
out_shape = x.shape[:-1] + (qweight.shape[-1], )
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
# NOTE: The output tensor should be zero-initialized.
|
||||
out = torch.zeros(out_shape, device="cuda", dtype=torch.float16)
|
||||
quantization_ops.squeezellm_gemm(reshaped_x, qweight, out,
|
||||
lookup_table)
|
||||
|
||||
if bias is not None:
|
||||
out = out + bias
|
||||
return out.reshape(out_shape)
|
316
vllm/model_executor/layers/rotary_embedding.py
Normal file
@ -0,0 +1,316 @@
|
||||
# coding=utf-8
|
||||
# Adapted from
|
||||
# https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/llama/modeling_llama.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Rotary Positional Embeddings."""
|
||||
import math
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm import pos_encoding_ops
|
||||
|
||||
|
||||
class RotaryEmbedding(nn.Module):
|
||||
"""Original rotary positional embedding."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.head_size = head_size
|
||||
self.rotary_dim = rotary_dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
self.is_neox_style = is_neox_style
|
||||
|
||||
cache = self._compute_cos_sin_cache()
|
||||
cache = cache.to(torch.get_default_dtype())
|
||||
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||
|
||||
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
|
||||
"""Compute the inverse frequency."""
|
||||
# NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`.
|
||||
# However, we use `torch.arange(..., dtype=torch.float)` instead to
|
||||
# avoid numerical issues with large base values (e.g., 10000000).
|
||||
# This may cause a slight numerical difference between the HF
|
||||
# implementation and ours.
|
||||
# NOTE(woosuk): To exactly match the HF implementation, we need to
|
||||
# use CPU to compute the cache and then move it to GPU. However, we
|
||||
# create the cache on GPU for faster initialization. This may cause
|
||||
# a slight numerical difference between the HF implementation and ours.
|
||||
inv_freq = 1.0 / (base**(torch.arange(
|
||||
0, self.rotary_dim, 2, dtype=torch.float, device="cuda") /
|
||||
self.rotary_dim))
|
||||
return inv_freq
|
||||
|
||||
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
||||
"""Compute the cos and sin cache."""
|
||||
inv_freq = self._compute_inv_freq(self.base)
|
||||
t = torch.arange(self.max_position_embeddings,
|
||||
dtype=torch.float,
|
||||
device="cuda")
|
||||
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
cache = torch.cat((cos, sin), dim=-1)
|
||||
return cache
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# pos_encoding_ops.rotary_embedding() is an in-place operation that
|
||||
# updates the query and key tensors.
|
||||
pos_encoding_ops.rotary_embedding(positions, query, key,
|
||||
self.head_size, self.cos_sin_cache,
|
||||
self.is_neox_style)
|
||||
return query, key
|
||||
|
||||
|
||||
class LinearScalingRotaryEmbedding(RotaryEmbedding):
|
||||
"""RotaryEmbedding extended with linear scaling.
|
||||
|
||||
Credits to the Reddit user /u/kaiokendev
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
scaling_factor: float,
|
||||
) -> None:
|
||||
self.scaling_factor = scaling_factor
|
||||
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
|
||||
is_neox_style)
|
||||
|
||||
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
||||
inv_freq = self._compute_inv_freq(self.base)
|
||||
# NOTE(woosuk): self.max_position_embeddings is the original
|
||||
# maximum length before applying the rope scaling.
|
||||
# Thus, the maximum length after applying the rope scaling is
|
||||
# self.max_position_embeddings * self.scaling_factor.
|
||||
max_len = self.max_position_embeddings * self.scaling_factor
|
||||
t = torch.arange(max_len, dtype=torch.float, device="cuda")
|
||||
t = t / self.scaling_factor
|
||||
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
cache = torch.cat((cos, sin), dim=-1)
|
||||
return cache
|
||||
|
||||
|
||||
class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
|
||||
"""RotaryEmbedding extended with Dynamic NTK scaling.
|
||||
|
||||
Credits to the Reddit users /u/bloc97 and /u/emozilla
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
scaling_factor: float,
|
||||
) -> None:
|
||||
self.scaling_factor = scaling_factor
|
||||
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
|
||||
is_neox_style)
|
||||
|
||||
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
||||
# NOTE(woosuk): self.max_position_embeddings is the original
|
||||
# maximum length before applying the rope scaling.
|
||||
# Thus, the maximum length after applying the rope scaling is
|
||||
# self.max_position_embeddings * self.scaling_factor.
|
||||
max_len = self.max_position_embeddings * self.scaling_factor
|
||||
base = self.base * (
|
||||
(self.scaling_factor * max_len / self.max_position_embeddings) -
|
||||
(self.scaling_factor - 1))**(self.rotary_dim /
|
||||
(self.rotary_dim - 2))
|
||||
inv_freq = self._compute_inv_freq(base)
|
||||
t = torch.arange(max_len, dtype=torch.float, device="cuda")
|
||||
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
cache = torch.cat((cos, sin), dim=-1)
|
||||
return cache
|
||||
|
||||
|
||||
# Inverse dim formula to find dim based on number of rotations
|
||||
def _yarn_find_correction_dim(num_rotations: int,
|
||||
dim: int,
|
||||
base: float = 10000,
|
||||
max_position_embeddings: int = 2048) -> float:
|
||||
return (dim * math.log(max_position_embeddings /
|
||||
(num_rotations * 2 * math.pi))) / (2 *
|
||||
math.log(base))
|
||||
|
||||
|
||||
# Find dim range bounds based on rotations
|
||||
def _yarn_find_correction_range(low_rot: int,
|
||||
high_rot: int,
|
||||
dim: int,
|
||||
base: float = 10000,
|
||||
max_position_embeddings: int = 2048) -> int:
|
||||
low = math.floor(
|
||||
_yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
|
||||
high = math.ceil(
|
||||
_yarn_find_correction_dim(high_rot, dim, base,
|
||||
max_position_embeddings))
|
||||
return max(low, 0), min(high, dim - 1) # Clamp values just in case
|
||||
|
||||
|
||||
def _yarn_linear_ramp_mask(low: float, high: float, dim: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device) -> torch.Tensor:
|
||||
if low == high:
|
||||
high += 0.001 # Prevent singularity
|
||||
|
||||
linear_func = (torch.arange(dim, dtype=dtype, device=device) -
|
||||
low) / (high - low)
|
||||
ramp_func = torch.clamp(linear_func, 0, 1)
|
||||
return ramp_func
|
||||
|
||||
|
||||
def _yarn_get_mscale(scale: float = 1) -> float:
|
||||
if scale <= 1:
|
||||
return 1.0
|
||||
return 0.1 * math.log(scale) + 1.0
|
||||
|
||||
|
||||
class YaRNScalingRotaryEmbedding(RotaryEmbedding):
|
||||
"""RotaryEmbedding extended with YaRN method.
|
||||
|
||||
Credits to Peng et al. github.com/jquesnelle/yarn
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
scaling_factor: float,
|
||||
*,
|
||||
extrapolation_factor: float = 1,
|
||||
attn_factor: float = 1,
|
||||
beta_fast: float = 32,
|
||||
beta_slow: float = 1,
|
||||
) -> None:
|
||||
self.scaling_factor = scaling_factor
|
||||
self.extrapolation_factor = extrapolation_factor
|
||||
self.attn_factor = attn_factor
|
||||
self.beta_fast = beta_fast
|
||||
self.beta_slow = beta_slow
|
||||
# Get n-d magnitude scaling corrected for interpolation
|
||||
self.mscale = float(
|
||||
_yarn_get_mscale(self.scaling_factor) * attn_factor)
|
||||
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
|
||||
is_neox_style)
|
||||
|
||||
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
|
||||
pos_freqs = self.base**(torch.arange(
|
||||
0, self.rotary_dim, 2, dtype=torch.float, device="cuda") /
|
||||
self.rotary_dim)
|
||||
inv_freq_extrapolation = 1.0 / pos_freqs
|
||||
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
|
||||
|
||||
low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow,
|
||||
self.rotary_dim, self.base,
|
||||
self.max_position_embeddings)
|
||||
# Get n-d rotational scaling corrected for extrapolation
|
||||
inv_freq_mask = (1 - _yarn_linear_ramp_mask(
|
||||
low, high, self.rotary_dim // 2, dtype=torch.float,
|
||||
device="cuda")) * self.extrapolation_factor
|
||||
inv_freq = inv_freq_interpolation * (
|
||||
1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
|
||||
return inv_freq
|
||||
|
||||
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
||||
inv_freq = self._compute_inv_freq(self.scaling_factor)
|
||||
t = torch.arange(self.max_position_embeddings * self.scaling_factor,
|
||||
device="cuda",
|
||||
dtype=torch.float32)
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||
cos = (freqs.cos() * self.mscale)
|
||||
sin = (freqs.sin() * self.mscale)
|
||||
cache = torch.cat((cos, sin), dim=-1)
|
||||
return cache
|
||||
|
||||
|
||||
def get_rope(
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position: int,
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
rope_scaling: Optional[Dict[str, Any]],
|
||||
) -> RotaryEmbedding:
|
||||
if rope_scaling is None:
|
||||
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
|
||||
is_neox_style)
|
||||
else:
|
||||
scaling_type = rope_scaling["type"]
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
if scaling_type == "linear":
|
||||
rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim,
|
||||
max_position, base,
|
||||
is_neox_style,
|
||||
scaling_factor)
|
||||
elif scaling_type == "dynamic":
|
||||
rotary_emb = DynamicNTKScalingRotaryEmbedding(
|
||||
head_size, rotary_dim, max_position, base, is_neox_style,
|
||||
scaling_factor)
|
||||
elif scaling_type == "yarn":
|
||||
original_max_position = rope_scaling[
|
||||
"original_max_position_embeddings"]
|
||||
assert max_position == original_max_position * scaling_factor
|
||||
extra_kwargs = {
|
||||
k: v
|
||||
for k, v in rope_scaling.items()
|
||||
if k in ("extrapolation_factor", "attn_factor", "beta_fast",
|
||||
"beta_slow")
|
||||
}
|
||||
rotary_emb = YaRNScalingRotaryEmbedding(head_size, rotary_dim,
|
||||
original_max_position,
|
||||
base, is_neox_style,
|
||||
scaling_factor,
|
||||
**extra_kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
||||
return rotary_emb
|
@ -1,15 +1,15 @@
|
||||
"""A layer that samples the next tokens from the model's outputs."""
|
||||
from typing import Dict, List, Tuple, Optional
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||
gather_from_tensor_model_parallel_region)
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import SamplerOutput, SequenceOutputs
|
||||
from vllm.model_executor.parallel_utils.communication_op import (
|
||||
tensor_model_parallel_all_gather)
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
from vllm.sequence import (PromptLogprobs, SampleLogprobs, SamplerOutput,
|
||||
SequenceData, SequenceGroupOutputs, SequenceOutputs)
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
|
||||
@ -44,22 +44,21 @@ class Sampler(nn.Module):
|
||||
hidden_states = _prune_hidden_states(hidden_states, input_metadata)
|
||||
|
||||
# Get the logits for the next tokens.
|
||||
logits = torch.matmul(hidden_states, embedding.t())
|
||||
if embedding_bias is not None:
|
||||
logits += embedding_bias
|
||||
logits = gather_from_tensor_model_parallel_region(logits)
|
||||
# Remove paddings in vocab (if any).
|
||||
logits = logits[:, :self.vocab_size]
|
||||
logits = _get_logits(hidden_states, embedding, embedding_bias,
|
||||
self.vocab_size)
|
||||
|
||||
# Apply logits processors (if any).
|
||||
logits = _apply_logits_processors(logits, input_metadata)
|
||||
# Apply presence and frequency penalties.
|
||||
output_tokens = _get_output_tokens(input_metadata)
|
||||
assert len(output_tokens) == logits.shape[0]
|
||||
presence_penalties, frequency_penalties = _get_penalties(
|
||||
input_metadata)
|
||||
presence_penalties, frequency_penalties, repetition_penalties = (
|
||||
_get_penalties(input_metadata))
|
||||
assert len(presence_penalties) == logits.shape[0]
|
||||
assert len(frequency_penalties) == logits.shape[0]
|
||||
assert len(repetition_penalties) == logits.shape[0]
|
||||
logits = _apply_penalties(logits, output_tokens, presence_penalties,
|
||||
frequency_penalties, self.vocab_size)
|
||||
frequency_penalties, repetition_penalties)
|
||||
|
||||
# Apply temperature scaling.
|
||||
temperatures = _get_temperatures(input_metadata)
|
||||
@ -72,121 +71,179 @@ class Sampler(nn.Module):
|
||||
logits.div_(t.unsqueeze(dim=1))
|
||||
|
||||
# Apply top-p and top-k truncation.
|
||||
top_ps, top_ks = _get_top_p_top_k(input_metadata, self.vocab_size)
|
||||
top_ps, top_ks, min_ps = _get_top_p_top_k_min_p(
|
||||
input_metadata, self.vocab_size)
|
||||
assert len(top_ps) == len(top_ks) == logits.shape[0]
|
||||
do_top_p = any(p < 1.0 - _SAMPLING_EPS for p in top_ps)
|
||||
do_top_k = any(k != self.vocab_size for k in top_ks)
|
||||
if do_top_p or do_top_k:
|
||||
logits = _apply_top_p_top_k(logits, top_ps, top_ks)
|
||||
|
||||
do_min_p = any(mp > _SAMPLING_EPS for mp in min_ps)
|
||||
if do_min_p:
|
||||
logits = _apply_min_p(logits, min_ps)
|
||||
|
||||
# We use float32 for probabilities and log probabilities.
|
||||
# Compute the probabilities.
|
||||
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
||||
# Compute the log probabilities (before applying top-p and top-k).
|
||||
logprobs = torch.log(probs)
|
||||
# Compute the log probabilities.
|
||||
# Use log_softmax to ensure numerical stability.
|
||||
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
|
||||
|
||||
# Sample the next tokens.
|
||||
return _sample(probs, logprobs, input_metadata)
|
||||
sample_results = _sample(probs, logprobs, input_metadata)
|
||||
# Get the logprobs query results.
|
||||
prompt_logprobs, sample_logprobs = _get_logprobs(
|
||||
logprobs, input_metadata, sample_results)
|
||||
return _build_sampler_output(sample_results, input_metadata,
|
||||
prompt_logprobs, sample_logprobs)
|
||||
|
||||
|
||||
def _get_logits(hidden_states: torch.Tensor, embedding: torch.Tensor,
|
||||
embedding_bias: Optional[torch.Tensor],
|
||||
vocab_size: int) -> torch.Tensor:
|
||||
# Get the logits for the next tokens.
|
||||
logits = torch.matmul(hidden_states, embedding.t())
|
||||
if embedding_bias is not None:
|
||||
logits += embedding_bias
|
||||
logits = tensor_model_parallel_all_gather(logits)
|
||||
# Remove paddings in vocab (if any).
|
||||
logits = logits[:, :vocab_size]
|
||||
return logits
|
||||
|
||||
|
||||
def _prune_hidden_states(
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
) -> torch.Tensor:
|
||||
start_idx = 0
|
||||
last_token_indicies: List[int] = []
|
||||
for prompt_len in input_metadata.prompt_lens:
|
||||
last_token_indicies.append(start_idx + prompt_len - 1)
|
||||
start_idx += prompt_len
|
||||
last_token_indicies.extend(
|
||||
range(start_idx, start_idx + input_metadata.num_generation_tokens))
|
||||
return hidden_states.index_select(
|
||||
0, torch.tensor(last_token_indicies, device=hidden_states.device))
|
||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
return hidden_states.index_select(0, input_metadata.selected_token_indices)
|
||||
|
||||
|
||||
def _get_penalties(
|
||||
input_metadata: InputMetadata) -> Tuple[List[float], List[float]]:
|
||||
input_metadata: InputMetadata
|
||||
) -> Tuple[List[float], List[float], List[float]]:
|
||||
# Collect the presence and frequency penalties.
|
||||
presence_penalties: List[float] = []
|
||||
frequency_penalties: List[float] = []
|
||||
repetition_penalties: List[float] = []
|
||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
||||
seq_ids, sampling_params = seq_group
|
||||
p = sampling_params.presence_penalty
|
||||
f = sampling_params.frequency_penalty
|
||||
if i < input_metadata.num_prompts:
|
||||
# A prompt input.
|
||||
presence_penalties.append(p)
|
||||
frequency_penalties.append(f)
|
||||
else:
|
||||
# A generation token.
|
||||
presence_penalties += [p] * len(seq_ids)
|
||||
frequency_penalties += [f] * len(seq_ids)
|
||||
return presence_penalties, frequency_penalties
|
||||
r = sampling_params.repetition_penalty
|
||||
if (i < input_metadata.num_prompts
|
||||
and sampling_params.prompt_logprobs is not None):
|
||||
# NOTE: We do not apply presence and frequency penalties for the
|
||||
# prompt token positions where we don't sample new tokens.
|
||||
prompt_len = input_metadata.prompt_lens[i]
|
||||
presence_penalties += [0] * (prompt_len - 1)
|
||||
frequency_penalties += [0] * (prompt_len - 1)
|
||||
repetition_penalties += [1] * (prompt_len - 1)
|
||||
presence_penalties += [p] * len(seq_ids)
|
||||
frequency_penalties += [f] * len(seq_ids)
|
||||
repetition_penalties += [r] * len(seq_ids)
|
||||
return presence_penalties, frequency_penalties, repetition_penalties
|
||||
|
||||
|
||||
def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]:
|
||||
output_tokens: List[List[int]] = []
|
||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
||||
seq_ids, _ = seq_group
|
||||
if i < input_metadata.num_prompts:
|
||||
# A prompt input.
|
||||
# NOTE: While the prompt input usually has no output tokens,
|
||||
# it may have output tokens in the case of recomputation.
|
||||
seq_id = seq_ids[0]
|
||||
seq_ids, sampling_params = seq_group
|
||||
if (i < input_metadata.num_prompts
|
||||
and sampling_params.prompt_logprobs is not None):
|
||||
# NOTE: prompt token positions do not need output tokens to
|
||||
# compute penalties.
|
||||
prompt_len = input_metadata.prompt_lens[i]
|
||||
output_tokens.extend([] for _ in range(prompt_len - 1))
|
||||
for seq_id in seq_ids:
|
||||
seq_data = input_metadata.seq_data[seq_id]
|
||||
output_tokens.append(seq_data.output_token_ids)
|
||||
else:
|
||||
# A generation token.
|
||||
for seq_id in seq_ids:
|
||||
seq_data = input_metadata.seq_data[seq_id]
|
||||
output_tokens.append(seq_data.output_token_ids)
|
||||
return output_tokens
|
||||
|
||||
|
||||
def _apply_logits_processors(logits: torch.Tensor,
|
||||
input_metadata: InputMetadata) -> torch.Tensor:
|
||||
logits_row_idx = 0
|
||||
found_logits_processors = False
|
||||
for seq_ids, sampling_params in input_metadata.seq_groups:
|
||||
logits_processors = sampling_params.logits_processors
|
||||
if logits_processors:
|
||||
found_logits_processors = True
|
||||
for seq_id in seq_ids:
|
||||
logits_row = logits[logits_row_idx]
|
||||
token_ids = input_metadata.seq_data[seq_id].output_token_ids
|
||||
for logits_processor in logits_processors:
|
||||
logits_row = logits_processor(token_ids, logits_row)
|
||||
logits[logits_row_idx] = logits_row
|
||||
logits_row_idx += 1
|
||||
else:
|
||||
logits_row_idx += len(seq_ids)
|
||||
if found_logits_processors:
|
||||
assert logits_row_idx == logits.shape[0]
|
||||
return logits
|
||||
|
||||
|
||||
def _apply_penalties(
|
||||
logits: torch.Tensor,
|
||||
output_tokens: List[List[int]],
|
||||
presence_penalties: List[float],
|
||||
frequency_penalties: List[float],
|
||||
vocab_size: int,
|
||||
repetition_penalties: List[float],
|
||||
) -> torch.Tensor:
|
||||
num_seqs = logits.shape[0]
|
||||
# Collect the indices of sequences that have non-zero penalties.
|
||||
indices = []
|
||||
num_seqs, vocab_size = logits.shape
|
||||
for i in range(num_seqs):
|
||||
if not output_tokens[i]:
|
||||
continue
|
||||
p = presence_penalties[i]
|
||||
f = frequency_penalties[i]
|
||||
if abs(p) < _SAMPLING_EPS and abs(f) < _SAMPLING_EPS:
|
||||
r = repetition_penalties[i]
|
||||
if abs(p) < _SAMPLING_EPS and abs(f) < _SAMPLING_EPS and abs(
|
||||
r - 1.0) < _SAMPLING_EPS:
|
||||
continue
|
||||
indices.append(i)
|
||||
|
||||
# Return early if all sequences have zero penalties.
|
||||
if not indices:
|
||||
break
|
||||
else:
|
||||
# Return early if all sequences have zero penalties.
|
||||
return logits
|
||||
|
||||
bin_counts = []
|
||||
for i in indices:
|
||||
bin_counts.append(np.bincount(output_tokens[i], minlength=vocab_size))
|
||||
bin_counts = np.stack(bin_counts, axis=0)
|
||||
bin_counts = torch.from_numpy(bin_counts).to(dtype=logits.dtype,
|
||||
device=logits.device)
|
||||
max_output_len = max(len(tokens) for tokens in output_tokens)
|
||||
padded_output_tokens = [
|
||||
tokens + [vocab_size] * (max_output_len - len(tokens))
|
||||
for tokens in output_tokens
|
||||
]
|
||||
output_tokens_tensor = torch.tensor(padded_output_tokens,
|
||||
dtype=torch.long,
|
||||
device=logits.device)
|
||||
|
||||
frequency_penalties = [frequency_penalties[i] for i in indices]
|
||||
# Compute the bin counts for the output tokens.
|
||||
# vocab_size + 1 for padding.
|
||||
bin_counts = torch.zeros((num_seqs, vocab_size + 1),
|
||||
dtype=torch.long,
|
||||
device=logits.device)
|
||||
bin_counts.scatter_add_(1, output_tokens_tensor,
|
||||
torch.ones_like(output_tokens_tensor))
|
||||
bin_counts = bin_counts[:, :vocab_size] # Remove the padding bin.
|
||||
mask = bin_counts > 0
|
||||
|
||||
repetition_penalties = torch.tensor(repetition_penalties,
|
||||
dtype=logits.dtype,
|
||||
device=logits.device)
|
||||
frequency_penalties = torch.tensor(frequency_penalties,
|
||||
dtype=logits.dtype,
|
||||
device=logits.device)
|
||||
presence_penalties = [presence_penalties[i] for i in indices]
|
||||
presence_penalties = torch.tensor(presence_penalties,
|
||||
dtype=logits.dtype,
|
||||
device=logits.device)
|
||||
|
||||
repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size)
|
||||
repetition_penalties[~mask] = 1.0
|
||||
logits = torch.where(logits > 0, logits / repetition_penalties,
|
||||
logits * repetition_penalties)
|
||||
|
||||
# We follow the definition in OpenAI API.
|
||||
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
|
||||
logits[indices] -= frequency_penalties.unsqueeze(dim=1) * bin_counts
|
||||
presence_mask = (bin_counts > 0.0).to(dtype=logits.dtype)
|
||||
logits[indices] -= presence_penalties.unsqueeze(dim=1) * presence_mask
|
||||
logits -= frequency_penalties.unsqueeze(dim=1) * bin_counts
|
||||
logits -= presence_penalties.unsqueeze(dim=1) * mask
|
||||
return logits
|
||||
|
||||
|
||||
@ -201,38 +258,39 @@ def _get_temperatures(input_metadata: InputMetadata) -> List[float]:
|
||||
# (i.e., greedy sampling or beam search).
|
||||
# Set the temperature to 1 to avoid division by zero.
|
||||
temperature = 1.0
|
||||
|
||||
if i < input_metadata.num_prompts:
|
||||
# A prompt input.
|
||||
temperatures.append(temperature)
|
||||
else:
|
||||
# A generation token.
|
||||
temperatures += [temperature] * len(seq_ids)
|
||||
if (i < input_metadata.num_prompts
|
||||
and sampling_params.prompt_logprobs is not None):
|
||||
prompt_len = input_metadata.prompt_lens[i]
|
||||
temperatures += [temperature] * (prompt_len - 1)
|
||||
temperatures += [temperature] * len(seq_ids)
|
||||
return temperatures
|
||||
|
||||
|
||||
def _get_top_p_top_k(
|
||||
def _get_top_p_top_k_min_p(
|
||||
input_metadata: InputMetadata,
|
||||
vocab_size: int,
|
||||
) -> Tuple[List[float], List[int]]:
|
||||
) -> Tuple[List[float], List[int], List[float]]:
|
||||
top_ps: List[float] = []
|
||||
top_ks: List[int] = []
|
||||
min_ps: List[float] = []
|
||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
||||
seq_ids, sampling_params = seq_group
|
||||
top_p = sampling_params.top_p
|
||||
min_p = sampling_params.min_p
|
||||
# k should not be greater than the vocab size.
|
||||
top_k = min(sampling_params.top_k, vocab_size)
|
||||
# k=-1 means no truncation.
|
||||
top_k = vocab_size if top_k == -1 else top_k
|
||||
if i < input_metadata.num_prompts:
|
||||
# A prompt input.
|
||||
top_ps.append(top_p)
|
||||
top_ks.append(top_k)
|
||||
else:
|
||||
# A generation token.
|
||||
top_ps += [top_p] * len(seq_ids)
|
||||
top_ks += [top_k] * len(seq_ids)
|
||||
return top_ps, top_ks
|
||||
if (i < input_metadata.num_prompts
|
||||
and sampling_params.prompt_logprobs is not None):
|
||||
prompt_len = input_metadata.prompt_lens[i]
|
||||
top_ps += [top_p] * (prompt_len - 1)
|
||||
top_ks += [top_k] * (prompt_len - 1)
|
||||
min_ps += [min_p] * (prompt_len - 1)
|
||||
top_ps += [top_p] * len(seq_ids)
|
||||
top_ks += [top_k] * len(seq_ids)
|
||||
min_ps += [min_p] * len(seq_ids)
|
||||
return top_ps, top_ks, min_ps
|
||||
|
||||
|
||||
def _apply_top_p_top_k(
|
||||
@ -264,164 +322,308 @@ def _apply_top_p_top_k(
|
||||
return logits
|
||||
|
||||
|
||||
def _get_topk_logprobs(
|
||||
def _apply_min_p(
|
||||
logits: torch.Tensor,
|
||||
min_ps: List[float],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Adapted from
|
||||
https://github.com/oobabooga/text-generation-webui/blob/3146124ec01f02c8fb1650a6517cf1b60b537aaf/modules/sampler_hijack.py#L16C17-L16C17
|
||||
"""
|
||||
min_p = torch.tensor(min_ps, dtype=logits.dtype, device=logits.device)
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
top_probs, _ = probs.max(dim=-1, keepdim=True)
|
||||
scaled_min_p = min_p.unsqueeze(dim=1) * top_probs
|
||||
tokens_to_remove = probs < scaled_min_p
|
||||
logits = logits.masked_fill(tokens_to_remove, -float("inf"))
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
def _greedy_sample(
|
||||
selected_seq_groups: List[Tuple[List[int], SamplingParams]],
|
||||
logprobs: torch.Tensor,
|
||||
num_logprobs: Optional[int],
|
||||
) -> Dict[int, float]:
|
||||
if num_logprobs is None or num_logprobs == 0:
|
||||
return {}
|
||||
|
||||
topk_logprobs, topk_ids = torch.topk(logprobs, num_logprobs)
|
||||
if num_logprobs == 1:
|
||||
topk_logprobs = [topk_logprobs.item()]
|
||||
topk_ids = [topk_ids.item()]
|
||||
else:
|
||||
topk_logprobs = topk_logprobs.tolist()
|
||||
topk_ids = topk_ids.tolist()
|
||||
|
||||
token_to_logprob: Dict[int, float] = {}
|
||||
for token_id, logprob in zip(topk_ids, topk_logprobs):
|
||||
token_to_logprob[token_id] = logprob
|
||||
return token_to_logprob
|
||||
) -> List[Tuple[List[int], List[int]]]:
|
||||
samples = torch.argmax(logprobs, dim=-1).cpu()
|
||||
sample_idx = 0
|
||||
results = []
|
||||
for seq_group in selected_seq_groups:
|
||||
seq_ids, _ = seq_group
|
||||
num_parent_seqs = len(seq_ids)
|
||||
assert num_parent_seqs == 1, (
|
||||
"Greedy sampling should have only one seq.")
|
||||
parent_ids = list(range(num_parent_seqs))
|
||||
next_token_ids = [samples[sample_idx].item()]
|
||||
results.append((next_token_ids, parent_ids))
|
||||
sample_idx += num_parent_seqs
|
||||
assert sample_idx == logprobs.size(0)
|
||||
return results
|
||||
|
||||
|
||||
def _sample_from_prompt(
|
||||
prob: torch.Tensor,
|
||||
sampling_params: SamplingParams,
|
||||
) -> List[int]:
|
||||
if sampling_params.use_beam_search:
|
||||
# Beam search.
|
||||
beam_width = sampling_params.best_of
|
||||
# Sample 2 * beam_width candidates to make sure that with high
|
||||
# probability we can get `beam_width` candidates in addition to
|
||||
# the finished sequences for the next iteration. See
|
||||
# https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563
|
||||
# for details. See also HF reference:
|
||||
# https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
|
||||
_, next_token_ids = torch.topk(prob, 2 * beam_width)
|
||||
next_token_ids = next_token_ids.tolist()
|
||||
elif sampling_params.temperature < _SAMPLING_EPS:
|
||||
# Greedy sampling.
|
||||
assert sampling_params.best_of == 1
|
||||
next_token_id = torch.argmax(prob)
|
||||
next_token_ids = [next_token_id.item()]
|
||||
else:
|
||||
# Random sampling.
|
||||
# Sample `best_of` tokens for the prompt.
|
||||
num_seqs = sampling_params.best_of
|
||||
next_token_ids = torch.multinomial(prob,
|
||||
num_samples=num_seqs,
|
||||
replacement=True)
|
||||
next_token_ids = next_token_ids.tolist()
|
||||
return next_token_ids
|
||||
|
||||
|
||||
def _sample_from_generation_tokens(
|
||||
seq_ids: List[int],
|
||||
def _random_sample(
|
||||
selected_seq_groups: List[Tuple[List[int], SamplingParams]],
|
||||
is_prompts: List[bool],
|
||||
probs: torch.Tensor,
|
||||
logprobs: torch.Tensor,
|
||||
seq_logprobs: List[float],
|
||||
sampling_params: SamplingParams,
|
||||
) -> Tuple[List[int], List[int]]:
|
||||
# NOTE(woosuk): sampling_params.best_of can be greater than
|
||||
# len(seq_ids) because some sequences in the group might have
|
||||
# been already terminated.
|
||||
if sampling_params.use_beam_search:
|
||||
# Beam search.
|
||||
# Add cumulative logprobs for the sequences in the group.
|
||||
seq_logprobs = torch.tensor(seq_logprobs,
|
||||
dtype=torch.float,
|
||||
device=logprobs.device)
|
||||
logprobs = logprobs + seq_logprobs.unsqueeze(dim=1)
|
||||
) -> List[Tuple[List[int], List[int]]]:
|
||||
# Find the maximum best_of value of the prompt phase requests.
|
||||
max_best_of = 1
|
||||
for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
|
||||
if is_prompt:
|
||||
seq_ids, sampling_params = seq_group
|
||||
max_best_of = max(max_best_of, sampling_params.best_of)
|
||||
random_samples = torch.multinomial(probs,
|
||||
num_samples=max_best_of,
|
||||
replacement=True).cpu()
|
||||
sample_idx = 0
|
||||
results = []
|
||||
for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
|
||||
seq_ids, sampling_params = seq_group
|
||||
num_parent_seqs = len(seq_ids)
|
||||
if is_prompt:
|
||||
# Prompt phase.
|
||||
assert num_parent_seqs == 1, (
|
||||
"Prompt input should have only one seq.")
|
||||
parent_ids = [0] * sampling_params.best_of
|
||||
next_token_ids = random_samples[
|
||||
sample_idx, :sampling_params.best_of].tolist()
|
||||
else:
|
||||
# Generation phase.
|
||||
parent_ids = list(range(num_parent_seqs))
|
||||
next_token_ids = random_samples[sample_idx:sample_idx +
|
||||
num_parent_seqs, 0].tolist()
|
||||
results.append((next_token_ids, parent_ids))
|
||||
sample_idx += num_parent_seqs
|
||||
assert sample_idx == probs.size(0)
|
||||
return results
|
||||
|
||||
vocab_size = logprobs.size(-1)
|
||||
beam_width = len(seq_ids)
|
||||
_, topk_ids = torch.topk(logprobs.flatten(), 2 * beam_width)
|
||||
topk_ids = topk_ids.tolist()
|
||||
seq_idx = [i // vocab_size for i in topk_ids]
|
||||
parent_seq_ids = [seq_ids[i] for i in seq_idx]
|
||||
next_token_ids = [i % vocab_size for i in topk_ids]
|
||||
elif sampling_params.temperature < _SAMPLING_EPS:
|
||||
# Greedy sampling.
|
||||
assert len(seq_ids) == 1
|
||||
next_token_id = torch.argmax(probs, dim=-1)
|
||||
next_token_ids = [int(next_token_id.item())]
|
||||
parent_seq_ids = seq_ids
|
||||
else:
|
||||
# Random sampling.
|
||||
# Sample 1 token for each sequence in the group.
|
||||
next_token_ids = torch.multinomial(probs,
|
||||
num_samples=1,
|
||||
replacement=True)
|
||||
next_token_ids = next_token_ids.squeeze(dim=-1).tolist()
|
||||
parent_seq_ids = seq_ids
|
||||
return parent_seq_ids, next_token_ids
|
||||
|
||||
def _beam_search_sample(
|
||||
selected_seq_groups: List[Tuple[List[int], SamplingParams]],
|
||||
is_prompts: List[bool],
|
||||
seq_data: Dict[int, SequenceData],
|
||||
logprobs: torch.Tensor,
|
||||
) -> List[Tuple[List[int], List[int]]]:
|
||||
# We sample 2 * beam_width candidates to make sure that with high
|
||||
# probability we can get `beam_width` candidates in addition to
|
||||
# the finished sequences for the next iteration. See
|
||||
# https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563
|
||||
# for details. See also HF reference:
|
||||
# https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
|
||||
#
|
||||
# NOTE: Beam search is not vectorized, so its speed can be slower than
|
||||
# other sampling methods.
|
||||
sample_idx = 0
|
||||
results = []
|
||||
for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
|
||||
seq_ids, sampling_params = seq_group
|
||||
num_parent_seqs = len(seq_ids)
|
||||
beam_width = sampling_params.best_of
|
||||
seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs]
|
||||
if is_prompt:
|
||||
# Prompt phase.
|
||||
assert num_parent_seqs == 1, (
|
||||
"Prompt input should have only one seq.")
|
||||
parent_ids = [0] * (2 * beam_width)
|
||||
_, next_token_ids = torch.topk(seq_group_logprobs[0],
|
||||
2 * beam_width)
|
||||
next_token_ids = next_token_ids.tolist()
|
||||
else:
|
||||
# Generation phase.
|
||||
cumulative_logprobs = [
|
||||
seq_data[seq_id].cumulative_logprob for seq_id in seq_ids
|
||||
]
|
||||
cumulative_logprobs = torch.tensor(
|
||||
cumulative_logprobs,
|
||||
dtype=torch.float,
|
||||
device=seq_group_logprobs.device)
|
||||
seq_group_logprobs = (seq_group_logprobs +
|
||||
cumulative_logprobs.unsqueeze(dim=1))
|
||||
_, topk_ids = torch.topk(seq_group_logprobs.flatten(),
|
||||
2 * beam_width)
|
||||
topk_ids = topk_ids.tolist()
|
||||
vocab_size = seq_group_logprobs.size(-1)
|
||||
parent_ids = [i // vocab_size for i in topk_ids]
|
||||
next_token_ids = [i % vocab_size for i in topk_ids]
|
||||
results.append((next_token_ids, parent_ids))
|
||||
sample_idx += num_parent_seqs
|
||||
assert sample_idx == logprobs.size(0)
|
||||
return results
|
||||
|
||||
|
||||
def _sample(
|
||||
probs: torch.Tensor,
|
||||
logprobs: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
) -> SamplerOutput:
|
||||
seq_outputs: SamplerOutput = []
|
||||
|
||||
# TODO(woosuk): Optimize.
|
||||
idx = 0
|
||||
) -> List[Tuple[List[int], List[int]]]:
|
||||
categorized_seq_group_ids = {t: [] for t in SamplingType}
|
||||
categorized_sample_indices = input_metadata.categorized_sample_indices
|
||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
||||
seq_group_outputs: List[SequenceOutputs] = []
|
||||
seq_ids, sampling_params = seq_group
|
||||
if i < input_metadata.num_prompts:
|
||||
# Generate the next tokens for a prompt input.
|
||||
assert len(seq_ids) == 1, "Prompt input should have only one seq."
|
||||
parent_seq_id = seq_ids[0]
|
||||
prob = probs[idx]
|
||||
logprob = logprobs[idx]
|
||||
idx += 1
|
||||
_, sampling_params = seq_group
|
||||
sampling_type = sampling_params.sampling_type
|
||||
categorized_seq_group_ids[sampling_type].append(i)
|
||||
|
||||
# Sample the next tokens.
|
||||
next_token_ids = _sample_from_prompt(prob, sampling_params)
|
||||
# Get top-k log probabilities for the next tokens.
|
||||
next_logprobs = _get_topk_logprobs(logprob,
|
||||
sampling_params.logprobs)
|
||||
|
||||
# Build the output.
|
||||
for next_token_id in next_token_ids:
|
||||
output_logprobs = next_logprobs.copy()
|
||||
output_logprobs[next_token_id] = logprob[next_token_id].item()
|
||||
seq_group_outputs.append(
|
||||
SequenceOutputs(parent_seq_id, next_token_id,
|
||||
output_logprobs))
|
||||
sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
|
||||
for sampling_type in SamplingType:
|
||||
seq_group_ids = categorized_seq_group_ids[sampling_type]
|
||||
seq_groups = [input_metadata.seq_groups[i] for i in seq_group_ids]
|
||||
is_prompts = [i < input_metadata.num_prompts for i in seq_group_ids]
|
||||
sample_indices = categorized_sample_indices[sampling_type]
|
||||
num_tokens = len(sample_indices)
|
||||
if num_tokens == 0:
|
||||
continue
|
||||
if sampling_type == SamplingType.GREEDY:
|
||||
category_logprobs = logprobs[sample_indices]
|
||||
sample_results = _greedy_sample(seq_groups, category_logprobs)
|
||||
elif sampling_type == SamplingType.RANDOM:
|
||||
category_probs = probs[sample_indices]
|
||||
sample_results = _random_sample(seq_groups, is_prompts,
|
||||
category_probs)
|
||||
elif sampling_type == SamplingType.BEAM:
|
||||
category_logprobs = logprobs[sample_indices]
|
||||
sample_results = _beam_search_sample(seq_groups, is_prompts,
|
||||
input_metadata.seq_data,
|
||||
category_logprobs)
|
||||
else:
|
||||
# Generate the next tokens for generation tokens.
|
||||
num_parent_seqs = len(seq_ids)
|
||||
prob = probs[idx:idx + num_parent_seqs]
|
||||
logprob = logprobs[idx:idx + num_parent_seqs]
|
||||
idx += num_parent_seqs
|
||||
raise ValueError(f"Unsupported sampling type: {sampling_type}")
|
||||
sample_results_dict.update(zip(seq_group_ids, sample_results))
|
||||
|
||||
# Sample the next tokens.
|
||||
seq_logprobs = [
|
||||
input_metadata.seq_data[seq_id].cumulative_logprob
|
||||
for seq_id in seq_ids
|
||||
]
|
||||
parent_seq_ids, next_token_ids = _sample_from_generation_tokens(
|
||||
seq_ids, prob, logprob, seq_logprobs, sampling_params)
|
||||
sample_results = [
|
||||
sample_results_dict[i] for i in range(len(input_metadata.seq_groups))
|
||||
]
|
||||
return sample_results
|
||||
|
||||
# Get top-k log probabilities for the next tokens.
|
||||
next_logprobs: Dict[int, Dict[int, float]] = {}
|
||||
for j, seq_id in enumerate(seq_ids):
|
||||
next_logprobs[seq_id] = _get_topk_logprobs(
|
||||
logprob[j], sampling_params.logprobs)
|
||||
|
||||
# Build the output.
|
||||
for parent_seq_id, next_token_id in zip(parent_seq_ids,
|
||||
next_token_ids):
|
||||
j = seq_ids.index(parent_seq_id)
|
||||
output_logprobs = next_logprobs[parent_seq_id].copy()
|
||||
output_logprobs[next_token_id] = logprob[j,
|
||||
next_token_id].item()
|
||||
seq_group_outputs.append(
|
||||
SequenceOutputs(parent_seq_id, next_token_id,
|
||||
output_logprobs))
|
||||
seq_outputs.append(seq_group_outputs)
|
||||
def _get_logprobs(
|
||||
logprobs: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
sample_results: List[Tuple[List[int], List[int]]],
|
||||
) -> Tuple[List[Optional[List[Optional[Dict[int, float]]]]], List[List[Dict[
|
||||
int, float]]]]:
|
||||
# Prepare query indices
|
||||
batched_logprobs_query_seq_indices: List[int] = []
|
||||
batched_logprobs_query_token_indices: List[int] = []
|
||||
largest_num_logprobs = 0
|
||||
sample_idx = 0
|
||||
for i, (seq_group, sample_result) in enumerate(
|
||||
zip(input_metadata.seq_groups, sample_results)):
|
||||
seq_ids, sampling_params = seq_group
|
||||
next_token_ids, parent_ids = sample_result
|
||||
num_parent_seqs = len(seq_ids)
|
||||
if (i < input_metadata.num_prompts
|
||||
and sampling_params.prompt_logprobs is not None):
|
||||
largest_num_logprobs = max(largest_num_logprobs,
|
||||
sampling_params.prompt_logprobs)
|
||||
prompt_len = input_metadata.prompt_lens[i]
|
||||
prompt_tokens = input_metadata.seq_data[
|
||||
seq_ids[0]].prompt_token_ids
|
||||
batched_logprobs_query_seq_indices.extend(
|
||||
sample_idx + j for j in range(prompt_len - 1))
|
||||
batched_logprobs_query_token_indices.extend(
|
||||
token_id for token_id in prompt_tokens[1:])
|
||||
sample_idx += prompt_len - 1
|
||||
batched_logprobs_query_seq_indices.extend(
|
||||
[sample_idx + parent_id for parent_id in parent_ids])
|
||||
batched_logprobs_query_token_indices.extend(next_token_ids)
|
||||
if sampling_params.logprobs is not None:
|
||||
largest_num_logprobs = max(largest_num_logprobs,
|
||||
sampling_params.logprobs)
|
||||
sample_idx += num_parent_seqs
|
||||
assert sample_idx == logprobs.size(0)
|
||||
|
||||
return seq_outputs
|
||||
# Batched query for logprobs of selected token
|
||||
batched_logprobs_query_result = logprobs[[
|
||||
batched_logprobs_query_seq_indices,
|
||||
batched_logprobs_query_token_indices
|
||||
]].cpu()
|
||||
|
||||
# Batched query for logprobs of topk tokens
|
||||
if largest_num_logprobs > 0:
|
||||
top_logprobs, top_token_ids = torch.topk(logprobs,
|
||||
largest_num_logprobs,
|
||||
dim=-1)
|
||||
top_logprobs = top_logprobs.cpu()
|
||||
top_token_ids = top_token_ids.cpu()
|
||||
else:
|
||||
top_logprobs, top_token_ids = None, None
|
||||
|
||||
# Gather results
|
||||
result_prompt_logprobs: List[Optional[PromptLogprobs]] = []
|
||||
result_sample_logprobs: List[SampleLogprobs] = []
|
||||
sample_idx = 0
|
||||
query_result_idx = 0
|
||||
for i, (seq_group, sample_result) in enumerate(
|
||||
zip(input_metadata.seq_groups, sample_results)):
|
||||
seq_ids, sampling_params = seq_group
|
||||
next_token_ids, parent_ids = sample_result
|
||||
|
||||
# Prompt logprobs
|
||||
if (i < input_metadata.num_prompts
|
||||
and sampling_params.prompt_logprobs is not None):
|
||||
num_logprobs = sampling_params.prompt_logprobs
|
||||
prompt_len = input_metadata.prompt_lens[i]
|
||||
prompt_tokens = input_metadata.seq_data[
|
||||
seq_ids[0]].prompt_token_ids
|
||||
group_prompt_logprobs: PromptLogprobs = [None]
|
||||
for token_id in prompt_tokens[1:]:
|
||||
prompt_logprobs_dict = {
|
||||
token_id:
|
||||
batched_logprobs_query_result[query_result_idx].item()
|
||||
}
|
||||
if num_logprobs > 0:
|
||||
prompt_logprobs_dict.update(
|
||||
zip(top_token_ids[sample_idx, :num_logprobs].tolist(),
|
||||
top_logprobs[sample_idx, :num_logprobs].tolist()))
|
||||
group_prompt_logprobs.append(prompt_logprobs_dict)
|
||||
sample_idx += 1
|
||||
query_result_idx += 1
|
||||
result_prompt_logprobs.append(group_prompt_logprobs)
|
||||
else:
|
||||
result_prompt_logprobs.append(None)
|
||||
|
||||
# Sample logprobs
|
||||
num_logprobs = sampling_params.logprobs
|
||||
if num_logprobs is None:
|
||||
num_logprobs = 0
|
||||
group_sample_logprobs: SampleLogprobs = []
|
||||
for next_token_id, parent_id in zip(next_token_ids, parent_ids):
|
||||
sample_logprobs_dict = {
|
||||
next_token_id:
|
||||
batched_logprobs_query_result[query_result_idx].item()
|
||||
}
|
||||
query_result_idx += 1
|
||||
if num_logprobs > 0:
|
||||
sample_logprobs_dict.update(
|
||||
zip(
|
||||
top_token_ids[sample_idx +
|
||||
parent_id, :num_logprobs].tolist(),
|
||||
top_logprobs[sample_idx +
|
||||
parent_id, :num_logprobs].tolist()))
|
||||
group_sample_logprobs.append(sample_logprobs_dict)
|
||||
result_sample_logprobs.append(group_sample_logprobs)
|
||||
sample_idx += len(seq_ids)
|
||||
|
||||
return result_prompt_logprobs, result_sample_logprobs
|
||||
|
||||
|
||||
def _build_sampler_output(
|
||||
sample_results: List[Tuple[List[int], List[int]]],
|
||||
input_metadata: InputMetadata,
|
||||
prompt_logprobs: List[Optional[PromptLogprobs]],
|
||||
sample_logprobs: List[SampleLogprobs],
|
||||
) -> SamplerOutput:
|
||||
sampler_output = []
|
||||
for (seq_group, sample_result, group_prompt_logprobs,
|
||||
group_sample_logprobs) in zip(input_metadata.seq_groups,
|
||||
sample_results, prompt_logprobs,
|
||||
sample_logprobs):
|
||||
seq_ids, _ = seq_group
|
||||
next_token_ids, parent_ids = sample_result
|
||||
seq_outputs = []
|
||||
for parent_id, next_token_id, logprobs in zip(parent_ids,
|
||||
next_token_ids,
|
||||
group_sample_logprobs):
|
||||
seq_outputs.append(
|
||||
SequenceOutputs(seq_ids[parent_id], next_token_id, logprobs))
|
||||
sampler_output.append(
|
||||
SequenceGroupOutputs(seq_outputs, group_prompt_logprobs))
|
||||
return sampler_output
|
||||
|
139
vllm/model_executor/layers/vocab_parallel_embedding.py
Normal file
@ -0,0 +1,139 @@
|
||||
from typing import Optional, Sequence
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.model_executor.parallel_utils.utils import divide
|
||||
from vllm.model_executor.parallel_utils.communication_op import (
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
|
||||
def pad_vocab_size(vocab_size: int, pad_to: int = 64) -> int:
|
||||
"""Pad the vocab size to the given value."""
|
||||
return ((vocab_size + pad_to - 1) // pad_to) * pad_to
|
||||
|
||||
|
||||
def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size: int,
|
||||
rank: int) -> Sequence[int]:
|
||||
index_f = rank * per_partition_vocab_size
|
||||
index_l = index_f + per_partition_vocab_size
|
||||
return index_f, index_l
|
||||
|
||||
|
||||
def vocab_range_from_global_vocab_size(global_vocab_size: int, rank: int,
|
||||
world_size: int) -> Sequence[int]:
|
||||
per_partition_vocab_size = divide(global_vocab_size, world_size)
|
||||
return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size,
|
||||
rank)
|
||||
|
||||
|
||||
class VocabParallelEmbedding(torch.nn.Module):
|
||||
"""Embedding parallelized in the vocabulary dimension.
|
||||
|
||||
Adapted from torch.nn.Embedding, note that we pad the vocabulary size to
|
||||
make sure it is divisible by the number of model parallel GPUs.
|
||||
|
||||
Args:
|
||||
num_embeddings: vocabulary size.
|
||||
embedding_dim: size of hidden state.
|
||||
params_dtype: type of the parameters.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
params_dtype: Optional[torch.dtype] = None):
|
||||
super().__init__()
|
||||
|
||||
# Keep the input dimensions.
|
||||
self.num_embeddings = num_embeddings
|
||||
self.num_embeddings_padded = pad_vocab_size(num_embeddings)
|
||||
self.embedding_dim = embedding_dim
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
# Divide the weight matrix along the vocaburaly dimension.
|
||||
self.vocab_start_index, self.vocab_end_index = (
|
||||
vocab_range_from_global_vocab_size(
|
||||
self.num_embeddings_padded, get_tensor_model_parallel_rank(),
|
||||
self.tp_size))
|
||||
self.num_embeddings_per_partition = (self.vocab_end_index -
|
||||
self.vocab_start_index)
|
||||
self.weight = Parameter(
|
||||
torch.empty(self.num_embeddings_per_partition,
|
||||
self.embedding_dim,
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=params_dtype))
|
||||
set_weight_attrs(self.weight, {
|
||||
"parallel_dim": 0,
|
||||
"weight_loader": self.weight_loader
|
||||
})
|
||||
|
||||
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
||||
parallel_dim = param.parallel_dim
|
||||
assert loaded_weight.shape[parallel_dim] == self.num_embeddings
|
||||
loaded_weight = loaded_weight[self.vocab_start_index:self.
|
||||
vocab_end_index]
|
||||
param[:loaded_weight.shape[0]].data.copy_(loaded_weight)
|
||||
|
||||
def forward(self, input_):
|
||||
if self.tp_size > 1:
|
||||
# Build the mask.
|
||||
input_mask = ((input_ < self.vocab_start_index) |
|
||||
(input_ >= self.vocab_end_index))
|
||||
# Mask the input.
|
||||
masked_input = input_.clone() - self.vocab_start_index
|
||||
masked_input[input_mask] = 0
|
||||
else:
|
||||
masked_input = input_
|
||||
# Get the embeddings.
|
||||
output_parallel = F.embedding(masked_input, self.weight)
|
||||
# Mask the output embedding.
|
||||
if self.tp_size > 1:
|
||||
output_parallel[input_mask, :] = 0.0
|
||||
# Reduce across all the model parallel GPUs.
|
||||
output = tensor_model_parallel_all_reduce(output_parallel)
|
||||
return output
|
||||
|
||||
|
||||
class ParallelLMHead(VocabParallelEmbedding):
|
||||
"""Parallelized LM head.
|
||||
|
||||
Output logits weight matrices used in the Sampler. The weight and bias
|
||||
tensors are padded to make sure they are divisible by the number of
|
||||
model parallel GPUs.
|
||||
|
||||
Args:
|
||||
num_embeddings: vocabulary size.
|
||||
embedding_dim: size of hidden state.
|
||||
bias: whether to use bias.
|
||||
params_dtype: type of the parameters.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
bias: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None):
|
||||
super().__init__(num_embeddings, embedding_dim, params_dtype)
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.num_embeddings_per_partition,
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=params_dtype))
|
||||
set_weight_attrs(self.bias, {
|
||||
"parallel_dim": 0,
|
||||
"weight_loader": self.weight_loader
|
||||
})
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
def forward(self, input_):
|
||||
del input_
|
||||
raise RuntimeError("LMHead's weights should be used in the sampler.")
|
@ -8,14 +8,17 @@ from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.model_executor.models import * # pylint: disable=wildcard-import
|
||||
from vllm.model_executor.weight_utils import initialize_dummy_weights
|
||||
from vllm.model_executor.weight_utils import (get_quant_config,
|
||||
initialize_dummy_weights)
|
||||
|
||||
# TODO(woosuk): Lazy-load the model classes.
|
||||
_MODEL_REGISTRY = {
|
||||
"AquilaModel": AquilaForCausalLM,
|
||||
"AquilaForCausalLM": AquilaForCausalLM, # AquilaChat2
|
||||
"BaiChuanForCausalLM": BaiChuanForCausalLM, # baichuan-7b
|
||||
"BaichuanForCausalLM": BaichuanForCausalLM, # baichuan-13b
|
||||
"BloomForCausalLM": BloomForCausalLM,
|
||||
"ChatGLMModel": ChatGLMForCausalLM,
|
||||
"FalconForCausalLM": FalconForCausalLM,
|
||||
"GPT2LMHeadModel": GPT2LMHeadModel,
|
||||
"GPTBigCodeForCausalLM": GPTBigCodeForCausalLM,
|
||||
@ -24,10 +27,15 @@ _MODEL_REGISTRY = {
|
||||
"InternLMForCausalLM": InternLMForCausalLM,
|
||||
"LlamaForCausalLM": LlamaForCausalLM,
|
||||
"LLaMAForCausalLM": LlamaForCausalLM, # For decapoda-research/llama-*
|
||||
"MistralForCausalLM": MistralForCausalLM,
|
||||
# transformers's mpt class has lower case
|
||||
"MptForCausalLM": MPTForCausalLM,
|
||||
"MPTForCausalLM": MPTForCausalLM,
|
||||
"OPTForCausalLM": OPTForCausalLM,
|
||||
"PhiForCausalLM": PhiForCausalLM,
|
||||
"QWenLMHeadModel": QWenLMHeadModel,
|
||||
"RWForCausalLM": FalconForCausalLM,
|
||||
"YiForCausalLM": YiForCausalLM,
|
||||
}
|
||||
|
||||
|
||||
@ -52,10 +60,34 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
|
||||
|
||||
def get_model(model_config: ModelConfig) -> nn.Module:
|
||||
model_class = _get_model_architecture(model_config.hf_config)
|
||||
|
||||
# Get the (maybe quantized) linear method.
|
||||
linear_method = None
|
||||
if model_config.quantization is not None:
|
||||
quant_config = get_quant_config(model_config.quantization,
|
||||
model_config.model,
|
||||
model_config.hf_config,
|
||||
model_config.download_dir)
|
||||
capability = torch.cuda.get_device_capability()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
if capability < quant_config.get_min_capability():
|
||||
raise ValueError(
|
||||
f"The quantization method {model_config.quantization} is not "
|
||||
"supported for the current GPU. "
|
||||
f"Minimum capability: {quant_config.get_min_capability()}. "
|
||||
f"Current capability: {capability}.")
|
||||
supported_dtypes = quant_config.get_supported_act_dtypes()
|
||||
if model_config.dtype not in supported_dtypes:
|
||||
raise ValueError(
|
||||
f"{model_config.dtype} is not supported for quantization "
|
||||
f"method {model_config.quantization}. Supported dtypes: "
|
||||
f"{supported_dtypes}")
|
||||
linear_method = quant_config.get_linear_method()
|
||||
|
||||
with _set_default_torch_dtype(model_config.dtype):
|
||||
# Create a model instance.
|
||||
# The weights will be initialized as empty tensors.
|
||||
model = model_class(model_config.hf_config)
|
||||
model = model_class(model_config.hf_config, linear_method)
|
||||
if model_config.load_format == "dummy":
|
||||
model = model.cuda()
|
||||
# NOTE(woosuk): For accurate performance evaluation, we assign
|
||||
@ -64,6 +96,6 @@ def get_model(model_config: ModelConfig) -> nn.Module:
|
||||
else:
|
||||
# Load the weights from the cached or downloaded files.
|
||||
model.load_weights(model_config.model, model_config.download_dir,
|
||||
model_config.load_format)
|
||||
model_config.load_format, model_config.revision)
|
||||
model = model.cuda()
|
||||
return model.eval()
|
||||
|
@ -9,15 +9,20 @@ from vllm.model_executor.models.gpt_j import GPTJForCausalLM
|
||||
from vllm.model_executor.models.gpt_neox import GPTNeoXForCausalLM
|
||||
from vllm.model_executor.models.internlm import InternLMForCausalLM
|
||||
from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||
from vllm.model_executor.models.mistral import MistralForCausalLM
|
||||
from vllm.model_executor.models.mpt import MPTForCausalLM
|
||||
from vllm.model_executor.models.opt import OPTForCausalLM
|
||||
from vllm.model_executor.models.phi_1_5 import PhiForCausalLM
|
||||
from vllm.model_executor.models.qwen import QWenLMHeadModel
|
||||
from vllm.model_executor.models.chatglm import ChatGLMForCausalLM
|
||||
from vllm.model_executor.models.yi import YiForCausalLM
|
||||
|
||||
__all__ = [
|
||||
"AquilaForCausalLM",
|
||||
"BaiChuanForCausalLM",
|
||||
"BaichuanForCausalLM",
|
||||
"BloomForCausalLM",
|
||||
"ChatGLMForCausalLM",
|
||||
"FalconForCausalLM",
|
||||
"GPT2LMHeadModel",
|
||||
"GPTBigCodeForCausalLM",
|
||||
@ -27,5 +32,8 @@ __all__ = [
|
||||
"LlamaForCausalLM",
|
||||
"MPTForCausalLM",
|
||||
"OPTForCausalLM",
|
||||
"PhiForCausalLM",
|
||||
"QWenLMHeadModel",
|
||||
"MistralForCausalLM",
|
||||
"YiForCausalLM",
|
||||
]
|
||||
|
@ -25,7 +25,7 @@
|
||||
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
||||
InputMetadata to extract the original 2D shape of the input.
|
||||
"""
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -33,14 +33,17 @@ from torch import nn
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.weight_utils import (
|
||||
hf_model_weights_iterator, load_padded_tensor_parallel_vocab,
|
||||
load_tensor_parallel_weights)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding, ParallelLMHead)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.transformers_utils.configs.aquila import AquilaConfig
|
||||
|
||||
@ -54,18 +57,17 @@ class AquilaMLP(nn.Module):
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.gate_up_proj = ColumnParallelLinear(hidden_size,
|
||||
2 * intermediate_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size, [intermediate_size] * 2,
|
||||
bias=False,
|
||||
linear_method=linear_method)
|
||||
self.down_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
linear_method=linear_method)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
@ -105,6 +107,10 @@ class AquilaAttention(nn.Module):
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
rope_theta: float = 10000,
|
||||
max_position_embeddings: int = 8192,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@ -119,28 +125,32 @@ class AquilaAttention(nn.Module):
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
self.qkv_proj = ColumnParallelLinear(
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size,
|
||||
(self.total_num_heads + 2 * self.total_num_kv_heads) *
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
perform_initialization=False,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.attn = PagedAttentionWithRoPE(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
base=self.rope_theta,
|
||||
max_position=self.max_position_embeddings,
|
||||
rotary_dim=self.head_dim,
|
||||
)
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
rope_scaling=rope_scaling)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -161,18 +171,31 @@ class AquilaAttention(nn.Module):
|
||||
|
||||
class AquilaDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(self, config: AquilaConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: AquilaConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||
8192)
|
||||
self.self_attn = AquilaAttention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_kv_heads=config.num_attention_heads,
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
rope_theta=rope_theta,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
rope_scaling=rope_scaling,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.mlp = AquilaMLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.input_layernorm = AquilaRMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
@ -209,19 +232,22 @@ class AquilaDecoderLayer(nn.Module):
|
||||
|
||||
class AquilaModel(nn.Module):
|
||||
|
||||
def __init__(self, config: AquilaConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: AquilaConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
#vocab_size = ((config.vocab_size + 63) // 64) * 64
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
perform_initialization=False)
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
AquilaDecoderLayer(config) for _ in range(config.num_hidden_layers)
|
||||
AquilaDecoderLayer(config, linear_method)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
self.norm = AquilaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
@ -254,16 +280,16 @@ class AquilaModel(nn.Module):
|
||||
|
||||
class AquilaForCausalLM(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.model = AquilaModel(config)
|
||||
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
||||
self.lm_head = ColumnParallelLinear(config.hidden_size,
|
||||
vocab_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.linear_method = linear_method
|
||||
self.model = AquilaModel(config, linear_method)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
|
||||
def forward(
|
||||
@ -280,78 +306,33 @@ class AquilaForCausalLM(nn.Module):
|
||||
input_metadata)
|
||||
return next_tokens
|
||||
|
||||
_column_parallel_weights = [
|
||||
"qkv_proj.weight", "gate_proj.weight", "up_proj.weight"
|
||||
]
|
||||
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
|
||||
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto"):
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||
q_proj_shard_size = (self.config.hidden_size // tp_size)
|
||||
kv_proj_shard_size = (self.config.hidden_size //
|
||||
self.config.num_attention_heads *
|
||||
self.config.num_attention_heads // tp_size)
|
||||
attention_weight_specs = [
|
||||
# (weight_name, shard_size, offset)
|
||||
("q_proj", q_proj_shard_size, 0),
|
||||
("k_proj", kv_proj_shard_size, q_proj_shard_size),
|
||||
("v_proj", kv_proj_shard_size,
|
||||
q_proj_shard_size + kv_proj_shard_size),
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
state_dict = self.state_dict()
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format):
|
||||
model_name_or_path, cache_dir, load_format, revision):
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
is_attention_weight = False
|
||||
for weight_name, shard_size, offset in attention_weight_specs:
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
param = state_dict[name.replace(weight_name, "qkv_proj")]
|
||||
|
||||
loaded_weight = loaded_weight[
|
||||
shard_size * tensor_model_parallel_rank:shard_size *
|
||||
(tensor_model_parallel_rank + 1)]
|
||||
param_slice = param.data[offset:offset + shard_size]
|
||||
assert param_slice.shape == loaded_weight.shape
|
||||
|
||||
param_slice.copy_(loaded_weight)
|
||||
is_attention_weight = True
|
||||
param = params_dict[name.replace(weight_name, param_name)]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
if is_attention_weight:
|
||||
continue
|
||||
|
||||
is_gate_up_weight = False
|
||||
for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
|
||||
if weight_name not in name:
|
||||
continue
|
||||
param = state_dict[name.replace(weight_name, "gate_up_proj")]
|
||||
shard_size = param.shape[0] // 2
|
||||
loaded_weight = loaded_weight[
|
||||
shard_size * tensor_model_parallel_rank:shard_size *
|
||||
(tensor_model_parallel_rank + 1)]
|
||||
param_slice = param.data[shard_size * stride_id:shard_size *
|
||||
(stride_id + 1)]
|
||||
assert param_slice.shape == loaded_weight.shape
|
||||
param_slice.copy_(loaded_weight)
|
||||
is_gate_up_weight = True
|
||||
break
|
||||
if is_gate_up_weight:
|
||||
continue
|
||||
|
||||
param = state_dict[name]
|
||||
if "embed_tokens" in name or "lm_head" in name:
|
||||
load_padded_tensor_parallel_vocab(param, loaded_weight,
|
||||
tensor_model_parallel_rank)
|
||||
continue
|
||||
|
||||
load_tensor_parallel_weights(param, loaded_weight, name,
|
||||
self._column_parallel_weights,
|
||||
self._row_parallel_weights,
|
||||
tensor_model_parallel_rank)
|
||||
else:
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
@ -30,17 +30,20 @@ from torch import nn
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.attention import (PagedAttentionWithRoPE,
|
||||
PagedAttentionWithALiBi)
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.weight_utils import (
|
||||
convert_pyslice_to_tensor, hf_model_weights_iterator,
|
||||
load_padded_tensor_parallel_vocab, load_tensor_parallel_weights)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding, ParallelLMHead)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.transformers_utils.configs.baichuan import BaiChuanConfig
|
||||
|
||||
@ -79,18 +82,17 @@ class BaiChuanMLP(nn.Module):
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.gate_up_proj = ColumnParallelLinear(hidden_size,
|
||||
2 * intermediate_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size, [intermediate_size] * 2,
|
||||
bias=False,
|
||||
linear_method=linear_method)
|
||||
self.down_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
linear_method=linear_method)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
@ -111,6 +113,9 @@ class BaiChuanAttention(nn.Module):
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
position_embedding: str,
|
||||
rope_theta: float = 10000,
|
||||
max_position_embeddings: int = 8192,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@ -122,21 +127,23 @@ class BaiChuanAttention(nn.Module):
|
||||
tensor_model_parallel_world_size)
|
||||
self.head_dim = hidden_size // self.total_num_heads
|
||||
self.postion_embedding = position_embedding
|
||||
self.rope_theta = rope_theta
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
self.W_pack = ColumnParallelLinear(
|
||||
self.W_pack = QKVParallelLinear(
|
||||
hidden_size,
|
||||
3 * hidden_size,
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
self.total_num_heads,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
perform_initialization=False,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
# Create the alibi slopes and slice them.
|
||||
if self.postion_embedding == "ALIBI":
|
||||
@ -151,10 +158,13 @@ class BaiChuanAttention(nn.Module):
|
||||
scaling, alibi_slopes)
|
||||
else:
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.attn = PagedAttentionWithRoPE(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
rotary_dim=self.head_dim)
|
||||
self.attn = PagedAttentionWithRoPE(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
rotary_dim=self.head_dim,
|
||||
base=self.rope_theta,
|
||||
max_position=self.max_position_embeddings)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -180,18 +190,28 @@ class BaiChuanAttention(nn.Module):
|
||||
|
||||
class BaiChuanDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(self, config: BaiChuanConfig, position_embedding: str):
|
||||
def __init__(self,
|
||||
config: BaiChuanConfig,
|
||||
position_embedding: str,
|
||||
linear_method: Optional[LinearMethodBase] = None):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||
8192)
|
||||
self.self_attn = BaiChuanAttention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
position_embedding=position_embedding,
|
||||
rope_theta=rope_theta,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.mlp = BaiChuanMLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
@ -205,10 +225,15 @@ class BaiChuanDecoderLayer(nn.Module):
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
@ -216,19 +241,20 @@ class BaiChuanDecoderLayer(nn.Module):
|
||||
input_metadata=input_metadata,
|
||||
cache_event=cache_event,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
return hidden_states
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
class BaiChuanModel(nn.Module):
|
||||
|
||||
def __init__(self, config: BaiChuanConfig, position_embedding: str):
|
||||
def __init__(self,
|
||||
config: BaiChuanConfig,
|
||||
position_embedding: str,
|
||||
linear_method: Optional[LinearMethodBase] = None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.padding_idx = config.pad_token_id
|
||||
@ -237,9 +263,9 @@ class BaiChuanModel(nn.Module):
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
perform_initialization=False)
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
BaiChuanDecoderLayer(config, position_embedding)
|
||||
BaiChuanDecoderLayer(config, position_embedding, linear_method)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
@ -253,34 +279,36 @@ class BaiChuanModel(nn.Module):
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
residual = None
|
||||
for i in range(len(self.layers)):
|
||||
if cache_events is None:
|
||||
cache_event = None
|
||||
else:
|
||||
cache_event = cache_events[i]
|
||||
layer = self.layers[i]
|
||||
hidden_states = layer(
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i],
|
||||
input_metadata,
|
||||
cache_event,
|
||||
residual,
|
||||
)
|
||||
hidden_states = self.norm(hidden_states)
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BaiChuanBaseForCausalLM(nn.Module):
|
||||
|
||||
def __init__(self, config, position_embedding: str):
|
||||
def __init__(self,
|
||||
config,
|
||||
position_embedding: str,
|
||||
linear_method: Optional[LinearMethodBase] = None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.model = BaiChuanModel(config, position_embedding)
|
||||
self.lm_head = ColumnParallelLinear(config.hidden_size,
|
||||
config.vocab_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.linear_method = linear_method
|
||||
self.model = BaiChuanModel(config, position_embedding, linear_method)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
|
||||
def forward(
|
||||
@ -297,78 +325,46 @@ class BaiChuanBaseForCausalLM(nn.Module):
|
||||
input_metadata)
|
||||
return next_tokens
|
||||
|
||||
_column_parallel_weights = []
|
||||
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
|
||||
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto"):
|
||||
tp_world_size = get_tensor_model_parallel_world_size()
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
state_dict = self.state_dict()
|
||||
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format):
|
||||
model_name_or_path, cache_dir, load_format, revision):
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
||||
|
||||
if "W_pack" in name:
|
||||
total_num_heads = self.config.num_attention_heads
|
||||
hidden_size = self.config.hidden_size
|
||||
head_size = hidden_size // total_num_heads
|
||||
num_heads = total_num_heads // tp_world_size
|
||||
head_start = tp_rank * num_heads
|
||||
head_end = (tp_rank + 1) * num_heads
|
||||
|
||||
loaded_weight = loaded_weight.view(3, total_num_heads,
|
||||
head_size, hidden_size)
|
||||
loaded_weight = loaded_weight[:, head_start:head_end, :, :]
|
||||
loaded_weight = loaded_weight.reshape(-1, hidden_size)
|
||||
|
||||
is_gate_up_weight = False
|
||||
for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
param = state_dict[name.replace(weight_name, "gate_up_proj")]
|
||||
shard_size = param.shape[0] // 2
|
||||
loaded_weight = loaded_weight[shard_size * tp_rank:shard_size *
|
||||
(tp_rank + 1)]
|
||||
param_slice = param.data[shard_size * stride_id:shard_size *
|
||||
(stride_id + 1)]
|
||||
assert param_slice.shape == loaded_weight.shape
|
||||
param_slice.copy_(loaded_weight)
|
||||
is_gate_up_weight = True
|
||||
param = params_dict[name.replace(weight_name, param_name)]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
if is_gate_up_weight:
|
||||
continue
|
||||
|
||||
param = state_dict[name]
|
||||
|
||||
if "embed_tokens" in name or "lm_head" in name:
|
||||
load_padded_tensor_parallel_vocab(param, loaded_weight,
|
||||
tp_rank)
|
||||
continue
|
||||
|
||||
load_tensor_parallel_weights(
|
||||
param,
|
||||
loaded_weight,
|
||||
name,
|
||||
self._column_parallel_weights,
|
||||
self._row_parallel_weights,
|
||||
tp_rank,
|
||||
)
|
||||
else:
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
|
||||
class BaichuanForCausalLM(BaiChuanBaseForCausalLM): # baichuan 13b
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config, "ALIBI")
|
||||
def __init__(self,
|
||||
config,
|
||||
linear_method: Optional[LinearMethodBase] = None):
|
||||
super().__init__(config, "ALIBI", linear_method)
|
||||
|
||||
|
||||
class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): # baichuan 7b
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config, "ROPE")
|
||||
def __init__(self,
|
||||
config,
|
||||
linear_method: Optional[LinearMethodBase] = None):
|
||||
super().__init__(config, "ROPE", linear_method)
|
||||
|
@ -30,13 +30,17 @@ from transformers import BloomConfig
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.attention import PagedAttentionWithALiBi
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearMethodBase,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
|
||||
load_tensor_parallel_weights)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
@ -69,7 +73,11 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
|
||||
|
||||
class BloomAttention(nn.Module):
|
||||
|
||||
def __init__(self, config: BloomConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: BloomConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.total_num_heads = config.n_head
|
||||
@ -80,19 +88,18 @@ class BloomAttention(nn.Module):
|
||||
assert self.total_num_heads % tp_world_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_world_size
|
||||
|
||||
self.query_key_value = ColumnParallelLinear(
|
||||
self.query_key_value = QKVParallelLinear(
|
||||
self.hidden_size,
|
||||
3 * self.hidden_size,
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
bias=True,
|
||||
gather_output=False,
|
||||
perform_initialization=False,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.dense = RowParallelLinear(
|
||||
self.hidden_size,
|
||||
self.hidden_size,
|
||||
bias=True,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
|
||||
# Create the alibi slopes and slice them.
|
||||
@ -126,38 +133,49 @@ class BloomAttention(nn.Module):
|
||||
|
||||
class BloomMLP(nn.Module):
|
||||
|
||||
def __init__(self, config: BloomConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: BloomConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
self.dense_h_to_4h = ColumnParallelLinear(hidden_size,
|
||||
4 * hidden_size,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.act = get_act_fn("gelu")
|
||||
self.dense_4h_to_h = RowParallelLinear(4 * hidden_size,
|
||||
hidden_size,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
self.dense_h_to_4h = ColumnParallelLinear(
|
||||
hidden_size,
|
||||
4 * hidden_size,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
quant_config = getattr(linear_method, "quant_config", None)
|
||||
self.gelu_impl = get_act_fn("gelu", quant_config, 4 * hidden_size)
|
||||
self.dense_4h_to_h = RowParallelLinear(
|
||||
4 * hidden_size,
|
||||
hidden_size,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x, _ = self.dense_h_to_4h(x)
|
||||
x = self.act(x)
|
||||
x = self.gelu_impl(x)
|
||||
x, _ = self.dense_4h_to_h(x)
|
||||
return x
|
||||
|
||||
|
||||
class BloomBlock(nn.Module):
|
||||
|
||||
def __init__(self, config: BloomConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: BloomConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
|
||||
self.input_layernorm = nn.LayerNorm(hidden_size,
|
||||
eps=config.layer_norm_epsilon)
|
||||
self.self_attention = BloomAttention(config)
|
||||
self.self_attention = BloomAttention(config, linear_method)
|
||||
self.post_attention_layernorm = nn.LayerNorm(
|
||||
hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.mlp = BloomMLP(config)
|
||||
self.mlp = BloomMLP(config, linear_method)
|
||||
self.apply_residual_connection_post_layernorm = (
|
||||
config.apply_residual_connection_post_layernorm)
|
||||
|
||||
@ -202,19 +220,27 @@ class BloomBlock(nn.Module):
|
||||
|
||||
class BloomModel(nn.Module):
|
||||
|
||||
def __init__(self, config: BloomConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: BloomConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
|
||||
# Embedding + LN Embedding
|
||||
self.word_embeddings = VocabParallelEmbedding(
|
||||
config.vocab_size, self.embed_dim, perform_initialization=False)
|
||||
config.vocab_size,
|
||||
self.embed_dim,
|
||||
)
|
||||
self.word_embeddings_layernorm = nn.LayerNorm(
|
||||
self.embed_dim, eps=config.layer_norm_epsilon)
|
||||
|
||||
# Transformer blocks
|
||||
self.h = nn.ModuleList(
|
||||
[BloomBlock(config) for _ in range(config.num_hidden_layers)])
|
||||
self.h = nn.ModuleList([
|
||||
BloomBlock(config, linear_method)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
|
||||
# Final Layer Norm
|
||||
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
||||
@ -248,12 +274,15 @@ class BloomModel(nn.Module):
|
||||
|
||||
class BloomForCausalLM(nn.Module):
|
||||
|
||||
def __init__(self, config: BloomConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: BloomConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.transformer = BloomModel(config)
|
||||
# TODO(zhuohan): create a new weight after implementing pipeline
|
||||
# parallelism
|
||||
self.linear_method = linear_method
|
||||
self.transformer = BloomModel(config, linear_method)
|
||||
self.lm_head_weight = self.transformer.word_embeddings.weight
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
|
||||
@ -271,54 +300,36 @@ class BloomForCausalLM(nn.Module):
|
||||
input_metadata)
|
||||
return next_tokens
|
||||
|
||||
_column_parallel_weights = [
|
||||
"word_embeddings.weight", "dense_h_to_4h.weight", "dense_h_to_4h.bias"
|
||||
]
|
||||
_row_parallel_weights = ["dense.weight", "dense_4h_to_h.weight"]
|
||||
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto"):
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
state_dict = self.state_dict()
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None):
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format):
|
||||
model_name_or_path, cache_dir, load_format, revision):
|
||||
if name == "lm_head.weight":
|
||||
# Since hidden_states are parallelized, we need to
|
||||
# load lm_head.weight in parallel.
|
||||
self._column_parallel_weights.append(name)
|
||||
# If lm_head is provided, use it instead.
|
||||
param = self.lm_head_weight
|
||||
else:
|
||||
if not name.startswith("transformer."):
|
||||
name = "transformer." + name
|
||||
param = state_dict[name]
|
||||
continue
|
||||
if not name.startswith("transformer."):
|
||||
name = "transformer." + name
|
||||
param = params_dict[name]
|
||||
|
||||
if "query_key_value" in name:
|
||||
# NOTE(woosuk): BLOOM's fused QKV has the shape of
|
||||
# [num_heads * 3 * head_size, hidden_size], while the
|
||||
# required shape is [3 * num_heads * head_size, hidden_size].
|
||||
# NOTE: BLOOM's fused QKV's output_dim has the shape of
|
||||
# (num_heads * 3 * head_size), while the
|
||||
# required shape is (3 * num_heads * head_size).
|
||||
# Thus, we need weight conversion.
|
||||
shard_size = param.shape[0]
|
||||
start = shard_size * tp_rank
|
||||
end = shard_size * (tp_rank + 1)
|
||||
loaded_weight = loaded_weight[start:end]
|
||||
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
num_heads = self.config.num_attention_heads
|
||||
hidden_size = self.config.hidden_size
|
||||
head_size = hidden_size // num_heads
|
||||
if "query_key_value.weight" in name:
|
||||
loaded_weight = loaded_weight.view(-1, 3, head_size,
|
||||
hidden_size)
|
||||
loaded_weight = loaded_weight.transpose(0, 1)
|
||||
loaded_weight = loaded_weight.reshape(-1, hidden_size)
|
||||
elif "query_key_value.bias" in name:
|
||||
loaded_weight = loaded_weight.view(-1, 3, head_size)
|
||||
loaded_weight = loaded_weight.transpose(0, 1)
|
||||
loaded_weight = loaded_weight.reshape(-1)
|
||||
else:
|
||||
raise ValueError(f"Unexpected weight name: {name}")
|
||||
load_tensor_parallel_weights(param, loaded_weight, name,
|
||||
self._column_parallel_weights,
|
||||
self._row_parallel_weights, tp_rank)
|
||||
if output_dim is not None:
|
||||
loaded_weight_shape = loaded_weight.shape
|
||||
loaded_weight = loaded_weight.view(
|
||||
loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
|
||||
loaded_weight_shape[output_dim + 1:])
|
||||
loaded_weight = loaded_weight.transpose(
|
||||
output_dim, output_dim + 1)
|
||||
loaded_weight = loaded_weight.reshape(loaded_weight_shape)
|
||||
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
376
vllm/model_executor/models/chatglm.py
Normal file
@ -0,0 +1,376 @@
|
||||
# coding=utf-8
|
||||
# Adapted from
|
||||
# https://github.com/THUDM/ChatGLM2-6B
|
||||
"""Inference-only ChatGLM model compatible with THUDM weights.
|
||||
|
||||
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
||||
InputMetadata to extract the original 2D shape of the input.
|
||||
"""
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import LayerNorm
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding, ParallelLMHead)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.transformers_utils.configs import ChatGLMConfig
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
class GLMAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.total_num_heads = config.num_attention_heads
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_size
|
||||
self.multi_query_attention = config.multi_query_attention
|
||||
self.total_num_kv_heads = (config.multi_query_group_num
|
||||
if config.multi_query_attention else
|
||||
config.num_attention_heads)
|
||||
if self.total_num_kv_heads >= tp_size:
|
||||
# Number of KV heads is greater than TP size, so we partition
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert self.total_num_kv_heads % tp_size == 0
|
||||
else:
|
||||
# Number of KV heads is less than TP size, so we replicate
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert tp_size % self.total_num_kv_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||
self.head_dim = config.hidden_size // self.total_num_heads
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
|
||||
self.query_key_value = QKVParallelLinear(
|
||||
self.hidden_size,
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=config.add_bias_linear or config.add_qkv_bias,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.dense = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
config.hidden_size,
|
||||
bias=config.add_bias_linear,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
|
||||
self.attn = PagedAttentionWithRoPE(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
rotary_dim=self.head_dim // 2,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
is_neox_style=False,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.query_key_value(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
key_cache, value_cache = kv_cache
|
||||
|
||||
context_layer = self.attn(
|
||||
position_ids,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
key_cache,
|
||||
value_cache,
|
||||
input_metadata,
|
||||
cache_event,
|
||||
)
|
||||
|
||||
attn_output, _ = self.dense(context_layer)
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
class GLMMLP(nn.Module):
|
||||
"""MLP.
|
||||
|
||||
MLP will take the input with h hidden state, project it to 4*h
|
||||
hidden dimension, perform nonlinear transformation, and project the
|
||||
state back into h hidden dimension.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.add_bias = config.add_bias_linear
|
||||
|
||||
# Project to 4h.
|
||||
self.dense_h_to_4h = MergedColumnParallelLinear(
|
||||
config.hidden_size,
|
||||
[config.ffn_hidden_size] * 2,
|
||||
bias=config.add_bias_linear,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
|
||||
self.activation_func = SiluAndMul()
|
||||
|
||||
# Project back to h.
|
||||
self.dense_4h_to_h = RowParallelLinear(
|
||||
config.ffn_hidden_size,
|
||||
config.hidden_size,
|
||||
bias=config.add_bias_linear,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
# [s, b, 4hp]
|
||||
intermediate_parallel, _ = self.dense_h_to_4h(hidden_states)
|
||||
intermediate_parallel = self.activation_func(intermediate_parallel)
|
||||
# [s, b, h]
|
||||
output, _ = self.dense_4h_to_h(intermediate_parallel)
|
||||
return output
|
||||
|
||||
|
||||
class GLMBlock(nn.Module):
|
||||
"""A single transformer layer.
|
||||
|
||||
Transformer layer takes input with size [s, b, h] and returns an
|
||||
output of the same size.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.apply_residual_connection_post_layernorm = (
|
||||
config.apply_residual_connection_post_layernorm)
|
||||
|
||||
self.fp32_residual_connection = config.fp32_residual_connection
|
||||
|
||||
layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
|
||||
# Layernorm on the input data.
|
||||
self.input_layernorm = layer_norm_func(config.hidden_size,
|
||||
eps=config.layernorm_epsilon)
|
||||
|
||||
# Self attention.
|
||||
self.self_attention = GLMAttention(config, linear_method)
|
||||
self.hidden_dropout = config.hidden_dropout
|
||||
|
||||
# Layernorm on the attention output
|
||||
self.post_attention_layernorm = layer_norm_func(
|
||||
config.hidden_size, eps=config.layernorm_epsilon)
|
||||
|
||||
# MLP
|
||||
self.mlp = GLMMLP(config, linear_method)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
# hidden_states: [num_tokens, h]
|
||||
# Layer norm at the beginning of the transformer layer.
|
||||
layernorm_output = self.input_layernorm(hidden_states)
|
||||
# Self attention.
|
||||
attention_output = self.self_attention(
|
||||
hidden_states=layernorm_output,
|
||||
position_ids=position_ids,
|
||||
kv_cache=kv_cache,
|
||||
input_metadata=input_metadata,
|
||||
cache_event=cache_event,
|
||||
)
|
||||
|
||||
# Residual connection.
|
||||
if self.apply_residual_connection_post_layernorm:
|
||||
residual = layernorm_output
|
||||
else:
|
||||
residual = hidden_states
|
||||
|
||||
layernorm_input = residual + attention_output
|
||||
|
||||
# Layer norm post the self attention.
|
||||
layernorm_output = self.post_attention_layernorm(layernorm_input)
|
||||
|
||||
# Second residual connection.
|
||||
if self.apply_residual_connection_post_layernorm:
|
||||
residual = layernorm_output
|
||||
else:
|
||||
residual = layernorm_input
|
||||
|
||||
output = self.mlp(layernorm_output) + residual
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class GLMTransformer(nn.Module):
|
||||
"""Transformer class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.post_layer_norm = config.post_layer_norm
|
||||
|
||||
# Number of layers.
|
||||
self.num_layers = config.num_layers
|
||||
|
||||
# Transformer layers.
|
||||
self.layers = nn.ModuleList(
|
||||
[GLMBlock(config, linear_method) for i in range(self.num_layers)])
|
||||
|
||||
if self.post_layer_norm:
|
||||
layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
|
||||
# Final layer norm before output.
|
||||
self.final_layernorm = layer_norm_func(
|
||||
config.hidden_size, eps=config.layernorm_epsilon)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> torch.Tensor:
|
||||
for i in range(self.num_layers):
|
||||
if cache_events is None:
|
||||
cache_event = None
|
||||
else:
|
||||
cache_event = cache_events[i]
|
||||
layer = self.layers[i]
|
||||
hidden_states = layer(
|
||||
hidden_states=hidden_states,
|
||||
position_ids=position_ids,
|
||||
kv_cache=kv_caches[i],
|
||||
input_metadata=input_metadata,
|
||||
cache_event=cache_event,
|
||||
)
|
||||
# Final layer norm.
|
||||
if self.post_layer_norm:
|
||||
hidden_states = self.final_layernorm(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class ChatGLMModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.embedding = VocabParallelEmbedding(config.padded_vocab_size,
|
||||
config.hidden_size)
|
||||
|
||||
self.num_layers = config.num_layers
|
||||
self.multi_query_group_num = config.multi_query_group_num
|
||||
self.kv_channels = config.kv_channels
|
||||
self.encoder = GLMTransformer(config, linear_method)
|
||||
|
||||
self.output_layer = ParallelLMHead(config.padded_vocab_size,
|
||||
config.hidden_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
):
|
||||
inputs_embeds = self.embedding(input_ids)
|
||||
|
||||
# Run encoder.
|
||||
hidden_states = self.encoder(
|
||||
hidden_states=inputs_embeds,
|
||||
position_ids=position_ids,
|
||||
kv_caches=kv_caches,
|
||||
input_metadata=input_metadata,
|
||||
cache_events=cache_events,
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class ChatGLMForCausalLM(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ChatGLMConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config: ChatGLMConfig = config
|
||||
self.linear_method = linear_method
|
||||
self.transformer = ChatGLMModel(config, linear_method)
|
||||
self.lm_head_weight = self.transformer.output_layer.weight
|
||||
self.sampler = Sampler(config.padded_vocab_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> SamplerOutput:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
input_metadata, cache_events)
|
||||
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
|
||||
input_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None):
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format, revision):
|
||||
if "rotary_pos_emb.inv_freq" in name:
|
||||
continue
|
||||
if "word_embeddings" in name:
|
||||
name = name.replace(".word_embeddings", "")
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
@ -27,18 +27,23 @@ from torch.nn import LayerNorm
|
||||
from transformers import FalconConfig as HF_FalconConfig
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.attention import (PagedAttention,
|
||||
PagedAttentionWithALiBi,
|
||||
PagedAttentionWithRoPE)
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearMethodBase,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.weight_utils import (convert_pyslice_to_tensor,
|
||||
hf_model_weights_iterator,
|
||||
load_tensor_parallel_weights)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding, ParallelLMHead)
|
||||
from vllm.model_executor.parallel_utils.communication_op import (
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear,
|
||||
reduce_from_tensor_model_parallel_region)
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.transformers_utils.configs import RWConfig
|
||||
|
||||
@ -46,19 +51,6 @@ KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
FalconConfig = Union[HF_FalconConfig, RWConfig]
|
||||
|
||||
|
||||
# NOTE(Hesslow): Unfortunately we did not fuse matmul and bias during
|
||||
# training, this means that there's one additional quantization to bfloat16
|
||||
# between the operations. In order not to degrade the quality of our HF-port,
|
||||
# we keep these characteristics in the final model.
|
||||
class FalconLinear(nn.Linear):
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = x @ self.weight.T
|
||||
if self.bias is None:
|
||||
return hidden_states
|
||||
return hidden_states + self.bias
|
||||
|
||||
|
||||
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
|
||||
closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
|
||||
base = torch.tensor(2**(-(2**-(math.log2(closest_power_of_2) - 3))),
|
||||
@ -84,7 +76,11 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
|
||||
|
||||
class FalconAttention(nn.Module):
|
||||
|
||||
def __init__(self, config: FalconConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: FalconConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -101,44 +97,29 @@ class FalconAttention(nn.Module):
|
||||
|
||||
if self.new_decoder_architecture:
|
||||
self.total_num_kv_heads = config.num_kv_heads
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
self.num_kv_heads = self.total_num_kv_heads // tp_size
|
||||
self.query_key_value = ColumnParallelLinear(
|
||||
self.hidden_size,
|
||||
(self.total_num_heads + 2 * self.total_num_kv_heads) *
|
||||
self.head_dim,
|
||||
bias=config.bias,
|
||||
gather_output=False,
|
||||
perform_initialization=False,
|
||||
skip_bias_add=True,
|
||||
)
|
||||
elif self.multi_query:
|
||||
self.total_num_kv_heads = 1
|
||||
self.num_kv_heads = 1
|
||||
self.query = ColumnParallelLinear(
|
||||
self.hidden_size,
|
||||
self.total_num_heads * self.head_dim,
|
||||
bias=config.bias,
|
||||
gather_output=False,
|
||||
perform_initialization=False,
|
||||
skip_bias_add=True,
|
||||
)
|
||||
self.key_value = FalconLinear(self.hidden_size,
|
||||
2 * self.head_dim,
|
||||
bias=config.bias)
|
||||
else:
|
||||
self.total_num_kv_heads = self.total_num_heads
|
||||
self.num_kv_heads = self.num_heads
|
||||
self.query_key_value = ColumnParallelLinear(
|
||||
self.hidden_size,
|
||||
(self.total_num_heads + 2 * self.total_num_kv_heads) *
|
||||
self.head_dim,
|
||||
bias=config.bias,
|
||||
gather_output=False,
|
||||
perform_initialization=False,
|
||||
skip_bias_add=True,
|
||||
)
|
||||
if self.total_num_kv_heads >= tp_size:
|
||||
# Number of KV heads is greater than TP size, so we partition
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert self.total_num_kv_heads % tp_size == 0
|
||||
else:
|
||||
# Number of KV heads is less than TP size, so we replicate
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert tp_size % self.total_num_kv_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||
|
||||
self.query_key_value = QKVParallelLinear(
|
||||
self.hidden_size,
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=config.bias,
|
||||
skip_bias_add=True,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
|
||||
@ -150,9 +131,8 @@ class FalconAttention(nn.Module):
|
||||
self.hidden_size,
|
||||
self.hidden_size,
|
||||
bias=config.bias,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False,
|
||||
skip_bias_add=True,
|
||||
linear_method=linear_method,
|
||||
reduce_results=self.reduce_row_parallel_results)
|
||||
|
||||
self.use_rotary = config.rotary
|
||||
@ -161,12 +141,17 @@ class FalconAttention(nn.Module):
|
||||
"Rotary and alibi are mutually exclusive.")
|
||||
|
||||
if self.use_rotary:
|
||||
# TODO(zhuohan): Pass in correct `max_position``
|
||||
self.attn = PagedAttentionWithRoPE(self.num_heads,
|
||||
self.head_dim,
|
||||
self.inv_norm_factor,
|
||||
rotary_dim=self.head_dim,
|
||||
num_kv_heads=self.num_kv_heads)
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
max_position_embeddings = getattr(config,
|
||||
"max_position_embeddings", 8192)
|
||||
self.attn = PagedAttentionWithRoPE(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.inv_norm_factor,
|
||||
base=rope_theta,
|
||||
max_position=max_position_embeddings,
|
||||
rotary_dim=self.head_dim,
|
||||
num_kv_heads=self.num_kv_heads)
|
||||
elif self.use_alibi:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
head_start = tp_rank * self.num_heads
|
||||
@ -193,18 +178,10 @@ class FalconAttention(nn.Module):
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
if not self.new_decoder_architecture and self.multi_query:
|
||||
q, bias = self.query(hidden_states)
|
||||
if bias is not None:
|
||||
q += bias
|
||||
kv = self.key_value(hidden_states)
|
||||
k, v = kv.split([self.kv_size, self.kv_size], dim=-1)
|
||||
else:
|
||||
qkv, bias = self.query_key_value(hidden_states)
|
||||
if bias is not None:
|
||||
qkv += bias
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
|
||||
dim=-1)
|
||||
qkv, bias = self.query_key_value(hidden_states)
|
||||
if bias is not None:
|
||||
qkv += bias
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
k_cache, v_cache = kv_cache
|
||||
if self.use_rotary:
|
||||
attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
|
||||
@ -218,27 +195,30 @@ class FalconAttention(nn.Module):
|
||||
|
||||
class FalconMLP(nn.Module):
|
||||
|
||||
def __init__(self, config: FalconConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: FalconConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
|
||||
self.dense_h_to_4h = ColumnParallelLinear(hidden_size,
|
||||
4 * hidden_size,
|
||||
bias=config.bias,
|
||||
gather_output=False,
|
||||
perform_initialization=False,
|
||||
skip_bias_add=True)
|
||||
self.act = nn.GELU()
|
||||
skip_bias_add=True,
|
||||
linear_method=linear_method)
|
||||
quant_config = getattr(linear_method, "quant_config", None)
|
||||
self.act = get_act_fn("gelu", quant_config, 4 * hidden_size)
|
||||
self.reduce_row_parallel_results = not (config.new_decoder_architecture
|
||||
or config.parallel_attn)
|
||||
self.dense_4h_to_h = RowParallelLinear(
|
||||
4 * hidden_size,
|
||||
hidden_size,
|
||||
bias=config.bias,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False,
|
||||
skip_bias_add=True,
|
||||
reduce_results=self.reduce_row_parallel_results)
|
||||
reduce_results=self.reduce_row_parallel_results,
|
||||
linear_method=linear_method)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# NOTE(zhuohan): Following huggingface, we do not fuse bias add here.
|
||||
@ -252,12 +232,16 @@ class FalconMLP(nn.Module):
|
||||
|
||||
class FalconDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(self, config: FalconConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: FalconConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.self_attention = FalconAttention(config)
|
||||
self.mlp = FalconMLP(config)
|
||||
self.self_attention = FalconAttention(config, linear_method)
|
||||
self.mlp = FalconMLP(config, linear_method)
|
||||
self.config = config
|
||||
|
||||
if config.new_decoder_architecture:
|
||||
@ -320,7 +304,7 @@ class FalconDecoderLayer(nn.Module):
|
||||
# only one all-reduce operator to reduce the results from
|
||||
# both MLP and Attention layers.
|
||||
mlp_output += attention_output
|
||||
mlp_output = reduce_from_tensor_model_parallel_region(mlp_output)
|
||||
mlp_output = tensor_model_parallel_all_reduce(mlp_output)
|
||||
if attention_bias is not None:
|
||||
mlp_output += attention_bias
|
||||
if mlp_bias is not None:
|
||||
@ -333,7 +317,11 @@ class FalconDecoderLayer(nn.Module):
|
||||
|
||||
class FalconModel(nn.Module):
|
||||
|
||||
def __init__(self, config: FalconConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: FalconConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
@ -342,11 +330,14 @@ class FalconModel(nn.Module):
|
||||
|
||||
# Embedding + LN Embedding
|
||||
self.word_embeddings = VocabParallelEmbedding(
|
||||
config.vocab_size, self.embed_dim, perform_initialization=False)
|
||||
config.vocab_size,
|
||||
self.embed_dim,
|
||||
)
|
||||
|
||||
# Transformer blocks
|
||||
self.h = nn.ModuleList([
|
||||
FalconDecoderLayer(config) for _ in range(config.num_hidden_layers)
|
||||
FalconDecoderLayer(config, linear_method)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
|
||||
# Final Layer Norm
|
||||
@ -380,15 +371,19 @@ class FalconModel(nn.Module):
|
||||
|
||||
class FalconForCausalLM(nn.Module):
|
||||
|
||||
def __init__(self, config: FalconConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: FalconConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.transformer = FalconModel(config)
|
||||
self.lm_head = ColumnParallelLinear(config.hidden_size,
|
||||
config.vocab_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.linear_method = linear_method
|
||||
self.transformer = FalconModel(config, linear_method)
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
)
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
|
||||
def forward(
|
||||
@ -411,88 +406,44 @@ class FalconForCausalLM(nn.Module):
|
||||
|
||||
return next_tokens
|
||||
|
||||
_column_parallel_weights = [
|
||||
"word_embeddings.weight", "lm_head.weight", "dense_h_to_4h.weight",
|
||||
"dense_h_to_4h.bias"
|
||||
]
|
||||
_row_parallel_weights = ["dense.weight", "dense_4h_to_h.weight"]
|
||||
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto"):
|
||||
tp_size = (get_tensor_model_parallel_world_size())
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
hidden_size = self.config.hidden_size
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None):
|
||||
total_num_heads = self.config.num_attention_heads
|
||||
num_heads = total_num_heads // tp_size
|
||||
head_size = hidden_size // total_num_heads
|
||||
head_start = tp_rank * num_heads
|
||||
head_end = (tp_rank + 1) * num_heads
|
||||
if self.config.new_decoder_architecture:
|
||||
total_num_kv_heads = self.config.num_kv_heads
|
||||
num_kv_heads = total_num_kv_heads // tp_size
|
||||
separated_q_kv = False
|
||||
kv_head_start = tp_rank * num_kv_heads
|
||||
kv_head_end = (tp_rank + 1) * num_kv_heads
|
||||
elif self.config.multi_query:
|
||||
total_num_kv_heads = 1
|
||||
num_kv_heads = 1
|
||||
separated_q_kv = True
|
||||
kv_head_start = 0
|
||||
kv_head_end = 1
|
||||
else:
|
||||
total_num_kv_heads = total_num_heads
|
||||
num_kv_heads = total_num_kv_heads // tp_size
|
||||
separated_q_kv = False
|
||||
kv_head_start = tp_rank * num_kv_heads
|
||||
kv_head_end = (tp_rank + 1) * num_kv_heads
|
||||
num_query_heads_per_kv_head = total_num_heads // total_num_kv_heads
|
||||
state_dict = self.state_dict()
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format):
|
||||
model_name_or_path, cache_dir, load_format, revision):
|
||||
param = params_dict[name]
|
||||
if "query_key_value" in name:
|
||||
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
||||
loaded_weight_size = loaded_weight.size()
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
loaded_weight_shape = loaded_weight.shape
|
||||
loaded_weight = loaded_weight.view(
|
||||
total_num_kv_heads, num_query_heads_per_kv_head + 2,
|
||||
head_size, *loaded_weight_size[1:])
|
||||
loaded_weight_shape[:output_dim] +
|
||||
(total_num_kv_heads, num_query_heads_per_kv_head + 2, -1) +
|
||||
loaded_weight_shape[output_dim + 1:])
|
||||
wq = loaded_weight.narrow(
|
||||
output_dim + 1, 0, num_query_heads_per_kv_head).reshape(
|
||||
*loaded_weight_shape[:output_dim], -1,
|
||||
*loaded_weight_shape[output_dim + 1:])
|
||||
wk = loaded_weight.narrow(
|
||||
output_dim + 1, num_query_heads_per_kv_head,
|
||||
1).reshape(*loaded_weight_shape[:output_dim], -1,
|
||||
*loaded_weight_shape[output_dim + 1:])
|
||||
wv = loaded_weight.narrow(
|
||||
output_dim + 1, num_query_heads_per_kv_head + 1,
|
||||
1).reshape(*loaded_weight_shape[:output_dim], -1,
|
||||
*loaded_weight_shape[output_dim + 1:])
|
||||
loaded_weight = torch.cat([wq, wk, wv], dim=output_dim)
|
||||
|
||||
wq = loaded_weight[:, :-2].reshape(-1, *loaded_weight_size[1:])
|
||||
wk = loaded_weight[:, [-2]].reshape(-1,
|
||||
*loaded_weight_size[1:])
|
||||
wv = loaded_weight[:, [-1]].reshape(-1,
|
||||
*loaded_weight_size[1:])
|
||||
|
||||
wq = wq[head_size * head_start:head_size * head_end]
|
||||
wk = wk[head_size * kv_head_start:head_size * kv_head_end]
|
||||
wv = wv[head_size * kv_head_start:head_size * kv_head_end]
|
||||
|
||||
if separated_q_kv:
|
||||
loaded_weight_q = wq
|
||||
loaded_weight_kv = torch.cat([wk, wv], dim=0)
|
||||
q_weight_name = name.replace("query_key_value", "query")
|
||||
kv_weight_name = name.replace("query_key_value",
|
||||
"key_value")
|
||||
load_tensor_parallel_weights(state_dict[q_weight_name],
|
||||
loaded_weight_q,
|
||||
q_weight_name,
|
||||
self._column_parallel_weights,
|
||||
self._row_parallel_weights,
|
||||
tp_rank)
|
||||
load_tensor_parallel_weights(state_dict[kv_weight_name],
|
||||
loaded_weight_kv,
|
||||
kv_weight_name,
|
||||
self._column_parallel_weights,
|
||||
self._row_parallel_weights,
|
||||
tp_rank)
|
||||
continue
|
||||
else:
|
||||
loaded_weight = torch.cat([wq, wk, wv], dim=0)
|
||||
|
||||
param = state_dict[name]
|
||||
load_tensor_parallel_weights(param, loaded_weight, name,
|
||||
self._column_parallel_weights,
|
||||
self._row_parallel_weights, tp_rank)
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|