mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
Compare commits
64 Commits
Author | SHA1 | Date | |
---|---|---|---|
e2fb71ec9f | |||
f936657eb6 | |||
6f88f762bf | |||
202351d5bf | |||
2e8e49fce3 | |||
a8e98aee0c | |||
bb1ba58f06 | |||
7bedab5748 | |||
20f7cc4cde | |||
649aa730c5 | |||
a19bc5c628 | |||
28e616c4e3 | |||
30e775281d | |||
21877b0d75 | |||
cf5cb1e33e | |||
03ffd0a022 | |||
a425bd9a9a | |||
bbbf86565f | |||
9f6be8692e | |||
f187877945 | |||
947b794146 | |||
8d926e91f1 | |||
4ee52bb169 | |||
7d7e3b78a3 | |||
f98b745a81 | |||
2d1e86f1b1 | |||
1ac4ccf73c | |||
2ac4d5e2bf | |||
3302f0aef3 | |||
6f2dd6c37e | |||
bc0644574c | |||
400b8289f7 | |||
c1026311b5 | |||
2b1c116b5a | |||
cc796b1358 | |||
f029ef94d7 | |||
95592fa00a | |||
fbe66e1d0b | |||
90979c38f8 | |||
e21d7687a9 | |||
ff36139ffc | |||
e3e79e9e8a | |||
b9fe4616f9 | |||
64ca424e75 | |||
b5f93d0631 | |||
a58936966f | |||
dd54a4b026 | |||
eda1a7cad3 | |||
f04908cae7 | |||
ab019eea75 | |||
9841d48a10 | |||
3272d7a0b7 | |||
0bb1e885a0 | |||
d6545ad22e | |||
90eb3f43ca | |||
e67b4f2c2a | |||
d6770d1f23 | |||
b9cecc2635 | |||
898285c9bf | |||
a62de9ecfd | |||
4042d192f5 | |||
1117aa1411 | |||
080438477f | |||
4b5bcf8906 |
4
.gitignore
vendored
4
.gitignore
vendored
@ -173,3 +173,7 @@ cython_debug/
|
||||
|
||||
# Sphinx documentation
|
||||
_build/
|
||||
|
||||
# vim swap files
|
||||
*.swo
|
||||
*.swp
|
||||
|
32
README.md
32
README.md
@ -10,13 +10,24 @@ Easy, fast, and cheap LLM serving for everyone
|
||||
</h3>
|
||||
|
||||
<p align="center">
|
||||
| <a href="https://vllm.readthedocs.io/en/latest/"><b>Documentation</b></a> | <a href="https://vllm.ai"><b>Blog</b></a> | <a href="https://github.com/vllm-project/vllm/discussions"><b>Discussions</b></a> |
|
||||
| <a href="https://vllm.readthedocs.io/en/latest/"><b>Documentation</b></a> | <a href="https://vllm.ai"><b>Blog</b></a> | <a href="https://arxiv.org/abs/2309.06180"><b>Paper</b></a> | <a href="https://discord.gg/jz7wjKhh6g"><b>Discord</b></a> |
|
||||
|
||||
</p>
|
||||
|
||||
---
|
||||
|
||||
**The First vLLM Bay Area Meetup (Oct 5th 6pm-8pm PT)**
|
||||
|
||||
We are excited to invite you to the first vLLM meetup!
|
||||
The vLLM team will share recent updates and roadmap.
|
||||
We will also have vLLM users and contributors coming up to the stage to share their experiences.
|
||||
Please register [here](https://lu.ma/first-vllm-meetup) and join us!
|
||||
|
||||
---
|
||||
|
||||
*Latest News* 🔥
|
||||
- [2023/09] We created our [Discord server](https://discord.gg/jz7wjKhh6g)! Join us to discuss vLLM and LLM serving! We will also post the latest announcements and updates there.
|
||||
- [2023/09] We released our [PagedAttention paper](https://arxiv.org/abs/2309.06180) on arXiv!
|
||||
- [2023/08] We would like to express our sincere gratitude to [Andreessen Horowitz](https://a16z.com/2023/08/30/supporting-the-open-source-ai-community/) (a16z) for providing a generous grant to support the open-source development and research of vLLM.
|
||||
- [2023/07] Added support for LLaMA-2! You can run and serve 7B/13B/70B LLaMA-2s on vLLM with a single command!
|
||||
- [2023/06] Serving vLLM On any Cloud with SkyPilot. Check out a 1-click [example](https://github.com/skypilot-org/skypilot/blob/master/llm/vllm) to start the vLLM demo, and the [blog post](https://blog.skypilot.co/serving-llm-24x-faster-on-the-cloud-with-vllm-and-skypilot/) for the story behind vLLM development on the clouds.
|
||||
@ -35,13 +46,13 @@ vLLM is fast with:
|
||||
|
||||
vLLM is flexible and easy to use with:
|
||||
|
||||
- Seamless integration with popular HuggingFace models
|
||||
- Seamless integration with popular Hugging Face models
|
||||
- High-throughput serving with various decoding algorithms, including *parallel sampling*, *beam search*, and more
|
||||
- Tensor parallelism support for distributed inference
|
||||
- Streaming outputs
|
||||
- OpenAI-compatible API server
|
||||
|
||||
vLLM seamlessly supports many Huggingface models, including the following architectures:
|
||||
vLLM seamlessly supports many Hugging Face models, including the following architectures:
|
||||
|
||||
- Aquila (`BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc.)
|
||||
- Baichuan (`baichuan-inc/Baichuan-7B`, `baichuan-inc/Baichuan-13B-Chat`, etc.)
|
||||
@ -53,6 +64,7 @@ vLLM seamlessly supports many Huggingface models, including the following archit
|
||||
- GPT-NeoX (`EleutherAI/gpt-neox-20b`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc.)
|
||||
- InternLM (`internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc.)
|
||||
- LLaMA & LLaMA-2 (`meta-llama/Llama-2-70b-hf`, `lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.)
|
||||
- Mistral (`mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc.)
|
||||
- MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.)
|
||||
- OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.)
|
||||
- Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.)
|
||||
@ -72,7 +84,7 @@ Visit our [documentation](https://vllm.readthedocs.io/en/latest/) to get started
|
||||
|
||||
## Performance
|
||||
|
||||
vLLM outperforms HuggingFace Transformers (HF) by up to 24x and Text Generation Inference (TGI) by up to 3.5x, in terms of throughput.
|
||||
vLLM outperforms Hugging Face Transformers (HF) by up to 24x and Text Generation Inference (TGI) by up to 3.5x, in terms of throughput.
|
||||
For details, check out our [blog post](https://vllm.ai).
|
||||
|
||||
<p align="center">
|
||||
@ -104,3 +116,15 @@ For details, check out our [blog post](https://vllm.ai).
|
||||
|
||||
We welcome and value any contributions and collaborations.
|
||||
Please check out [CONTRIBUTING.md](./CONTRIBUTING.md) for how to get involved.
|
||||
|
||||
## Citation
|
||||
|
||||
If you use vLLM for your research, please cite our [paper](https://arxiv.org/abs/2309.06180):
|
||||
```bibtex
|
||||
@inproceedings{kwon2023efficient,
|
||||
title={Efficient Memory Management for Large Language Model Serving with PagedAttention},
|
||||
author={Woosuk Kwon and Zhuohan Li and Siyuan Zhuang and Ying Sheng and Lianmin Zheng and Cody Hao Yu and Joseph E. Gonzalez and Hao Zhang and Ion Stoica},
|
||||
booktitle={Proceedings of the ACM SIGOPS 29th Symposium on Operating Systems Principles},
|
||||
year={2023}
|
||||
}
|
||||
```
|
||||
|
@ -18,6 +18,7 @@ def main(args: argparse.Namespace):
|
||||
llm = LLM(
|
||||
model=args.model,
|
||||
tokenizer=args.tokenizer,
|
||||
quantization=args.quantization,
|
||||
tensor_parallel_size=args.tensor_parallel_size,
|
||||
max_num_seqs=args.batch_size,
|
||||
max_num_batched_tokens=args.batch_size * args.input_len,
|
||||
@ -63,19 +64,28 @@ def main(args: argparse.Namespace):
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Benchmark the latency of processing a single batch of '
|
||||
'requests till completion.')
|
||||
'requests till completion.')
|
||||
parser.add_argument('--model', type=str, default='facebook/opt-125m')
|
||||
parser.add_argument('--tokenizer', type=str, default=None)
|
||||
parser.add_argument('--quantization',
|
||||
'-q',
|
||||
choices=['awq', None],
|
||||
default=None)
|
||||
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
|
||||
parser.add_argument('--input-len', type=int, default=32)
|
||||
parser.add_argument('--output-len', type=int, default=128)
|
||||
parser.add_argument('--batch-size', type=int, default=8)
|
||||
parser.add_argument('--n', type=int, default=1,
|
||||
parser.add_argument('--n',
|
||||
type=int,
|
||||
default=1,
|
||||
help='Number of generated sequences per prompt.')
|
||||
parser.add_argument('--use-beam-search', action='store_true')
|
||||
parser.add_argument('--num-iters', type=int, default=3,
|
||||
parser.add_argument('--num-iters',
|
||||
type=int,
|
||||
default=3,
|
||||
help='Number of iterations to run.')
|
||||
parser.add_argument('--trust-remote-code', action='store_true',
|
||||
parser.add_argument('--trust-remote-code',
|
||||
action='store_true',
|
||||
help='trust remote code from huggingface')
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
@ -3,7 +3,7 @@ import argparse
|
||||
import json
|
||||
import random
|
||||
import time
|
||||
from typing import List, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, PreTrainedTokenizerBase
|
||||
@ -22,15 +22,10 @@ def sample_requests(
|
||||
with open(dataset_path) as f:
|
||||
dataset = json.load(f)
|
||||
# Filter out the conversations with less than 2 turns.
|
||||
dataset = [
|
||||
data for data in dataset
|
||||
if len(data["conversations"]) >= 2
|
||||
]
|
||||
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
|
||||
# Only keep the first two turns of each conversation.
|
||||
dataset = [
|
||||
(data["conversations"][0]["value"], data["conversations"][1]["value"])
|
||||
for data in dataset
|
||||
]
|
||||
dataset = [(data["conversations"][0]["value"],
|
||||
data["conversations"][1]["value"]) for data in dataset]
|
||||
|
||||
# Tokenize the prompts and completions.
|
||||
prompts = [prompt for prompt, _ in dataset]
|
||||
@ -63,6 +58,7 @@ def run_vllm(
|
||||
requests: List[Tuple[str, int, int]],
|
||||
model: str,
|
||||
tokenizer: str,
|
||||
quantization: Optional[str],
|
||||
tensor_parallel_size: int,
|
||||
seed: int,
|
||||
n: int,
|
||||
@ -72,6 +68,7 @@ def run_vllm(
|
||||
llm = LLM(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
quantization=quantization,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
seed=seed,
|
||||
trust_remote_code=trust_remote_code,
|
||||
@ -111,8 +108,8 @@ def run_hf(
|
||||
trust_remote_code: bool,
|
||||
) -> float:
|
||||
assert not use_beam_search
|
||||
llm = AutoModelForCausalLM.from_pretrained(model,
|
||||
torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
|
||||
llm = AutoModelForCausalLM.from_pretrained(
|
||||
model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
|
||||
if llm.config.model_type == "llama":
|
||||
# To enable padding in the HF backend.
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
@ -132,13 +129,14 @@ def run_hf(
|
||||
if len(batch) < max_batch_size and i != len(requests) - 1:
|
||||
# Check if we can add more requests to the batch.
|
||||
_, next_prompt_len, next_output_len = requests[i + 1]
|
||||
if (max(max_prompt_len, next_prompt_len) + max(
|
||||
max_output_len, next_output_len)) <= 2048:
|
||||
if (max(max_prompt_len, next_prompt_len) +
|
||||
max(max_output_len, next_output_len)) <= 2048:
|
||||
# We can add more requests to the batch.
|
||||
continue
|
||||
|
||||
# Generate the sequences.
|
||||
input_ids = tokenizer(batch, return_tensors="pt", padding=True).input_ids
|
||||
input_ids = tokenizer(batch, return_tensors="pt",
|
||||
padding=True).input_ids
|
||||
llm_outputs = llm.generate(
|
||||
input_ids=input_ids.cuda(),
|
||||
do_sample=not use_beam_search,
|
||||
@ -165,44 +163,58 @@ def main(args: argparse.Namespace):
|
||||
random.seed(args.seed)
|
||||
|
||||
# Sample the requests.
|
||||
tokenizer = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code)
|
||||
tokenizer = get_tokenizer(args.tokenizer,
|
||||
trust_remote_code=args.trust_remote_code)
|
||||
requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
|
||||
|
||||
if args.backend == "vllm":
|
||||
elapsed_time = run_vllm(
|
||||
requests, args.model, args.tokenizer, args.tensor_parallel_size,
|
||||
args.seed, args.n, args.use_beam_search, args.trust_remote_code)
|
||||
elapsed_time = run_vllm(requests, args.model, args.tokenizer,
|
||||
args.quantization, args.tensor_parallel_size,
|
||||
args.seed, args.n, args.use_beam_search,
|
||||
args.trust_remote_code)
|
||||
elif args.backend == "hf":
|
||||
assert args.tensor_parallel_size == 1
|
||||
elapsed_time = run_hf(
|
||||
requests, args.model, tokenizer, args.n, args.use_beam_search,
|
||||
args.hf_max_batch_size, args.trust_remote_code)
|
||||
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
|
||||
args.use_beam_search, args.hf_max_batch_size,
|
||||
args.trust_remote_code)
|
||||
else:
|
||||
raise ValueError(f"Unknown backend: {args.backend}")
|
||||
total_num_tokens = sum(
|
||||
prompt_len + output_len
|
||||
for _, prompt_len, output_len in requests
|
||||
)
|
||||
total_num_tokens = sum(prompt_len + output_len
|
||||
for _, prompt_len, output_len in requests)
|
||||
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
|
||||
f"{total_num_tokens / elapsed_time:.2f} tokens/s")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Benchmark the throughput.")
|
||||
parser.add_argument("--backend", type=str, choices=["vllm", "hf"],
|
||||
parser.add_argument("--backend",
|
||||
type=str,
|
||||
choices=["vllm", "hf"],
|
||||
default="vllm")
|
||||
parser.add_argument("--dataset", type=str, required=True,
|
||||
parser.add_argument("--dataset",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the dataset.")
|
||||
parser.add_argument("--model", type=str, default="facebook/opt-125m")
|
||||
parser.add_argument("--tokenizer", type=str, default=None)
|
||||
parser.add_argument('--quantization',
|
||||
'-q',
|
||||
choices=['awq', None],
|
||||
default=None)
|
||||
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
|
||||
parser.add_argument("--n", type=int, default=1,
|
||||
parser.add_argument("--n",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of generated sequences per prompt.")
|
||||
parser.add_argument("--use-beam-search", action="store_true")
|
||||
parser.add_argument("--num-prompts", type=int, default=1000,
|
||||
parser.add_argument("--num-prompts",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="Number of prompts to process.")
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--hf-max-batch-size", type=int, default=None,
|
||||
parser.add_argument("--hf-max-batch-size",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Maximum batch size for HF backend.")
|
||||
parser.add_argument('--trust-remote-code',
|
||||
action='store_true',
|
||||
@ -215,6 +227,8 @@ if __name__ == "__main__":
|
||||
elif args.backend == "hf":
|
||||
if args.hf_max_batch_size is None:
|
||||
raise ValueError("HF max batch size is required for HF backend.")
|
||||
if args.quantization is not None:
|
||||
raise ValueError("Quantization is only for vLLM backend.")
|
||||
if args.tokenizer is None:
|
||||
args.tokenizer = args.model
|
||||
|
||||
|
@ -341,6 +341,9 @@ __global__ void single_query_cached_kv_attention_kernel(
|
||||
} // namespace vllm
|
||||
|
||||
#define LAUNCH_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \
|
||||
cudaFuncSetAttribute( \
|
||||
vllm::single_query_cached_kv_attention_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>, \
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \
|
||||
vllm::single_query_cached_kv_attention_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS> \
|
||||
<<<grid, block, shared_mem_size, stream>>>( \
|
||||
out_ptr, \
|
||||
@ -401,6 +404,8 @@ void single_query_cached_kv_attention_launcher(
|
||||
int padded_max_context_len = ((max_context_len + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE;
|
||||
int logits_size = padded_max_context_len * sizeof(float);
|
||||
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
|
||||
// Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
|
||||
// Keep that in sync with the logic here!
|
||||
int shared_mem_size = std::max(logits_size, outputs_size);
|
||||
|
||||
dim3 grid(num_heads, num_seqs);
|
||||
|
13
csrc/cuda_utils.cpp
Normal file
13
csrc/cuda_utils.cpp
Normal file
@ -0,0 +1,13 @@
|
||||
#include <torch/extension.h>
|
||||
|
||||
int get_device_attribute(
|
||||
int attribute,
|
||||
int device_id);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def(
|
||||
"get_device_attribute",
|
||||
&get_device_attribute,
|
||||
"Gets the specified device attribute.");
|
||||
}
|
||||
|
14
csrc/cuda_utils_kernels.cu
Normal file
14
csrc/cuda_utils_kernels.cu
Normal file
@ -0,0 +1,14 @@
|
||||
int get_device_attribute(
|
||||
int attribute,
|
||||
int device_id)
|
||||
{
|
||||
int device, value;
|
||||
if (device_id < 0) {
|
||||
cudaGetDevice(&device);
|
||||
}
|
||||
else {
|
||||
device = device_id;
|
||||
}
|
||||
cudaDeviceGetAttribute(&value, static_cast<cudaDeviceAttr>(attribute), device);
|
||||
return value;
|
||||
}
|
15
csrc/quantization.cpp
Normal file
15
csrc/quantization.cpp
Normal file
@ -0,0 +1,15 @@
|
||||
#include <torch/extension.h>
|
||||
|
||||
torch::Tensor awq_gemm(
|
||||
torch::Tensor _in_feats,
|
||||
torch::Tensor _kernel,
|
||||
torch::Tensor _scaling_factors,
|
||||
torch::Tensor _zeros,
|
||||
int split_k_iters);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def(
|
||||
"awq_gemm",
|
||||
&awq_gemm,
|
||||
"Quantized GEMM for AWQ");
|
||||
}
|
87
csrc/quantization/awq/dequantize.cuh
Normal file
87
csrc/quantization/awq/dequantize.cuh
Normal file
@ -0,0 +1,87 @@
|
||||
/*
|
||||
Adapted from https://github.com/mit-han-lab/llm-awq
|
||||
Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
|
||||
@article{lin2023awq,
|
||||
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
|
||||
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
|
||||
journal={arXiv},
|
||||
year={2023}
|
||||
}
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace vllm {
|
||||
namespace awq {
|
||||
|
||||
__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
assert(false);
|
||||
#else
|
||||
uint4 result;
|
||||
|
||||
uint32_t* h = reinterpret_cast<uint32_t*>(&result);
|
||||
uint32_t const i4s = reinterpret_cast<uint32_t const&>(source);
|
||||
|
||||
// First, we extract the i4s and construct an intermediate fp16 number.
|
||||
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
|
||||
static constexpr uint32_t BOTTOM_MASK = 0x000f000f;
|
||||
static constexpr uint32_t TOP_MASK = 0x00f000f0;
|
||||
static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;
|
||||
|
||||
// Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing
|
||||
// format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions.
|
||||
// In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and
|
||||
// elt_67 to fp16 without having to shift them to the bottom bits before hand.
|
||||
|
||||
// Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue
|
||||
// immediately before required.
|
||||
const uint32_t top_i4s = i4s >> 8;
|
||||
// Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
: "=r"(h[0])
|
||||
: "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
||||
// Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
: "=r"(h[1])
|
||||
: "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
||||
// Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
: "=r"(h[2])
|
||||
: "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
||||
// Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
: "=r"(h[3])
|
||||
: "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
||||
|
||||
// I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the
|
||||
// half2 ctor. In this case, I chose performance reliability over code readability.
|
||||
|
||||
// This is the half2 {1032, 1032} represented as an integer.
|
||||
// static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408;
|
||||
// Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7]
|
||||
static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400;
|
||||
// This is the half2 {1 / 16, 1 / 16} represented as an integer.
|
||||
static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00;
|
||||
// This is the half2 {-72, -72} represented as an integer.
|
||||
// static constexpr uint32_t NEG_72 = 0xd480d480;
|
||||
// Haotian: Let's use {-64, -64}.
|
||||
static constexpr uint32_t NEG_64 = 0xd400d400;
|
||||
|
||||
// Finally, we construct the output numbers.
|
||||
// Convert elt_01
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
|
||||
// Convert elt_23
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
|
||||
// Convert elt_45
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
|
||||
// Convert elt_67
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
|
||||
|
||||
return result;
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace awq
|
||||
} // namespace vllm
|
491
csrc/quantization/awq/gemm_kernels.cu
Normal file
491
csrc/quantization/awq/gemm_kernels.cu
Normal file
@ -0,0 +1,491 @@
|
||||
/*
|
||||
Adapted from https://github.com/mit-han-lab/llm-awq
|
||||
@article{lin2023awq,
|
||||
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
|
||||
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
|
||||
journal={arXiv},
|
||||
year={2023}
|
||||
}
|
||||
*/
|
||||
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include "dequantize.cuh"
|
||||
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
namespace vllm {
|
||||
namespace awq {
|
||||
|
||||
// Pack two half values.
|
||||
static inline __device__ __host__ unsigned
|
||||
__pack_half2(const half x, const half y) {
|
||||
unsigned v0 = *((unsigned short *)&x);
|
||||
unsigned v1 = *((unsigned short *)&y);
|
||||
return (v1 << 16) | v0;
|
||||
}
|
||||
|
||||
__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
assert(false);
|
||||
#else
|
||||
static constexpr uint32_t ZERO = 0x0;
|
||||
float C_warp[32];
|
||||
__shared__ half A_shared[16 * (32 + 8)];
|
||||
__shared__ half B_shared[32 * (128 + 8)];
|
||||
|
||||
__shared__ half scaling_factors_shared[128];
|
||||
__shared__ half zeros_shared[128];
|
||||
|
||||
int j_factors1 = ((OC + 128 - 1) / 128);
|
||||
int blockIdx_x = 0;
|
||||
int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1);
|
||||
int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1);
|
||||
|
||||
half A_shared_warp[8];
|
||||
half B_shared_warp[32];
|
||||
for (int j_0_4_init = 0; j_0_4_init < 4; ++j_0_4_init) {
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
C_warp[(j_0_4_init * 8) + i] = 0.0;
|
||||
}
|
||||
}
|
||||
|
||||
static constexpr int row_stride_warp = 32 * 8 / 32;
|
||||
static constexpr int row_stride = 2 * 32 * 8 / 128;
|
||||
bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < 128;
|
||||
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
|
||||
bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id
|
||||
// bool wb_C_flag = (threadIdx.x / 4) < M;
|
||||
|
||||
half* A_ptr = A
|
||||
+ (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC
|
||||
+ (((int)threadIdx.x) % (32 / 8)) * 8;
|
||||
|
||||
int* B_ptr = B
|
||||
+ ((int)threadIdx.y) * (OC / 8) * 2
|
||||
+ (((int)threadIdx.x) / (128 / 8)) * (OC / 8)
|
||||
+ (((int)blockIdx_y) % j_factors1) * (128 / 8)
|
||||
+ (((int)threadIdx.x) % (128 / 8)) * 1;
|
||||
// Why * 1 in the above line?
|
||||
|
||||
half* A_shared_ptr = A_shared
|
||||
+ ((int)threadIdx.y) * row_stride_warp * (32 + 8)
|
||||
+ (((int)threadIdx.x) / (32 / 8)) * (32 + 8)
|
||||
+ (((int)threadIdx.x) % (32 / 8) ) * 8;
|
||||
|
||||
half* B_shared_ptr = B_shared
|
||||
+ ((int)threadIdx.y) * (row_stride / 2) * (128 + 8)
|
||||
+ (((int)threadIdx.x) / (128 / 8)) * (128 + 8)
|
||||
+ (((int)threadIdx.x) % (128 / 8)) * 8;
|
||||
|
||||
int* zeros_ptr = zeros
|
||||
+ (((int)blockIdx_y) % j_factors1) * (128 / 8)
|
||||
+ ((int)threadIdx.x) % (128 / 8);
|
||||
|
||||
half* scaling_factors_ptr = scaling_factors
|
||||
+ (((int)blockIdx_y) % j_factors1) * (128)
|
||||
+ (((int)threadIdx.x) % (128 / 8)) * 8;
|
||||
|
||||
half* C_ptr = C
|
||||
+ blockIdx_z * M * OC // blockIdz.x -> split_k dim
|
||||
+ (((int)blockIdx_y) % j_factors1) * 128
|
||||
+ ((int)threadIdx.y) * 64
|
||||
+ (((int)threadIdx.x) % 4) * 2;
|
||||
|
||||
// preload s.f. and zeros
|
||||
int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters;
|
||||
if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1;
|
||||
for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) {
|
||||
int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z;
|
||||
__syncthreads();
|
||||
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
|
||||
if (ld_A_flag)
|
||||
{
|
||||
*(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32));
|
||||
}
|
||||
else
|
||||
{
|
||||
*(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0);
|
||||
}
|
||||
|
||||
// for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
|
||||
uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8));
|
||||
uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
|
||||
uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC));
|
||||
/*
|
||||
if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){
|
||||
printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
|
||||
}
|
||||
*/
|
||||
// uint4 B_loaded_scale = make_uint4(0, 0, 0, 0);
|
||||
int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);
|
||||
|
||||
for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 8; ++ax0_ax1_fused_0) {
|
||||
|
||||
// B: 32 x 136 (128+8) float16
|
||||
// each warp: 32 x 4
|
||||
// each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4
|
||||
// *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8)));
|
||||
// row stride in shared memory: (NWARPS * 32 * 8 / cta_N)
|
||||
uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8));
|
||||
uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
|
||||
//uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8);
|
||||
|
||||
// uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8);
|
||||
// - zero and * scale
|
||||
// TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale.
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
|
||||
/*
|
||||
if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){
|
||||
printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w);
|
||||
}
|
||||
*/
|
||||
|
||||
// write back
|
||||
*(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (128 + 8)) = B_loaded_fp16;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1) {
|
||||
{
|
||||
unsigned int addr;
|
||||
__asm__ __volatile__(
|
||||
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
|
||||
: "=r"(addr)
|
||||
: "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8))))
|
||||
);
|
||||
|
||||
|
||||
__asm__ __volatile__(
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
|
||||
"{%0, %1, %2, %3}, [%4];\n"
|
||||
: "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3])
|
||||
: "r"(addr)
|
||||
);
|
||||
}
|
||||
|
||||
for (int ax1_0 = 0; ax1_0 < 4; ++ax1_0) {
|
||||
{
|
||||
unsigned int addr;
|
||||
__asm__ __volatile__(
|
||||
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
|
||||
: "=r"(addr)
|
||||
: "l"((void *)((&(B_shared[(((k_0_1 * 2176) + (((int)threadIdx.y) * 64)) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * 136) + ((((int)threadIdx.x) >> 4) * 8))))
|
||||
);
|
||||
__asm__ __volatile__(
|
||||
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
|
||||
"{%0, %1, %2, %3}, [%4];\n"
|
||||
: "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3])
|
||||
: "r"(addr)
|
||||
);
|
||||
}
|
||||
}
|
||||
for (int j_0_4 = 0; j_0_4 < 4; ++j_0_4) {
|
||||
{
|
||||
__asm__ __volatile__(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
|
||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
|
||||
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
|
||||
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
|
||||
}
|
||||
|
||||
{
|
||||
__asm__ __volatile__(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
|
||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
|
||||
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
|
||||
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Shang: Hoist loop invariance.
|
||||
for (int ax1_0_1 = 0; ax1_0_1 < 4; ++ax1_0_1) {
|
||||
for (int local_id = 0; local_id < 8; ++local_id) {
|
||||
int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8;
|
||||
if (row_offset < M)
|
||||
{
|
||||
*(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
assert(false);
|
||||
#else
|
||||
static constexpr uint32_t ZERO = 0x0;
|
||||
float C_warp[32];
|
||||
__shared__ half A_shared[16 * (32 + 8)];
|
||||
__shared__ half B_shared[32 * (64 + 8)];
|
||||
|
||||
__shared__ half scaling_factors_shared[64];
|
||||
__shared__ half zeros_shared[64];
|
||||
|
||||
int j_factors1 = ((OC + 64 - 1) / 64);
|
||||
|
||||
int blockIdx_x = 0;
|
||||
int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1);
|
||||
int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1);
|
||||
|
||||
half A_shared_warp[8];
|
||||
half B_shared_warp[16];
|
||||
for (int j_0_4_init = 0; j_0_4_init < 2; ++j_0_4_init) {
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
C_warp[(j_0_4_init * 8) + i] = 0.0;
|
||||
}
|
||||
}
|
||||
|
||||
static constexpr int row_stride_warp = 32 * 8 / 32;
|
||||
static constexpr int row_stride = 2 * 32 * 8 / 64;
|
||||
bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < 64;
|
||||
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
|
||||
bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id
|
||||
// bool wb_C_flag = (threadIdx.x / 4) < M;
|
||||
|
||||
half* A_ptr = A
|
||||
+ (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC
|
||||
+ (((int)threadIdx.x) % (32 / 8)) * 8;
|
||||
|
||||
int* B_ptr = B
|
||||
+ ((int)threadIdx.y) * (OC / 8) * 4
|
||||
+ (((int)threadIdx.x) / (64 / 8)) * (OC / 8)
|
||||
+ (((int)blockIdx_y) % j_factors1) * (64 / 8)
|
||||
+ (((int)threadIdx.x) % (64 / 8)) * 1;
|
||||
// Why * 1 in the above line?
|
||||
|
||||
half* A_shared_ptr = A_shared
|
||||
+ ((int)threadIdx.y) * row_stride_warp * (32 + 8)
|
||||
+ (((int)threadIdx.x) / (32 / 8)) * (32 + 8)
|
||||
+ (((int)threadIdx.x) % (32 / 8) ) * 8;
|
||||
|
||||
half* B_shared_ptr = B_shared
|
||||
+ ((int)threadIdx.y) * (row_stride / 2) * (64 + 8)
|
||||
+ (((int)threadIdx.x) / (64 / 8)) * (64 + 8)
|
||||
+ (((int)threadIdx.x) % (64 / 8)) * 8;
|
||||
|
||||
int* zeros_ptr = zeros
|
||||
+ (((int)blockIdx_y) % j_factors1) * (64 / 8)
|
||||
+ ((int)threadIdx.x) % (64 / 8);
|
||||
|
||||
half* scaling_factors_ptr = scaling_factors
|
||||
+ (((int)blockIdx_y) % j_factors1) * (64)
|
||||
+ (((int)threadIdx.x) % (64 / 8)) * 8;
|
||||
|
||||
half* C_ptr = C
|
||||
+ blockIdx_z * M * OC // blockIdz.x -> split_k dim
|
||||
+ (((int)blockIdx_y) % j_factors1) * 64
|
||||
+ ((int)threadIdx.y) * 32
|
||||
+ (((int)threadIdx.x) % 4) * 2;
|
||||
|
||||
// preload s.f. and zeros
|
||||
int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters;
|
||||
if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1;
|
||||
for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) {
|
||||
int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z;
|
||||
__syncthreads();
|
||||
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
|
||||
if (ld_A_flag)
|
||||
{
|
||||
*(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32));
|
||||
}
|
||||
else
|
||||
{
|
||||
*(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0);
|
||||
}
|
||||
|
||||
// for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
|
||||
uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8));
|
||||
uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
|
||||
uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC));
|
||||
/*
|
||||
if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){
|
||||
printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
|
||||
}
|
||||
*/
|
||||
// uint4 B_loaded_scale = make_uint4(0, 0, 0, 0);
|
||||
int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);
|
||||
|
||||
for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 4; ++ax0_ax1_fused_0) {
|
||||
|
||||
// B: 32 x 136 (128+8) float16
|
||||
// each warp: 32 x 4
|
||||
// each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4
|
||||
// *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8)));
|
||||
// row stride in shared memory: (NWARPS * 32 * 8 / cta_N)
|
||||
uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8));
|
||||
uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
|
||||
//uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8);
|
||||
|
||||
// uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8);
|
||||
// - zero and * scale
|
||||
// TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale.
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
|
||||
/*
|
||||
if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){
|
||||
printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w);
|
||||
}
|
||||
*/
|
||||
|
||||
// write back
|
||||
*(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (64 + 8)) = B_loaded_fp16;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1)
|
||||
{
|
||||
{
|
||||
unsigned int addr;
|
||||
__asm__ __volatile__(
|
||||
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
|
||||
: "=r"(addr)
|
||||
: "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8))))
|
||||
);
|
||||
__asm__ __volatile__(
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
|
||||
"{%0, %1, %2, %3}, [%4];\n"
|
||||
: "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3])
|
||||
: "r"(addr)
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
for (int ax1_0 = 0; ax1_0 < 2; ++ax1_0)
|
||||
{
|
||||
{
|
||||
unsigned int addr;
|
||||
__asm__ __volatile__(
|
||||
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
|
||||
: "=r"(addr)
|
||||
: "l"((void *)((&(B_shared[(((k_0_1 * 1152) + (((int)threadIdx.y) * 32)) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * 72) + ((((int)threadIdx.x) >> 4) * 8))))
|
||||
);
|
||||
__asm__ __volatile__(
|
||||
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
|
||||
"{%0, %1, %2, %3}, [%4];\n"
|
||||
: "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3])
|
||||
: "r"(addr)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
for (int j_0_4 = 0; j_0_4 < 2; ++j_0_4)
|
||||
{
|
||||
|
||||
{
|
||||
__asm__ __volatile__(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
|
||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
|
||||
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
|
||||
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
|
||||
}
|
||||
|
||||
{
|
||||
__asm__ __volatile__(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
|
||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
|
||||
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
|
||||
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Shang: Hoist loop invariance.
|
||||
for (int ax1_0_1 = 0; ax1_0_1 < 2; ++ax1_0_1) {
|
||||
for (int local_id = 0; local_id < 8; ++local_id) {
|
||||
int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8;
|
||||
if (row_offset < M)
|
||||
{
|
||||
*(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace awq
|
||||
} // namespace vllm
|
||||
|
||||
// in_feats: M, IC [float16]
|
||||
// kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b]
|
||||
// scaling_factors: IC // G, OC [float16]
|
||||
// zeros: IC // G, OC // 8 [int32] -> cast to IC // G, OC [uint4b]
|
||||
// assume that batch_size < 16 for now
|
||||
|
||||
torch::Tensor awq_gemm(
|
||||
torch::Tensor _in_feats,
|
||||
torch::Tensor _kernel,
|
||||
torch::Tensor _scaling_factors,
|
||||
torch::Tensor _zeros,
|
||||
int split_k_iters)
|
||||
{
|
||||
int num_in_feats = _in_feats.size(0);
|
||||
int num_in_channels = _in_feats.size(1);
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats));
|
||||
|
||||
auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device());
|
||||
at::Tensor _out_feats = torch::empty({split_k_iters, num_in_feats, _kernel.size(1) * 8}, options);
|
||||
int num_out_feats = _out_feats.size(-2);
|
||||
int num_out_channels = _out_feats.size(-1);
|
||||
|
||||
auto in_feats = reinterpret_cast<half*>(_in_feats.data_ptr<at::Half>());
|
||||
auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>());
|
||||
auto out_feats = reinterpret_cast<half*>(_out_feats.data_ptr<at::Half>());
|
||||
auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
|
||||
auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
|
||||
int group_size = num_in_channels / _scaling_factors.size(0);
|
||||
|
||||
if (num_out_channels % 64 != 0)
|
||||
throw std::invalid_argument("OC is not multiple of cta_N = 64");
|
||||
if (num_out_channels % 8 != 0)
|
||||
throw std::invalid_argument("OC is not multiple of pack_num = 8");
|
||||
if (group_size % 32 != 0)
|
||||
throw std::invalid_argument("Group size should be a multiple of 32");
|
||||
if (num_out_channels % group_size != 0)
|
||||
throw std::invalid_argument("OC is not multiple of Group size");
|
||||
|
||||
if (num_out_channels % 128 == 0)
|
||||
{
|
||||
int j_factors1 = num_out_channels / 128 / 1;
|
||||
dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
|
||||
// threadIdx.x: 32
|
||||
// threadIdx.y: i_factors[2] * j_factors[2]
|
||||
dim3 threads_per_block(32, 2);
|
||||
vllm::awq::gemm_forward_4bit_cuda_m16n128k32<<<num_blocks, threads_per_block>>>(
|
||||
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
|
||||
}
|
||||
else if (num_out_channels % 64 == 0)
|
||||
{
|
||||
int j_factors1 = num_out_channels / 64 / 1;
|
||||
dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
|
||||
|
||||
// threadIdx.x: 32
|
||||
// threadIdx.y: i_factors[2] * j_factors[2]
|
||||
dim3 threads_per_block(32, 2);
|
||||
vllm::awq::gemm_forward_4bit_cuda_m16n64k32<<<num_blocks, threads_per_block>>>(
|
||||
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
|
||||
}
|
||||
return _out_feats.sum(0);
|
||||
}
|
@ -3,31 +3,15 @@
|
||||
Installation
|
||||
============
|
||||
|
||||
vLLM is a Python library that also contains some C++ and CUDA code.
|
||||
This additional code requires compilation on the user's machine.
|
||||
vLLM is a Python library that also contains pre-compiled C++ and CUDA (11.8) binaries.
|
||||
|
||||
Requirements
|
||||
------------
|
||||
|
||||
* OS: Linux
|
||||
* Python: 3.8 or higher
|
||||
* CUDA: 11.0 -- 11.8
|
||||
* Python: 3.8 -- 3.11
|
||||
* GPU: compute capability 7.0 or higher (e.g., V100, T4, RTX20xx, A100, L4, etc.)
|
||||
|
||||
.. note::
|
||||
As of now, vLLM does not support CUDA 12.
|
||||
If you are using Hopper or Lovelace GPUs, please use CUDA 11.8 instead of CUDA 12.
|
||||
|
||||
.. tip::
|
||||
If you have trouble installing vLLM, we recommend using the NVIDIA PyTorch Docker image.
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ # Pull the Docker image with CUDA 11.8.
|
||||
$ docker run --gpus all -it --rm --shm-size=8g nvcr.io/nvidia/pytorch:22.12-py3
|
||||
|
||||
Inside the Docker container, please execute :code:`pip uninstall torch` before installing vLLM.
|
||||
|
||||
Install with pip
|
||||
----------------
|
||||
|
||||
@ -40,7 +24,7 @@ You can install vLLM using pip:
|
||||
$ conda activate myenv
|
||||
|
||||
$ # Install vLLM.
|
||||
$ pip install vllm # This may take 5-10 minutes.
|
||||
$ pip install vllm
|
||||
|
||||
|
||||
.. _build_from_source:
|
||||
@ -55,3 +39,12 @@ You can also build and install vLLM from source:
|
||||
$ git clone https://github.com/vllm-project/vllm.git
|
||||
$ cd vllm
|
||||
$ pip install -e . # This may take 5-10 minutes.
|
||||
|
||||
.. tip::
|
||||
If you have trouble building vLLM, we recommend using the NVIDIA PyTorch Docker image.
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ # Pull the Docker image with CUDA 11.8.
|
||||
$ # Use `--ipc=host` to make sure the shared memory is large enough.
|
||||
$ docker run --gpus all -it --rm --ipc=host nvcr.io/nvidia/pytorch:22.12-py3
|
||||
|
@ -128,4 +128,4 @@ Since this server is compatible with OpenAI API, you can use it as a drop-in rep
|
||||
prompt="San Francisco is a")
|
||||
print("Completion result:", completion)
|
||||
|
||||
For a more detailed client example, refer to `examples/openai_client.py <https://github.com/vllm-project/vllm/blob/main/examples/openai_client.py>`_.
|
||||
For a more detailed client example, refer to `examples/openai_completion_client.py <https://github.com/vllm-project/vllm/blob/main/examples/openai_completion_client.py>`_.
|
||||
|
@ -43,6 +43,7 @@ vLLM is flexible and easy to use with:
|
||||
For more information, check out the following:
|
||||
|
||||
* `vLLM announcing blog post <https://vllm.ai>`_ (intro to PagedAttention)
|
||||
* `vLLM paper <https://arxiv.org/abs/2309.06180>`_ (SOSP 2023)
|
||||
* `How continuous batching enables 23x throughput in LLM inference while reducing p50 latency <https://www.anyscale.com/blog/continuous-batching-llm-inference>`_ by Cade Daniel et al.
|
||||
|
||||
|
||||
@ -63,6 +64,7 @@ Documentation
|
||||
|
||||
serving/distributed_serving
|
||||
serving/run_on_sky
|
||||
serving/deploying_with_triton
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
|
@ -44,6 +44,9 @@ Alongside each architecture, we include some popular models that use it.
|
||||
* - :code:`LlamaForCausalLM`
|
||||
- LLaMA, LLaMA-2, Vicuna, Alpaca, Koala, Guanaco
|
||||
- :code:`meta-llama/Llama-2-13b-hf`, :code:`meta-llama/Llama-2-70b-hf`, :code:`openlm-research/open_llama_13b`, :code:`lmsys/vicuna-13b-v1.3`, :code:`young-geng/koala`, etc.
|
||||
* - :code:`MistralForCausalLM`
|
||||
- Mistral, Mistral-Instruct
|
||||
- :code:`mistralai/Mistral-7B-v0.1`, :code:`mistralai/Mistral-7B-Instruct-v0.1`, etc.
|
||||
* - :code:`MPTForCausalLM`
|
||||
- MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter
|
||||
- :code:`mosaicml/mpt-7b`, :code:`mosaicml/mpt-7b-storywriter`, :code:`mosaicml/mpt-30b`, etc.
|
||||
|
6
docs/source/serving/deploying_with_triton.rst
Normal file
6
docs/source/serving/deploying_with_triton.rst
Normal file
@ -0,0 +1,6 @@
|
||||
.. _deploying_with_triton:
|
||||
|
||||
Deploying with NVIDIA Triton
|
||||
============================
|
||||
|
||||
The `Triton Inference Server <https://github.com/triton-inference-server>`_ hosts a tutorial demonstrating how to quickly deploy a simple `facebook/opt-125m <https://huggingface.co/facebook/opt-125m>`_ model using vLLM. Please see `Deploying a vLLM model in Triton <https://github.com/triton-inference-server/tutorials/blob/main/Quick_Deploy/vLLM/README.md#deploying-a-vllm-model-in-triton>`_ for more details.
|
@ -11,3 +11,4 @@ types-setuptools
|
||||
# testing
|
||||
pytest
|
||||
pytest-forked
|
||||
pytest-asyncio
|
||||
|
@ -1,11 +1,13 @@
|
||||
ninja # For faster builds.
|
||||
psutil
|
||||
ray >= 2.5.1
|
||||
pandas # Required for Ray data.
|
||||
pyarrow # Required for Ray data.
|
||||
sentencepiece # Required for LLaMA tokenizer.
|
||||
numpy
|
||||
torch >= 2.0.0
|
||||
transformers >= 4.33.1 # Required for Code Llama.
|
||||
xformers >= 0.0.21
|
||||
xformers >= 0.0.22
|
||||
fastapi
|
||||
uvicorn
|
||||
uvicorn[standard]
|
||||
pydantic < 2 # Required for OpenAI server.
|
||||
|
171
setup.py
171
setup.py
@ -3,6 +3,7 @@ import os
|
||||
import re
|
||||
import subprocess
|
||||
from typing import List, Set
|
||||
import warnings
|
||||
|
||||
from packaging.version import parse, Version
|
||||
import setuptools
|
||||
@ -11,6 +12,9 @@ from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME
|
||||
|
||||
ROOT_DIR = os.path.dirname(__file__)
|
||||
|
||||
# Supported NVIDIA GPU architectures.
|
||||
SUPPORTED_ARCHS = ["7.0", "7.5", "8.0", "8.6", "8.9", "9.0"]
|
||||
|
||||
# Compiler flags.
|
||||
CXX_FLAGS = ["-g", "-O2", "-std=c++17"]
|
||||
# TODO(woosuk): Should we use -O3?
|
||||
@ -22,7 +26,7 @@ NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
|
||||
|
||||
if CUDA_HOME is None:
|
||||
raise RuntimeError(
|
||||
f"Cannot find CUDA_HOME. CUDA must be available to build the package.")
|
||||
"Cannot find CUDA_HOME. CUDA must be available to build the package.")
|
||||
|
||||
|
||||
def get_nvcc_cuda_version(cuda_dir: str) -> Version:
|
||||
@ -38,47 +42,82 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version:
|
||||
return nvcc_cuda_version
|
||||
|
||||
|
||||
# Collect the compute capabilities of all available GPUs.
|
||||
device_count = torch.cuda.device_count()
|
||||
compute_capabilities: Set[int] = set()
|
||||
for i in range(device_count):
|
||||
major, minor = torch.cuda.get_device_capability(i)
|
||||
if major < 7:
|
||||
raise RuntimeError(
|
||||
"GPUs with compute capability less than 7.0 are not supported.")
|
||||
compute_capabilities.add(major * 10 + minor)
|
||||
def get_torch_arch_list() -> Set[str]:
|
||||
# TORCH_CUDA_ARCH_LIST can have one or more architectures,
|
||||
# e.g. "8.0" or "7.5,8.0,8.6+PTX". Here, the "8.6+PTX" option asks the
|
||||
# compiler to additionally include PTX code that can be runtime-compiled
|
||||
# and executed on the 8.6 or newer architectures. While the PTX code will
|
||||
# not give the best performance on the newer architectures, it provides
|
||||
# forward compatibility.
|
||||
valid_arch_strs = SUPPORTED_ARCHS + [s + "+PTX" for s in SUPPORTED_ARCHS]
|
||||
arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST", None)
|
||||
if arch_list is None:
|
||||
return set()
|
||||
|
||||
# List are separated by ; or space.
|
||||
arch_list = arch_list.replace(" ", ";").split(";")
|
||||
for arch in arch_list:
|
||||
if arch not in valid_arch_strs:
|
||||
raise ValueError(
|
||||
f"Unsupported CUDA arch ({arch}). "
|
||||
f"Valid CUDA arch strings are: {valid_arch_strs}.")
|
||||
return set(arch_list)
|
||||
|
||||
|
||||
# First, check the TORCH_CUDA_ARCH_LIST environment variable.
|
||||
compute_capabilities = get_torch_arch_list()
|
||||
if not compute_capabilities:
|
||||
# If TORCH_CUDA_ARCH_LIST is not defined or empty, target all available
|
||||
# GPUs on the current machine.
|
||||
device_count = torch.cuda.device_count()
|
||||
for i in range(device_count):
|
||||
major, minor = torch.cuda.get_device_capability(i)
|
||||
if major < 7:
|
||||
raise RuntimeError(
|
||||
"GPUs with compute capability below 7.0 are not supported.")
|
||||
compute_capabilities.add(f"{major}.{minor}")
|
||||
|
||||
nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME)
|
||||
if not compute_capabilities:
|
||||
# If no GPU is specified nor available, add all supported architectures
|
||||
# based on the NVCC CUDA version.
|
||||
compute_capabilities = set(SUPPORTED_ARCHS)
|
||||
if nvcc_cuda_version < Version("11.1"):
|
||||
compute_capabilities.remove("8.6")
|
||||
if nvcc_cuda_version < Version("11.8"):
|
||||
compute_capabilities.remove("8.9")
|
||||
compute_capabilities.remove("9.0")
|
||||
|
||||
# Validate the NVCC CUDA version.
|
||||
nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME)
|
||||
if nvcc_cuda_version < Version("11.0"):
|
||||
raise RuntimeError("CUDA 11.0 or higher is required to build the package.")
|
||||
if 86 in compute_capabilities and nvcc_cuda_version < Version("11.1"):
|
||||
raise RuntimeError(
|
||||
"CUDA 11.1 or higher is required for GPUs with compute capability 8.6.")
|
||||
if 89 in compute_capabilities and nvcc_cuda_version < Version("11.8"):
|
||||
# CUDA 11.8 is required to generate the code targeting compute capability 8.9.
|
||||
# However, GPUs with compute capability 8.9 can also run the code generated by
|
||||
# the previous versions of CUDA 11 and targeting compute capability 8.0.
|
||||
# Therefore, if CUDA 11.8 is not available, we target compute capability 8.0
|
||||
# instead of 8.9.
|
||||
compute_capabilities.remove(89)
|
||||
compute_capabilities.add(80)
|
||||
if 90 in compute_capabilities and nvcc_cuda_version < Version("11.8"):
|
||||
raise RuntimeError(
|
||||
"CUDA 11.8 or higher is required for GPUs with compute capability 9.0.")
|
||||
|
||||
# If no GPU is available, add all supported compute capabilities.
|
||||
if not compute_capabilities:
|
||||
compute_capabilities = {70, 75, 80}
|
||||
if nvcc_cuda_version >= Version("11.1"):
|
||||
compute_capabilities.add(86)
|
||||
if nvcc_cuda_version >= Version("11.8"):
|
||||
compute_capabilities.add(89)
|
||||
compute_capabilities.add(90)
|
||||
if nvcc_cuda_version < Version("11.1"):
|
||||
if any(cc.startswith("8.6") for cc in compute_capabilities):
|
||||
raise RuntimeError(
|
||||
"CUDA 11.1 or higher is required for compute capability 8.6.")
|
||||
if nvcc_cuda_version < Version("11.8"):
|
||||
if any(cc.startswith("8.9") for cc in compute_capabilities):
|
||||
# CUDA 11.8 is required to generate the code targeting compute capability 8.9.
|
||||
# However, GPUs with compute capability 8.9 can also run the code generated by
|
||||
# the previous versions of CUDA 11 and targeting compute capability 8.0.
|
||||
# Therefore, if CUDA 11.8 is not available, we target compute capability 8.0
|
||||
# instead of 8.9.
|
||||
warnings.warn(
|
||||
"CUDA 11.8 or higher is required for compute capability 8.9. "
|
||||
"Targeting compute capability 8.0 instead.")
|
||||
compute_capabilities = set(cc for cc in compute_capabilities
|
||||
if not cc.startswith("8.9"))
|
||||
compute_capabilities.add("8.0+PTX")
|
||||
if any(cc.startswith("9.0") for cc in compute_capabilities):
|
||||
raise RuntimeError(
|
||||
"CUDA 11.8 or higher is required for compute capability 9.0.")
|
||||
|
||||
# Add target compute capabilities to NVCC flags.
|
||||
for capability in compute_capabilities:
|
||||
NVCC_FLAGS += ["-gencode", f"arch=compute_{capability},code=sm_{capability}"]
|
||||
num = capability[0] + capability[2]
|
||||
NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=sm_{num}"]
|
||||
if capability.endswith("+PTX"):
|
||||
NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=compute_{num}"]
|
||||
|
||||
# Use NVCC threads to parallelize the build.
|
||||
if nvcc_cuda_version >= Version("11.2"):
|
||||
@ -91,7 +130,10 @@ ext_modules = []
|
||||
cache_extension = CUDAExtension(
|
||||
name="vllm.cache_ops",
|
||||
sources=["csrc/cache.cpp", "csrc/cache_kernels.cu"],
|
||||
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS},
|
||||
extra_compile_args={
|
||||
"cxx": CXX_FLAGS,
|
||||
"nvcc": NVCC_FLAGS,
|
||||
},
|
||||
)
|
||||
ext_modules.append(cache_extension)
|
||||
|
||||
@ -99,7 +141,10 @@ ext_modules.append(cache_extension)
|
||||
attention_extension = CUDAExtension(
|
||||
name="vllm.attention_ops",
|
||||
sources=["csrc/attention.cpp", "csrc/attention/attention_kernels.cu"],
|
||||
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS},
|
||||
extra_compile_args={
|
||||
"cxx": CXX_FLAGS,
|
||||
"nvcc": NVCC_FLAGS,
|
||||
},
|
||||
)
|
||||
ext_modules.append(attention_extension)
|
||||
|
||||
@ -107,7 +152,10 @@ ext_modules.append(attention_extension)
|
||||
positional_encoding_extension = CUDAExtension(
|
||||
name="vllm.pos_encoding_ops",
|
||||
sources=["csrc/pos_encoding.cpp", "csrc/pos_encoding_kernels.cu"],
|
||||
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS},
|
||||
extra_compile_args={
|
||||
"cxx": CXX_FLAGS,
|
||||
"nvcc": NVCC_FLAGS,
|
||||
},
|
||||
)
|
||||
ext_modules.append(positional_encoding_extension)
|
||||
|
||||
@ -115,7 +163,10 @@ ext_modules.append(positional_encoding_extension)
|
||||
layernorm_extension = CUDAExtension(
|
||||
name="vllm.layernorm_ops",
|
||||
sources=["csrc/layernorm.cpp", "csrc/layernorm_kernels.cu"],
|
||||
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS},
|
||||
extra_compile_args={
|
||||
"cxx": CXX_FLAGS,
|
||||
"nvcc": NVCC_FLAGS,
|
||||
},
|
||||
)
|
||||
ext_modules.append(layernorm_extension)
|
||||
|
||||
@ -123,10 +174,38 @@ ext_modules.append(layernorm_extension)
|
||||
activation_extension = CUDAExtension(
|
||||
name="vllm.activation_ops",
|
||||
sources=["csrc/activation.cpp", "csrc/activation_kernels.cu"],
|
||||
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS},
|
||||
extra_compile_args={
|
||||
"cxx": CXX_FLAGS,
|
||||
"nvcc": NVCC_FLAGS,
|
||||
},
|
||||
)
|
||||
ext_modules.append(activation_extension)
|
||||
|
||||
# Quantization kernels.
|
||||
quantization_extension = CUDAExtension(
|
||||
name="vllm.quantization_ops",
|
||||
sources=[
|
||||
"csrc/quantization.cpp",
|
||||
"csrc/quantization/awq/gemm_kernels.cu",
|
||||
],
|
||||
extra_compile_args={
|
||||
"cxx": CXX_FLAGS,
|
||||
"nvcc": NVCC_FLAGS,
|
||||
},
|
||||
)
|
||||
ext_modules.append(quantization_extension)
|
||||
|
||||
# Misc. CUDA utils.
|
||||
cuda_utils_extension = CUDAExtension(
|
||||
name="vllm.cuda_utils",
|
||||
sources=["csrc/cuda_utils.cpp", "csrc/cuda_utils_kernels.cu"],
|
||||
extra_compile_args={
|
||||
"cxx": CXX_FLAGS,
|
||||
"nvcc": NVCC_FLAGS,
|
||||
},
|
||||
)
|
||||
ext_modules.append(cuda_utils_extension)
|
||||
|
||||
|
||||
def get_path(*filepath) -> str:
|
||||
return os.path.join(ROOT_DIR, *filepath)
|
||||
@ -138,8 +217,8 @@ def find_version(filepath: str):
|
||||
Adapted from https://github.com/ray-project/ray/blob/0b190ee1160eeca9796bc091e07eaebf4c85b511/python/setup.py
|
||||
"""
|
||||
with open(filepath) as fp:
|
||||
version_match = re.search(
|
||||
r"^__version__ = ['\"]([^'\"]*)['\"]", fp.read(), re.M)
|
||||
version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]",
|
||||
fp.read(), re.M)
|
||||
if version_match:
|
||||
return version_match.group(1)
|
||||
raise RuntimeError("Unable to find version string.")
|
||||
@ -162,7 +241,8 @@ setuptools.setup(
|
||||
version=find_version(get_path("vllm", "__init__.py")),
|
||||
author="vLLM Team",
|
||||
license="Apache 2.0",
|
||||
description="A high-throughput and memory-efficient inference and serving engine for LLMs",
|
||||
description=("A high-throughput and memory-efficient inference and "
|
||||
"serving engine for LLMs"),
|
||||
long_description=read_readme(),
|
||||
long_description_content_type="text/markdown",
|
||||
url="https://github.com/vllm-project/vllm",
|
||||
@ -174,11 +254,12 @@ setuptools.setup(
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"License :: OSI Approved :: Apache Software License",
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||
],
|
||||
packages=setuptools.find_packages(
|
||||
exclude=("assets", "benchmarks", "csrc", "docs", "examples", "tests")),
|
||||
packages=setuptools.find_packages(exclude=("benchmarks", "csrc", "docs",
|
||||
"examples", "tests")),
|
||||
python_requires=">=3.8",
|
||||
install_requires=get_requirements(),
|
||||
ext_modules=ext_modules,
|
||||
|
@ -40,8 +40,7 @@ if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
engine = AsyncLLMEngineWithStats.from_engine_args(engine_args,
|
||||
start_engine_loop=False)
|
||||
engine = AsyncLLMEngineWithStats.from_engine_args(engine_args)
|
||||
vllm.entrypoints.api_server.engine = engine
|
||||
uvicorn.run(
|
||||
app,
|
||||
|
80
tests/async_engine/test_async_llm_engine.py
Normal file
80
tests/async_engine/test_async_llm_engine.py
Normal file
@ -0,0 +1,80 @@
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestOutput:
|
||||
request_id: int
|
||||
finished: bool = False
|
||||
|
||||
|
||||
class MockEngine:
|
||||
|
||||
def __init__(self):
|
||||
self.step_calls = 0
|
||||
self.add_request_calls = 0
|
||||
self.abort_request_calls = 0
|
||||
self.request_id = None
|
||||
|
||||
async def step_async(self):
|
||||
self.step_calls += 1
|
||||
return [RequestOutput(
|
||||
request_id=self.request_id)] if self.request_id else []
|
||||
|
||||
def generate(self, request_id):
|
||||
self.request_id = request_id
|
||||
|
||||
def stop_generating(self):
|
||||
self.request_id = None
|
||||
|
||||
def add_request(self, **kwargs):
|
||||
self.add_request_calls += 1
|
||||
return
|
||||
|
||||
def abort_request(self, request_id):
|
||||
self.abort_request_calls += 1
|
||||
return
|
||||
|
||||
|
||||
class MockAsyncLLMEngine(AsyncLLMEngine):
|
||||
|
||||
def _init_engine(self, *args, **kwargs):
|
||||
return MockEngine()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_requests_event():
|
||||
engine = MockAsyncLLMEngine(worker_use_ray=False, engine_use_ray=False)
|
||||
engine.start_background_loop()
|
||||
await asyncio.sleep(0.01)
|
||||
assert engine.engine.step_calls == 0
|
||||
|
||||
await engine.add_request("1", "", None)
|
||||
await asyncio.sleep(0.01)
|
||||
assert engine.engine.add_request_calls == 1
|
||||
assert engine.engine.step_calls == 1
|
||||
|
||||
await engine.add_request("2", "", None)
|
||||
engine.engine.generate("2")
|
||||
await asyncio.sleep(0)
|
||||
assert engine.engine.add_request_calls == 2
|
||||
assert engine.engine.step_calls == 2
|
||||
await asyncio.sleep(0)
|
||||
assert engine.engine.step_calls == 3
|
||||
engine.engine.stop_generating()
|
||||
await asyncio.sleep(0)
|
||||
assert engine.engine.step_calls == 4
|
||||
await asyncio.sleep(0)
|
||||
assert engine.engine.step_calls == 4
|
||||
|
||||
await engine.add_request("3", "", None)
|
||||
await asyncio.sleep(0.01)
|
||||
assert engine.engine.add_request_calls == 3
|
||||
assert engine.engine.step_calls == 5
|
||||
await asyncio.sleep(0.01)
|
||||
assert engine.engine.add_request_calls == 3
|
||||
assert engine.engine.step_calls == 5
|
@ -4,10 +4,25 @@ from vllm.engine.async_llm_engine import RequestTracker
|
||||
from vllm.outputs import RequestOutput
|
||||
|
||||
|
||||
class DummyEvent:
|
||||
|
||||
def __init__(self):
|
||||
self._flag = False
|
||||
|
||||
def set(self):
|
||||
self._flag = True
|
||||
|
||||
def clear(self):
|
||||
self._flag = False
|
||||
|
||||
|
||||
def test_request_tracker():
|
||||
tracker = RequestTracker()
|
||||
tracker.new_requests_event = DummyEvent()
|
||||
stream_1 = tracker.add_request("1")
|
||||
assert tracker.new_requests_event._flag
|
||||
new, finished = tracker.get_new_and_finished_requests()
|
||||
assert not tracker.new_requests_event._flag
|
||||
assert len(new) == 1
|
||||
assert new[0]["request_id"] == "1"
|
||||
assert not finished
|
||||
@ -15,7 +30,9 @@ def test_request_tracker():
|
||||
|
||||
stream_2 = tracker.add_request("2")
|
||||
stream_3 = tracker.add_request("3")
|
||||
assert tracker.new_requests_event._flag
|
||||
new, finished = tracker.get_new_and_finished_requests()
|
||||
assert not tracker.new_requests_event._flag
|
||||
assert len(new) == 2
|
||||
assert new[0]["request_id"] == "2"
|
||||
assert new[1]["request_id"] == "3"
|
||||
@ -26,6 +43,7 @@ def test_request_tracker():
|
||||
# request_ids must be unique
|
||||
with pytest.raises(KeyError):
|
||||
tracker.add_request("1")
|
||||
assert not tracker.new_requests_event._flag
|
||||
|
||||
tracker.abort_request("1")
|
||||
new, finished = tracker.get_new_and_finished_requests()
|
||||
@ -36,6 +54,7 @@ def test_request_tracker():
|
||||
|
||||
stream_4 = tracker.add_request("4")
|
||||
tracker.abort_request("4")
|
||||
assert tracker.new_requests_event._flag
|
||||
new, finished = tracker.get_new_and_finished_requests()
|
||||
assert len(finished) == 1
|
||||
assert "4" in finished
|
||||
@ -43,9 +62,11 @@ def test_request_tracker():
|
||||
assert stream_4.finished
|
||||
|
||||
stream_5 = tracker.add_request("5")
|
||||
assert tracker.new_requests_event._flag
|
||||
tracker.process_request_output(
|
||||
RequestOutput("2", "output", [], [], finished=True))
|
||||
new, finished = tracker.get_new_and_finished_requests()
|
||||
assert not tracker.new_requests_event._flag
|
||||
assert len(finished) == 1
|
||||
assert "2" in finished
|
||||
assert len(new) == 1
|
||||
|
62
tests/engine/test_detokenize.py
Normal file
62
tests/engine/test_detokenize.py
Normal file
@ -0,0 +1,62 @@
|
||||
import pytest
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm.transformers_utils.tokenizer import detokenize_incrementally
|
||||
|
||||
TRUTH = [
|
||||
"Hello here, this is a simple test",
|
||||
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. It is designed to be used in production environments, where inference and serving",
|
||||
"我很感谢你的热情"
|
||||
]
|
||||
TOKENIZERS = [
|
||||
"facebook/opt-125m",
|
||||
"gpt2",
|
||||
"bigcode/tiny_starcoder_py",
|
||||
"EleutherAI/gpt-j-6b",
|
||||
"EleutherAI/pythia-70m",
|
||||
"bigscience/bloom-560m",
|
||||
"mosaicml/mpt-7b",
|
||||
"tiiuae/falcon-7b",
|
||||
"meta-llama/Llama-2-7b-hf",
|
||||
"codellama/CodeLlama-7b-hf",
|
||||
]
|
||||
|
||||
|
||||
def _run_incremental_decode(tokenizer, all_input_ids,
|
||||
skip_special_tokens: bool):
|
||||
decoded_text = ""
|
||||
offset = 0
|
||||
token_offset = 0
|
||||
prev_tokens = None
|
||||
for i in range(len(all_input_ids)):
|
||||
new_tokens, text, offset, token_offset = detokenize_incrementally(
|
||||
tokenizer,
|
||||
all_input_ids[:i + 1],
|
||||
prev_tokens,
|
||||
offset,
|
||||
token_offset,
|
||||
skip_special_tokens=skip_special_tokens)
|
||||
decoded_text += text
|
||||
if prev_tokens is None:
|
||||
prev_tokens = new_tokens
|
||||
else:
|
||||
prev_tokens += new_tokens
|
||||
return decoded_text
|
||||
|
||||
|
||||
@pytest.mark.parametrize("truth", TRUTH)
|
||||
@pytest.mark.parametrize("tokenizer_id", TOKENIZERS)
|
||||
@pytest.mark.parametrize("skip_special_tokens", (True, False))
|
||||
def test_decode_streaming(tokenizer_id, truth, skip_special_tokens):
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
|
||||
all_input_ids = tokenizer(truth, add_special_tokens=False)["input_ids"]
|
||||
if skip_special_tokens:
|
||||
all_input_ids = ([tokenizer.bos_token_id]
|
||||
if tokenizer.bos_token_id is not None else
|
||||
[]) + all_input_ids + [tokenizer.eos_token_id]
|
||||
|
||||
decoded_text = _run_incremental_decode(
|
||||
tokenizer, all_input_ids, skip_special_tokens=skip_special_tokens)
|
||||
|
||||
assert decoded_text == truth
|
@ -7,8 +7,12 @@ from xformers import ops as xops
|
||||
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
|
||||
|
||||
from vllm import attention_ops
|
||||
from vllm.utils import get_max_shared_memory_bytes
|
||||
|
||||
MAX_SEQ_LEN = 8192
|
||||
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
||||
# This will change depending on the compute capability.
|
||||
# - 512 as a buffer
|
||||
MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
|
||||
NUM_BLOCKS = 128 # Arbitrary values for testing
|
||||
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||
@ -135,6 +139,7 @@ def test_single_query_cached_kv_attention(
|
||||
device="cuda")
|
||||
|
||||
context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
|
||||
context_lens[-1] = MAX_SEQ_LEN
|
||||
max_context_len = max(context_lens)
|
||||
context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda")
|
||||
|
||||
@ -242,7 +247,11 @@ def test_multi_query_kv_attention(
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
seq_lens = random.sample(range(1, MAX_SEQ_LEN), num_seqs)
|
||||
# MAX_SEQ_LEN sometimes causes OOM in the reference implementation.
|
||||
# As the xformers library is already tested with its own tests, we can use
|
||||
# a smaller MAX_SEQ_LEN here.
|
||||
max_len = min(MAX_SEQ_LEN, 4096)
|
||||
seq_lens = random.sample(range(1, max_len), num_seqs)
|
||||
num_tokens = sum(seq_lens)
|
||||
|
||||
scale = float(1.0 / (head_size**0.5))
|
||||
|
@ -133,9 +133,10 @@ def test_rotary_embedding(
|
||||
device="cuda")
|
||||
|
||||
# Create the rotary embedding.
|
||||
inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim))
|
||||
inv_freq = 1.0 / (base**(
|
||||
torch.arange(0, rotary_dim, 2, dtype=torch.float) / rotary_dim))
|
||||
t = torch.arange(max_position).float()
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq.float())
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
cos_sin_cache = torch.cat((cos, sin), dim=-1)
|
||||
|
184
tests/samplers/test_sampler.py
Normal file
184
tests/samplers/test_sampler.py
Normal file
@ -0,0 +1,184 @@
|
||||
import pytest
|
||||
import random
|
||||
from typing import Tuple
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
|
||||
from vllm.worker.worker import Worker
|
||||
|
||||
|
||||
class MockLogitsSampler(Sampler):
|
||||
|
||||
def __init__(self, vocab_size: int, fake_logits: torch.Tensor):
|
||||
super().__init__(vocab_size=vocab_size)
|
||||
self.fake_logits = fake_logits
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
with patch("vllm.model_executor.layers.sampler._prune_hidden_states",
|
||||
lambda x, y: x):
|
||||
with patch("vllm.model_executor.layers.sampler._get_logits",
|
||||
lambda *args, **kwargs: self.fake_logits):
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
|
||||
def _prepare_test(
|
||||
batch_size: int
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler, Worker]:
|
||||
vocab_size = 32000
|
||||
input_tensor = torch.rand((batch_size, 1024),
|
||||
device="cuda",
|
||||
dtype=torch.float16)
|
||||
fake_logits = torch.full((batch_size, vocab_size),
|
||||
1e-2,
|
||||
device=input_tensor.device,
|
||||
dtype=input_tensor.dtype)
|
||||
sampler = MockLogitsSampler(32000, fake_logits)
|
||||
worker = Worker(None, None, None)
|
||||
worker.block_size = 16
|
||||
return input_tensor, fake_logits, sampler, worker
|
||||
|
||||
|
||||
RANDOM_SEEDS = list(range(128))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
def test_sampler_all_greedy(seed: int):
|
||||
set_random_seed(seed)
|
||||
batch_size = random.randint(1, 256)
|
||||
input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size)
|
||||
|
||||
seq_group_metadata_list = []
|
||||
for i in range(batch_size):
|
||||
seq_group_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
seq_data={0: SequenceData([1, 2, 3])},
|
||||
sampling_params=SamplingParams(temperature=0, ),
|
||||
block_tables={0: [1]},
|
||||
))
|
||||
|
||||
_, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
|
||||
sampler_output = sampler(embedding=None,
|
||||
hidden_states=input_tensor,
|
||||
input_metadata=input_metadata)
|
||||
expected = torch.argmax(fake_logits, dim=-1)
|
||||
for i, sequence_output in enumerate(sampler_output):
|
||||
for nth_output in sequence_output:
|
||||
assert nth_output.output_token == expected[i].item()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
def test_sampler_all_random(seed: int):
|
||||
set_random_seed(seed)
|
||||
batch_size = random.randint(1, 256)
|
||||
input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size)
|
||||
|
||||
for i in range(batch_size):
|
||||
fake_logits[i, i] = 1e2
|
||||
|
||||
seq_group_metadata_list = []
|
||||
for i in range(batch_size):
|
||||
seq_group_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
seq_data={0: SequenceData([1, 2, 3])},
|
||||
sampling_params=SamplingParams(
|
||||
temperature=1.0,
|
||||
n=random.randint(1, 10),
|
||||
),
|
||||
block_tables={0: [1]},
|
||||
))
|
||||
|
||||
_, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
|
||||
sampler_output = sampler(embedding=None,
|
||||
hidden_states=input_tensor,
|
||||
input_metadata=input_metadata)
|
||||
for i, sequence_output in enumerate(sampler_output):
|
||||
for nth_output in sequence_output:
|
||||
assert nth_output.output_token == i
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
def test_sampler_all_beam(seed: int):
|
||||
set_random_seed(seed)
|
||||
batch_size = random.randint(1, 256)
|
||||
input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size)
|
||||
|
||||
seq_group_metadata_list = []
|
||||
for i in range(batch_size):
|
||||
seq_group_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
seq_data={0: SequenceData([1, 2, 3])},
|
||||
sampling_params=SamplingParams(
|
||||
temperature=0,
|
||||
best_of=2,
|
||||
use_beam_search=True,
|
||||
),
|
||||
block_tables={0: [1]},
|
||||
))
|
||||
|
||||
_, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
|
||||
sampler(embedding=None,
|
||||
hidden_states=input_tensor,
|
||||
input_metadata=input_metadata)
|
||||
# no assertion here as I am not sure how to determine whether
|
||||
# the outputs are expected - in other words, this just tests
|
||||
# whether there are no exceptions in the sampler
|
||||
# when handling an all-beam search case.
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
def test_sampler_mixed(seed: int):
|
||||
set_random_seed(seed)
|
||||
batch_size = random.randint(1, 256)
|
||||
input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size)
|
||||
|
||||
seq_group_metadata_list = []
|
||||
expected_tokens = []
|
||||
for i in range(batch_size):
|
||||
n = 1
|
||||
sampling_type = random.randint(0, 2)
|
||||
if sampling_type == 0:
|
||||
sampling_params = SamplingParams(temperature=0)
|
||||
elif sampling_type == 1:
|
||||
n = random.randint(1, 10)
|
||||
sampling_params = SamplingParams(
|
||||
temperature=random.random() + 0.1,
|
||||
top_p=min(random.random() + 0.1, 1),
|
||||
top_k=random.randint(0, 10) or -1,
|
||||
n=n,
|
||||
presence_penalty=random.randint(0, 1),
|
||||
)
|
||||
else:
|
||||
sampling_params = SamplingParams(temperature=0,
|
||||
use_beam_search=True,
|
||||
best_of=2)
|
||||
for idx in range(n):
|
||||
fake_logits[i, i + idx] = 1e2
|
||||
expected_tokens.append(i + idx)
|
||||
seq_group_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
seq_data={0: SequenceData([1, 2, 3])},
|
||||
sampling_params=sampling_params,
|
||||
block_tables={0: [1]},
|
||||
))
|
||||
|
||||
_, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
|
||||
sampler_output = sampler(embedding=None,
|
||||
hidden_states=input_tensor,
|
||||
input_metadata=input_metadata)
|
||||
for i, sequence_output in enumerate(sampler_output):
|
||||
if seq_group_metadata_list[i].sampling_params.use_beam_search:
|
||||
continue
|
||||
for nth_output in sequence_output:
|
||||
assert nth_output.output_token in expected_tokens
|
@ -8,7 +8,7 @@ from vllm.entrypoints.llm import LLM
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
__version__ = "0.1.5"
|
||||
__version__ = "0.2.0"
|
||||
|
||||
__all__ = [
|
||||
"LLM",
|
||||
|
144
vllm/config.py
144
vllm/config.py
@ -38,6 +38,13 @@ class ModelConfig:
|
||||
will use FP16 precision for FP32 and FP16 models, and BF16 precision
|
||||
for BF16 models.
|
||||
seed: Random seed for reproducibility.
|
||||
revision: The specific model version to use. It can be a branch name,
|
||||
a tag name, or a commit id. If unspecified, will use the default
|
||||
version.
|
||||
max_model_len: Maximum length of a sequence (including prompt and
|
||||
output). If None, will be derived from the model.
|
||||
quantization: Quantization method that was used to quantize the model
|
||||
weights. If None, we assume the model weights are not quantized.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -50,6 +57,9 @@ class ModelConfig:
|
||||
load_format: str,
|
||||
dtype: str,
|
||||
seed: int,
|
||||
revision: Optional[str] = None,
|
||||
max_model_len: Optional[int] = None,
|
||||
quantization: Optional[str] = None,
|
||||
) -> None:
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
@ -58,11 +68,16 @@ class ModelConfig:
|
||||
self.download_dir = download_dir
|
||||
self.load_format = load_format
|
||||
self.seed = seed
|
||||
self.revision = revision
|
||||
self.quantization = quantization
|
||||
|
||||
self.hf_config = get_config(model, trust_remote_code)
|
||||
self.hf_config = get_config(model, trust_remote_code, revision)
|
||||
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
|
||||
self.max_model_len = _get_and_verify_max_len(self.hf_config,
|
||||
max_model_len)
|
||||
self._verify_load_format()
|
||||
self._verify_tokenizer_mode()
|
||||
self._verify_quantization()
|
||||
|
||||
def _verify_load_format(self) -> None:
|
||||
load_format = self.load_format.lower()
|
||||
@ -82,6 +97,17 @@ class ModelConfig:
|
||||
"either 'auto' or 'slow'.")
|
||||
self.tokenizer_mode = tokenizer_mode
|
||||
|
||||
def _verify_quantization(self) -> None:
|
||||
supported_quantization = ["awq"]
|
||||
if self.quantization is None:
|
||||
return
|
||||
quantization = self.quantization.lower()
|
||||
if quantization not in supported_quantization:
|
||||
raise ValueError(
|
||||
f"Unknown quantization: {self.quantization}. Must be one of "
|
||||
f"{supported_quantization}.")
|
||||
self.quantization = quantization
|
||||
|
||||
def verify_with_parallel_config(
|
||||
self,
|
||||
parallel_config: "ParallelConfig",
|
||||
@ -109,22 +135,28 @@ class ModelConfig:
|
||||
# FIXME(woosuk): This may not be true for all models.
|
||||
return self.hf_config.hidden_size // self.hf_config.num_attention_heads
|
||||
|
||||
def get_num_heads(self, parallel_config: "ParallelConfig") -> int:
|
||||
def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
|
||||
"""Returns the number of KV heads per GPU worker."""
|
||||
# For GPTBigCode & Falcon:
|
||||
# Note: for falcon, when new_decoder_architecture is True, the
|
||||
# multi_query flag is ignored and we use n_head_kv for the number of
|
||||
# KV heads.
|
||||
falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
|
||||
new_decoder_arch_falcon = (
|
||||
self.hf_config.model_type == "falcon"
|
||||
self.hf_config.model_type in falcon_model_types
|
||||
and getattr(self.hf_config, "new_decoder_architecture", False))
|
||||
if not new_decoder_arch_falcon and getattr(self.hf_config,
|
||||
"multi_query", False):
|
||||
# Multi-query attention, only one KV head.
|
||||
# Currently, tensor parallelism is not supported in this case.
|
||||
return 1
|
||||
# For Falcon:
|
||||
if getattr(self.hf_config, "n_head_kv", None) is not None:
|
||||
return (self.hf_config.n_head_kv //
|
||||
parallel_config.tensor_parallel_size)
|
||||
if getattr(self.hf_config, "num_kv_heads", None) is not None:
|
||||
return (self.hf_config.num_kv_heads //
|
||||
parallel_config.tensor_parallel_size)
|
||||
# For LLaMA-2:
|
||||
if getattr(self.hf_config, "num_key_value_heads", None) is not None:
|
||||
return (self.hf_config.num_key_value_heads //
|
||||
@ -132,26 +164,6 @@ class ModelConfig:
|
||||
total_num_attention_heads = self.hf_config.num_attention_heads
|
||||
return total_num_attention_heads // parallel_config.tensor_parallel_size
|
||||
|
||||
def get_max_model_len(self) -> int:
|
||||
max_model_len = float("inf")
|
||||
possible_keys = [
|
||||
# OPT
|
||||
"max_position_embeddings",
|
||||
# GPT-2
|
||||
"n_positions",
|
||||
# MPT
|
||||
"max_seq_len",
|
||||
# Others
|
||||
"max_sequence_length",
|
||||
"max_seq_length",
|
||||
"seq_len",
|
||||
]
|
||||
for key in possible_keys:
|
||||
max_len_key = getattr(self.hf_config, key, None)
|
||||
if max_len_key is not None:
|
||||
max_model_len = min(max_model_len, max_len_key)
|
||||
return max_model_len
|
||||
|
||||
def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
|
||||
total_num_hidden_layers = self.hf_config.num_hidden_layers
|
||||
return total_num_hidden_layers // parallel_config.pipeline_parallel_size
|
||||
@ -172,10 +184,12 @@ class CacheConfig:
|
||||
block_size: int,
|
||||
gpu_memory_utilization: float,
|
||||
swap_space: int,
|
||||
sliding_window: Optional[int] = None,
|
||||
) -> None:
|
||||
self.block_size = block_size
|
||||
self.gpu_memory_utilization = gpu_memory_utilization
|
||||
self.swap_space_bytes = swap_space * _GB
|
||||
self.sliding_window = sliding_window
|
||||
self._verify_args()
|
||||
|
||||
# Will be set after profiling.
|
||||
@ -251,11 +265,36 @@ class SchedulerConfig:
|
||||
and generated text).
|
||||
"""
|
||||
|
||||
def __init__(self, max_num_batched_tokens: int, max_num_seqs: int,
|
||||
max_model_len: int) -> None:
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
def __init__(
|
||||
self,
|
||||
max_num_batched_tokens: Optional[int],
|
||||
max_num_seqs: int,
|
||||
max_model_len: int,
|
||||
) -> None:
|
||||
if max_num_batched_tokens is not None:
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
else:
|
||||
# If max_model_len is too short, use 2048 as the default value for
|
||||
# higher throughput.
|
||||
self.max_num_batched_tokens = max(max_model_len, 2048)
|
||||
self.max_num_seqs = max_num_seqs
|
||||
self.max_model_len = max_model_len
|
||||
self._verify_args()
|
||||
|
||||
def _verify_args(self) -> None:
|
||||
if self.max_num_batched_tokens < self.max_model_len:
|
||||
raise ValueError(
|
||||
f"max_num_batched_tokens ({self.max_num_batched_tokens}) is "
|
||||
f"smaller than max_model_len ({self.max_model_len}). "
|
||||
"This effectively limits the maximum sequence length to "
|
||||
"max_num_batched_tokens and makes vLLM reject longer "
|
||||
"sequences. Please increase max_num_batched_tokens or "
|
||||
"decrease max_model_len.")
|
||||
if self.max_num_batched_tokens < self.max_num_seqs:
|
||||
raise ValueError(
|
||||
f"max_num_batched_tokens ({self.max_num_batched_tokens}) must "
|
||||
"be greater than or equal to max_num_seqs "
|
||||
f"({self.max_num_seqs}).")
|
||||
|
||||
|
||||
_STR_DTYPE_TO_TORCH_DTYPE = {
|
||||
@ -311,3 +350,56 @@ def _get_and_verify_dtype(
|
||||
f"of at least 8.0. Your {gpu_name} GPU has compute capability "
|
||||
f"{compute_capability[0]}.{compute_capability[1]}.")
|
||||
return torch_dtype
|
||||
|
||||
|
||||
def _get_and_verify_max_len(
|
||||
hf_config: PretrainedConfig,
|
||||
max_model_len: Optional[int],
|
||||
) -> int:
|
||||
"""Get and verify the model's maximum length."""
|
||||
derived_max_model_len = float("inf")
|
||||
possible_keys = [
|
||||
# OPT
|
||||
"max_position_embeddings",
|
||||
# GPT-2
|
||||
"n_positions",
|
||||
# MPT
|
||||
"max_seq_len",
|
||||
# Others
|
||||
"max_sequence_length",
|
||||
"max_seq_length",
|
||||
"seq_len",
|
||||
]
|
||||
for key in possible_keys:
|
||||
max_len_key = getattr(hf_config, key, None)
|
||||
if max_len_key is not None:
|
||||
derived_max_model_len = min(derived_max_model_len, max_len_key)
|
||||
if derived_max_model_len == float("inf"):
|
||||
if max_model_len is not None:
|
||||
# If max_model_len is specified, we use it.
|
||||
return max_model_len
|
||||
|
||||
default_max_len = 2048
|
||||
logger.warning(
|
||||
"The model's config.json does not contain any of the following "
|
||||
"keys to determine the original maximum length of the model: "
|
||||
f"{possible_keys}. Assuming the model's maximum length is "
|
||||
f"{default_max_len}.")
|
||||
derived_max_model_len = default_max_len
|
||||
|
||||
rope_scaling = getattr(hf_config, "rope_scaling", None)
|
||||
if rope_scaling is not None:
|
||||
assert "factor" in rope_scaling
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
derived_max_model_len *= scaling_factor
|
||||
|
||||
if max_model_len is None:
|
||||
max_model_len = derived_max_model_len
|
||||
elif max_model_len > derived_max_model_len:
|
||||
raise ValueError(
|
||||
f"User-specified max_model_len ({max_model_len}) is greater than "
|
||||
f"the derived max_model_len ({max_len_key}={derived_max_model_len}"
|
||||
" in model's config.json). This may lead to incorrect model "
|
||||
"outputs or CUDA errors. Make sure the value is correct and "
|
||||
"within the model context size.")
|
||||
return int(max_model_len)
|
||||
|
@ -63,10 +63,18 @@ class BlockSpaceManager:
|
||||
num_gpu_blocks: int,
|
||||
num_cpu_blocks: int,
|
||||
watermark: float = 0.01,
|
||||
sliding_window: Optional[int] = None,
|
||||
) -> None:
|
||||
self.block_size = block_size
|
||||
self.num_total_gpu_blocks = num_gpu_blocks
|
||||
self.num_total_cpu_blocks = num_cpu_blocks
|
||||
|
||||
self.block_sliding_window = None
|
||||
if sliding_window is not None:
|
||||
assert sliding_window % block_size == 0, (sliding_window,
|
||||
block_size)
|
||||
self.block_sliding_window = sliding_window // block_size
|
||||
|
||||
self.watermark = watermark
|
||||
assert watermark >= 0.0
|
||||
|
||||
@ -83,6 +91,9 @@ class BlockSpaceManager:
|
||||
# the same prompt. This may not be true for preempted sequences.
|
||||
seq = seq_group.get_seqs()[0]
|
||||
num_required_blocks = len(seq.logical_token_blocks)
|
||||
if self.block_sliding_window is not None:
|
||||
num_required_blocks = min(num_required_blocks,
|
||||
self.block_sliding_window)
|
||||
num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
|
||||
# Use watermark to avoid frequent cache eviction.
|
||||
return (num_free_gpu_blocks - num_required_blocks >=
|
||||
@ -95,8 +106,12 @@ class BlockSpaceManager:
|
||||
|
||||
# Allocate new physical token blocks that will store the prompt tokens.
|
||||
block_table: BlockTable = []
|
||||
for _ in range(len(seq.logical_token_blocks)):
|
||||
block = self.gpu_allocator.allocate()
|
||||
for logical_idx in range(len(seq.logical_token_blocks)):
|
||||
if (self.block_sliding_window is not None
|
||||
and logical_idx >= self.block_sliding_window):
|
||||
block = block_table[logical_idx % self.block_sliding_window]
|
||||
else:
|
||||
block = self.gpu_allocator.allocate()
|
||||
# Set the reference counts of the token blocks.
|
||||
block.ref_count = seq_group.num_seqs()
|
||||
block_table.append(block)
|
||||
@ -118,11 +133,17 @@ class BlockSpaceManager:
|
||||
block_table = self.block_tables[seq.seq_id]
|
||||
|
||||
if len(block_table) < len(logical_blocks):
|
||||
# The sequence has a new logical block.
|
||||
# Allocate a new physical block.
|
||||
block = self.gpu_allocator.allocate()
|
||||
block_table.append(block)
|
||||
return None
|
||||
if (self.block_sliding_window
|
||||
and len(block_table) >= self.block_sliding_window):
|
||||
# re-use a block
|
||||
block_table.append(block_table[len(block_table) %
|
||||
self.block_sliding_window])
|
||||
else:
|
||||
# The sequence has a new logical block.
|
||||
# Allocate a new physical block.
|
||||
block = self.gpu_allocator.allocate()
|
||||
block_table.append(block)
|
||||
return None
|
||||
|
||||
# We want to append the token to the last physical block.
|
||||
last_block = block_table[-1]
|
||||
@ -154,9 +175,7 @@ class BlockSpaceManager:
|
||||
for seq in seq_group.get_seqs():
|
||||
if seq.is_finished():
|
||||
continue
|
||||
block_table = self.block_tables[seq.seq_id]
|
||||
for block in block_table:
|
||||
blocks.add(block)
|
||||
blocks.update(self.block_tables[seq.seq_id])
|
||||
return list(blocks)
|
||||
|
||||
def can_swap_in(self, seq_group: SequenceGroup) -> bool:
|
||||
@ -224,7 +243,7 @@ class BlockSpaceManager:
|
||||
return block_number_mapping
|
||||
|
||||
def _free_block_table(self, block_table: BlockTable) -> None:
|
||||
for block in block_table:
|
||||
for block in set(block_table):
|
||||
if block.device == Device.GPU:
|
||||
self.gpu_allocator.free(block)
|
||||
else:
|
||||
|
@ -73,7 +73,7 @@ class Scheduler:
|
||||
block_size=self.cache_config.block_size,
|
||||
num_gpu_blocks=self.cache_config.num_gpu_blocks,
|
||||
num_cpu_blocks=self.cache_config.num_cpu_blocks,
|
||||
)
|
||||
sliding_window=self.cache_config.sliding_window)
|
||||
|
||||
# TODO(zhuohan): Use deque instead of list for better performance.
|
||||
# Sequence groups in the WAITING state.
|
||||
@ -175,7 +175,7 @@ class Scheduler:
|
||||
num_curr_seqs += num_new_seqs
|
||||
scheduled.append(seq_group)
|
||||
|
||||
if scheduled:
|
||||
if scheduled or ignored_seq_groups:
|
||||
scheduler_outputs = SchedulerOutputs(
|
||||
scheduled_seq_groups=scheduled,
|
||||
prompt_run=True,
|
||||
|
@ -18,20 +18,22 @@ class EngineArgs:
|
||||
load_format: str = 'auto'
|
||||
dtype: str = 'auto'
|
||||
seed: int = 0
|
||||
max_model_len: Optional[int] = None
|
||||
worker_use_ray: bool = False
|
||||
pipeline_parallel_size: int = 1
|
||||
tensor_parallel_size: int = 1
|
||||
block_size: int = 16
|
||||
swap_space: int = 4 # GiB
|
||||
gpu_memory_utilization: float = 0.90
|
||||
max_num_batched_tokens: int = 2560
|
||||
max_num_batched_tokens: Optional[int] = None
|
||||
max_num_seqs: int = 256
|
||||
disable_log_stats: bool = False
|
||||
revision: Optional[str] = None
|
||||
quantization: Optional[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.tokenizer is None:
|
||||
self.tokenizer = self.model
|
||||
self.max_num_seqs = min(self.max_num_seqs, self.max_num_batched_tokens)
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(
|
||||
@ -48,6 +50,13 @@ class EngineArgs:
|
||||
type=str,
|
||||
default=EngineArgs.tokenizer,
|
||||
help='name or path of the huggingface tokenizer to use')
|
||||
parser.add_argument(
|
||||
'--revision',
|
||||
type=str,
|
||||
default=None,
|
||||
help='the specific model version to use. It can be a branch '
|
||||
'name, a tag name, or a commit id. If unspecified, will use '
|
||||
'the default version.')
|
||||
parser.add_argument('--tokenizer-mode',
|
||||
type=str,
|
||||
default=EngineArgs.tokenizer_mode,
|
||||
@ -79,16 +88,22 @@ class EngineArgs:
|
||||
'a numpy cache to speed up the loading. '
|
||||
'"dummy" will initialize the weights with random values, '
|
||||
'which is mainly for profiling.')
|
||||
# TODO(woosuk): Support FP32.
|
||||
parser.add_argument(
|
||||
'--dtype',
|
||||
type=str,
|
||||
default=EngineArgs.dtype,
|
||||
choices=['auto', 'half', 'bfloat16', 'float'],
|
||||
choices=[
|
||||
'auto', 'half', 'float16', 'bfloat16', 'float', 'float32'
|
||||
],
|
||||
help='data type for model weights and activations. '
|
||||
'The "auto" option will use FP16 precision '
|
||||
'for FP32 and FP16 models, and BF16 precision '
|
||||
'for BF16 models.')
|
||||
parser.add_argument('--max-model-len',
|
||||
type=int,
|
||||
default=None,
|
||||
help='model context length. If unspecified, '
|
||||
'will be automatically derived from the model.')
|
||||
# Parallel arguments
|
||||
parser.add_argument('--worker-use-ray',
|
||||
action='store_true',
|
||||
@ -136,6 +151,13 @@ class EngineArgs:
|
||||
parser.add_argument('--disable-log-stats',
|
||||
action='store_true',
|
||||
help='disable logging statistics')
|
||||
# Quantization settings.
|
||||
parser.add_argument('--quantization',
|
||||
'-q',
|
||||
type=str,
|
||||
choices=['awq', None],
|
||||
default=None,
|
||||
help='Method used to quantize the weights')
|
||||
return parser
|
||||
|
||||
@classmethod
|
||||
@ -149,20 +171,20 @@ class EngineArgs:
|
||||
def create_engine_configs(
|
||||
self,
|
||||
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]:
|
||||
# Initialize the configs.
|
||||
model_config = ModelConfig(self.model, self.tokenizer,
|
||||
self.tokenizer_mode, self.trust_remote_code,
|
||||
self.download_dir, self.load_format,
|
||||
self.dtype, self.seed)
|
||||
cache_config = CacheConfig(self.block_size,
|
||||
self.gpu_memory_utilization,
|
||||
self.swap_space)
|
||||
self.dtype, self.seed, self.revision,
|
||||
self.max_model_len, self.quantization)
|
||||
cache_config = CacheConfig(
|
||||
self.block_size, self.gpu_memory_utilization, self.swap_space,
|
||||
getattr(model_config.hf_config, 'sliding_window', None))
|
||||
parallel_config = ParallelConfig(self.pipeline_parallel_size,
|
||||
self.tensor_parallel_size,
|
||||
self.worker_use_ray)
|
||||
scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
|
||||
self.max_num_seqs,
|
||||
model_config.get_max_model_len())
|
||||
model_config.max_model_len)
|
||||
return model_config, cache_config, parallel_config, scheduler_config
|
||||
|
||||
|
||||
@ -171,6 +193,7 @@ class AsyncEngineArgs(EngineArgs):
|
||||
"""Arguments for asynchronous vLLM engine."""
|
||||
engine_use_ray: bool = False
|
||||
disable_log_requests: bool = False
|
||||
max_log_len: Optional[int] = None
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(
|
||||
@ -183,4 +206,10 @@ class AsyncEngineArgs(EngineArgs):
|
||||
parser.add_argument('--disable-log-requests',
|
||||
action='store_true',
|
||||
help='disable logging requests')
|
||||
parser.add_argument('--max-log-len',
|
||||
type=int,
|
||||
default=None,
|
||||
help='max number of prompt characters or prompt '
|
||||
'ID numbers being printed in log. '
|
||||
'Default: unlimited.')
|
||||
return parser
|
||||
|
@ -1,7 +1,8 @@
|
||||
import asyncio
|
||||
import time
|
||||
from functools import partial
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
|
||||
from typing import (Any, Dict, Iterable, List, Optional, Set, Tuple, Type,
|
||||
Union)
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
@ -78,14 +79,24 @@ class RequestTracker:
|
||||
self._finished_requests: asyncio.Queue[str] = asyncio.Queue()
|
||||
self._new_requests: asyncio.Queue[Tuple[AsyncStream,
|
||||
dict]] = asyncio.Queue()
|
||||
self.new_requests_event = None
|
||||
|
||||
def __contains__(self, item):
|
||||
return item in self._request_streams
|
||||
|
||||
def propagate_exception(self, exc: Exception) -> None:
|
||||
"""Propagate an exception to all request streams."""
|
||||
for stream in self._request_streams.values():
|
||||
stream.put(exc)
|
||||
def init_event(self):
|
||||
self.new_requests_event = asyncio.Event()
|
||||
|
||||
def propagate_exception(self,
|
||||
exc: Exception,
|
||||
request_id: Optional[str] = None) -> None:
|
||||
"""Propagate an exception to request streams
|
||||
(all if request_id is None)."""
|
||||
if request_id is not None:
|
||||
self._request_streams[request_id].put(exc)
|
||||
else:
|
||||
for stream in self._request_streams.values():
|
||||
stream.put(exc)
|
||||
|
||||
def process_request_output(self,
|
||||
request_output: RequestOutput,
|
||||
@ -112,6 +123,9 @@ class RequestTracker:
|
||||
"request_id": request_id,
|
||||
**engine_add_request_kwargs
|
||||
}))
|
||||
|
||||
self.new_requests_event.set()
|
||||
|
||||
return stream
|
||||
|
||||
def abort_request(self, request_id: str, *, verbose: bool = False) -> None:
|
||||
@ -148,8 +162,13 @@ class RequestTracker:
|
||||
self._request_streams[stream.request_id] = stream
|
||||
new_requests.append(new_request)
|
||||
|
||||
self.new_requests_event.clear()
|
||||
|
||||
return new_requests, finished_requests
|
||||
|
||||
async def wait_for_new_requests(self):
|
||||
await self.new_requests_event.wait()
|
||||
|
||||
|
||||
class _AsyncLLMEngine(LLMEngine):
|
||||
"""Extension of LLMEngine to add async methods."""
|
||||
@ -164,10 +183,9 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
and updates the scheduler with the model outputs. Finally, it decodes
|
||||
the sequences and returns the newly generated results.
|
||||
"""
|
||||
(seq_group_metadata_list, scheduler_outputs,
|
||||
early_return) = self._schedule()
|
||||
if early_return is not None:
|
||||
return early_return
|
||||
seq_group_metadata_list, scheduler_outputs, ignored = self._schedule()
|
||||
if scheduler_outputs.is_empty():
|
||||
return ignored
|
||||
|
||||
# Execute the model.
|
||||
output = await self._run_workers_async(
|
||||
@ -178,7 +196,7 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
blocks_to_copy=scheduler_outputs.blocks_to_copy,
|
||||
)
|
||||
|
||||
return self._process_model_outputs(output, scheduler_outputs)
|
||||
return self._process_model_outputs(output, scheduler_outputs) + ignored
|
||||
|
||||
async def _run_workers_async(
|
||||
self,
|
||||
@ -230,6 +248,8 @@ class AsyncLLMEngine:
|
||||
async frontend will be executed in a separate process as the
|
||||
model workers.
|
||||
log_requests: Whether to log the requests.
|
||||
start_engine_loop: If True, the background task to run the engine
|
||||
will be automatically started in the generate call.
|
||||
*args, *kwargs: Arguments for LLMEngine.
|
||||
"""
|
||||
|
||||
@ -240,17 +260,22 @@ class AsyncLLMEngine:
|
||||
engine_use_ray: bool,
|
||||
*args,
|
||||
log_requests: bool = True,
|
||||
start_engine_loop: bool = False,
|
||||
max_log_len: Optional[int] = None,
|
||||
start_engine_loop: bool = True,
|
||||
**kwargs) -> None:
|
||||
self.worker_use_ray = worker_use_ray
|
||||
self.engine_use_ray = engine_use_ray
|
||||
self.log_requests = log_requests
|
||||
self.max_log_len = max_log_len
|
||||
self.engine = self._init_engine(*args, **kwargs)
|
||||
|
||||
self.request_tracker: RequestTracker = RequestTracker()
|
||||
self.background_loop = None
|
||||
if start_engine_loop:
|
||||
self.start_background_loop()
|
||||
# We need to keep a reference to unshielded
|
||||
# task as well to prevent it from being garbage
|
||||
# collected
|
||||
self._background_loop_unshielded = None
|
||||
self.start_engine_loop = start_engine_loop
|
||||
self._request_tracker = RequestTracker()
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
@ -261,11 +286,14 @@ class AsyncLLMEngine:
|
||||
"""Start the background loop."""
|
||||
if self.is_running:
|
||||
raise RuntimeError("Background loop is already running.")
|
||||
self.background_loop = asyncio.get_event_loop().create_task(
|
||||
self.run_engine_loop())
|
||||
self.background_loop.add_done_callback(
|
||||
self._request_tracker.init_event()
|
||||
|
||||
self._background_loop_unshielded = asyncio.get_event_loop(
|
||||
).create_task(self.run_engine_loop())
|
||||
self._background_loop_unshielded.add_done_callback(
|
||||
partial(_raise_exception_on_finish,
|
||||
request_tracker=self.request_tracker))
|
||||
request_tracker=self._request_tracker))
|
||||
self.background_loop = asyncio.shield(self._background_loop_unshielded)
|
||||
|
||||
def _init_engine(self, *args,
|
||||
**kwargs) -> Union[_AsyncLLMEngine, "ray.ObjectRef"]:
|
||||
@ -277,11 +305,13 @@ class AsyncLLMEngine:
|
||||
engine_class = ray.remote(num_gpus=1)(self._engine_class).remote
|
||||
return engine_class(*args, **kwargs)
|
||||
|
||||
async def engine_step(self):
|
||||
"""Kick the engine to process the waiting requests."""
|
||||
async def engine_step(self) -> bool:
|
||||
"""Kick the engine to process the waiting requests.
|
||||
|
||||
Returns True if there are in-progress requests."""
|
||||
|
||||
new_requests, finished_requests = (
|
||||
self.request_tracker.get_new_and_finished_requests())
|
||||
self._request_tracker.get_new_and_finished_requests())
|
||||
|
||||
for new_request in new_requests:
|
||||
# Add the request into the vLLM engine's waiting queue.
|
||||
@ -301,9 +331,11 @@ class AsyncLLMEngine:
|
||||
|
||||
# Put the outputs into the corresponding streams.
|
||||
for request_output in request_outputs:
|
||||
self.request_tracker.process_request_output(
|
||||
self._request_tracker.process_request_output(
|
||||
request_output, verbose=self.log_requests)
|
||||
|
||||
return len(request_outputs) > 0
|
||||
|
||||
async def _engine_abort(self, request_ids: Iterable[str]):
|
||||
if self.engine_use_ray:
|
||||
await self.engine.abort_request.remote(request_ids)
|
||||
@ -311,8 +343,12 @@ class AsyncLLMEngine:
|
||||
self.engine.abort_request(request_ids)
|
||||
|
||||
async def run_engine_loop(self):
|
||||
# Initialize the RequestTracker here so it uses the right event loop.
|
||||
has_requests_in_progress = False
|
||||
while True:
|
||||
await self.engine_step()
|
||||
if not has_requests_in_progress:
|
||||
await self._request_tracker.wait_for_new_requests()
|
||||
has_requests_in_progress = await self.engine_step()
|
||||
await asyncio.sleep(0)
|
||||
|
||||
async def add_request(
|
||||
@ -324,19 +360,30 @@ class AsyncLLMEngine:
|
||||
arrival_time: Optional[float] = None,
|
||||
) -> AsyncStream:
|
||||
if self.log_requests:
|
||||
shortened_prompt = prompt
|
||||
shortened_token_ids = prompt_token_ids
|
||||
if self.max_log_len is not None:
|
||||
if shortened_prompt is not None:
|
||||
shortened_prompt = shortened_prompt[:self.max_log_len]
|
||||
if shortened_token_ids is not None:
|
||||
shortened_token_ids = shortened_token_ids[:self.
|
||||
max_log_len]
|
||||
logger.info(f"Received request {request_id}: "
|
||||
f"prompt: {prompt!r}, "
|
||||
f"prompt: {shortened_prompt!r}, "
|
||||
f"sampling params: {sampling_params}, "
|
||||
f"prompt token ids: {prompt_token_ids}.")
|
||||
f"prompt token ids: {shortened_token_ids}.")
|
||||
|
||||
if not self.is_running:
|
||||
raise AsyncEngineDeadError(
|
||||
"Background loop is not running. If it was running, "
|
||||
"inspect the output to find the stacktrace of the "
|
||||
"error that caused the background loop to stop "
|
||||
"(AsyncEngineDeadError).")
|
||||
if self.start_engine_loop:
|
||||
self.start_background_loop()
|
||||
else:
|
||||
raise AsyncEngineDeadError(
|
||||
"Background loop is not running. If it was running, "
|
||||
"inspect the output to find the stacktrace of the "
|
||||
"error that caused the background loop to stop "
|
||||
"(AsyncEngineDeadError).")
|
||||
|
||||
stream = self.request_tracker.add_request(
|
||||
stream = self._request_tracker.add_request(
|
||||
request_id,
|
||||
prompt=prompt,
|
||||
sampling_params=sampling_params,
|
||||
@ -381,8 +428,9 @@ class AsyncLLMEngine:
|
||||
|
||||
async for request_output in stream:
|
||||
yield request_output
|
||||
except Exception as e:
|
||||
# If there is an exception, abort the request.
|
||||
except (Exception, asyncio.CancelledError) as e:
|
||||
# If there is an exception or coroutine is cancelled, abort the
|
||||
# request.
|
||||
self._abort(request_id)
|
||||
raise e
|
||||
|
||||
@ -413,8 +461,8 @@ class AsyncLLMEngine:
|
||||
Args:
|
||||
request_id: The unique id of the request.
|
||||
"""
|
||||
self.request_tracker.abort_request(request_id,
|
||||
verbose=self.log_requests)
|
||||
self._request_tracker.abort_request(request_id,
|
||||
verbose=self.log_requests)
|
||||
|
||||
async def get_model_config(self) -> ModelConfig:
|
||||
"""Get the model configuration of the vLLM engine."""
|
||||
@ -426,7 +474,7 @@ class AsyncLLMEngine:
|
||||
@classmethod
|
||||
def from_engine_args(cls,
|
||||
engine_args: AsyncEngineArgs,
|
||||
start_engine_loop: bool = False) -> "AsyncLLMEngine":
|
||||
start_engine_loop: bool = True) -> "AsyncLLMEngine":
|
||||
"""Creates an async LLM engine from the engine arguments."""
|
||||
# Create the engine configs.
|
||||
engine_configs = engine_args.create_engine_configs()
|
||||
@ -442,5 +490,6 @@ class AsyncLLMEngine:
|
||||
placement_group,
|
||||
log_requests=not engine_args.disable_log_requests,
|
||||
log_stats=not engine_args.disable_log_stats,
|
||||
max_log_len=engine_args.max_log_len,
|
||||
start_engine_loop=start_engine_loop)
|
||||
return engine
|
||||
|
@ -54,8 +54,8 @@ class LLMEngine:
|
||||
scheduler_config: The configuration related to the request scheduler.
|
||||
distributed_init_method: The initialization method for distributed
|
||||
execution. See `torch.distributed.init_process_group` for details.
|
||||
stage_devices: The list of devices for each stage. Each stage is a list
|
||||
of (rank, node_resource, device) tuples.
|
||||
placement_group: Ray placement group for distributed execution.
|
||||
Required for distributed execution.
|
||||
log_stats: Whether to log statistics.
|
||||
"""
|
||||
|
||||
@ -74,16 +74,21 @@ class LLMEngine:
|
||||
f"model={model_config.model!r}, "
|
||||
f"tokenizer={model_config.tokenizer!r}, "
|
||||
f"tokenizer_mode={model_config.tokenizer_mode}, "
|
||||
f"revision={model_config.revision}, "
|
||||
f"trust_remote_code={model_config.trust_remote_code}, "
|
||||
f"dtype={model_config.dtype}, "
|
||||
f"max_seq_len={model_config.max_model_len}, "
|
||||
f"download_dir={model_config.download_dir!r}, "
|
||||
f"load_format={model_config.load_format}, "
|
||||
f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
|
||||
f"quantization={model_config.quantization}, "
|
||||
f"seed={model_config.seed})")
|
||||
# TODO(woosuk): Print more configs in debug mode.
|
||||
|
||||
self.model_config = model_config
|
||||
self.cache_config = cache_config
|
||||
assert self.cache_config.sliding_window == getattr(
|
||||
self.model_config.hf_config, "sliding_window", None)
|
||||
self.parallel_config = parallel_config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.log_stats = log_stats
|
||||
@ -92,7 +97,8 @@ class LLMEngine:
|
||||
self.tokenizer = get_tokenizer(
|
||||
model_config.tokenizer,
|
||||
tokenizer_mode=model_config.tokenizer_mode,
|
||||
trust_remote_code=model_config.trust_remote_code)
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
revision=model_config.revision)
|
||||
self.seq_counter = Counter()
|
||||
|
||||
# Create the parallel GPU workers.
|
||||
@ -153,7 +159,7 @@ class LLMEngine:
|
||||
placement_group=placement_group,
|
||||
placement_group_capture_child_tasks=True),
|
||||
**ray_remote_kwargs,
|
||||
)(RayWorker).remote()
|
||||
)(RayWorker).remote(self.model_config.trust_remote_code)
|
||||
self.workers.append(worker)
|
||||
|
||||
# Initialize torch distributed process group for the workers.
|
||||
@ -291,14 +297,12 @@ class LLMEngine:
|
||||
def _schedule(
|
||||
self
|
||||
) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs,
|
||||
Optional[List[RequestOutput]]]:
|
||||
List[RequestOutput]]:
|
||||
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
|
||||
if scheduler_outputs.is_empty():
|
||||
return seq_group_metadata_list, scheduler_outputs, [
|
||||
RequestOutput.from_seq_group(seq_group)
|
||||
for seq_group in scheduler_outputs.ignored_seq_groups
|
||||
]
|
||||
return seq_group_metadata_list, scheduler_outputs, None
|
||||
return seq_group_metadata_list, scheduler_outputs, [
|
||||
RequestOutput.from_seq_group(seq_group)
|
||||
for seq_group in scheduler_outputs.ignored_seq_groups
|
||||
]
|
||||
|
||||
def _check_beam_search_early_stopping(
|
||||
self,
|
||||
@ -386,7 +390,7 @@ class LLMEngine:
|
||||
child_seqs.append((parent, parent))
|
||||
|
||||
for seq, _ in child_seqs:
|
||||
self._decode_sequence(seq)
|
||||
self._decode_sequence(seq, seq_group.sampling_params)
|
||||
self._check_stop(seq, seq_group.sampling_params)
|
||||
|
||||
# Non-beam search case
|
||||
@ -542,10 +546,9 @@ class LLMEngine:
|
||||
and updates the scheduler with the model outputs. Finally, it decodes
|
||||
the sequences and returns the newly generated results.
|
||||
"""
|
||||
(seq_group_metadata_list, scheduler_outputs,
|
||||
early_return) = self._schedule()
|
||||
if early_return is not None:
|
||||
return early_return
|
||||
seq_group_metadata_list, scheduler_outputs, ignored = self._schedule()
|
||||
if scheduler_outputs.is_empty():
|
||||
return ignored
|
||||
|
||||
# Execute the model.
|
||||
output = self._run_workers(
|
||||
@ -556,7 +559,7 @@ class LLMEngine:
|
||||
blocks_to_copy=scheduler_outputs.blocks_to_copy,
|
||||
)
|
||||
|
||||
return self._process_model_outputs(output, scheduler_outputs)
|
||||
return self._process_model_outputs(output, scheduler_outputs) + ignored
|
||||
|
||||
def _log_system_stats(
|
||||
self,
|
||||
@ -621,17 +624,25 @@ class LLMEngine:
|
||||
f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%")
|
||||
self.last_logging_time = now
|
||||
|
||||
def _decode_sequence(self, seq: Sequence) -> None:
|
||||
def _decode_sequence(self, seq: Sequence,
|
||||
sampling_params: SamplingParams) -> None:
|
||||
"""Decodes the new token for a sequence."""
|
||||
new_token, new_output_text = detokenize_incrementally(
|
||||
self.tokenizer,
|
||||
seq.output_tokens,
|
||||
seq.get_last_token_id(),
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
if new_token is not None:
|
||||
seq.output_tokens.append(new_token)
|
||||
seq.output_text = new_output_text
|
||||
(new_tokens, new_output_text, prefix_offset,
|
||||
read_offset) = detokenize_incrementally(
|
||||
self.tokenizer,
|
||||
all_input_ids=seq.get_token_ids(),
|
||||
prev_tokens=seq.tokens,
|
||||
prefix_offset=seq.prefix_offset,
|
||||
read_offset=seq.read_offset,
|
||||
skip_special_tokens=sampling_params.skip_special_tokens,
|
||||
)
|
||||
if seq.tokens is None:
|
||||
seq.tokens = new_tokens
|
||||
else:
|
||||
seq.tokens.extend(new_tokens)
|
||||
seq.prefix_offset = prefix_offset
|
||||
seq.read_offset = read_offset
|
||||
seq.output_text += new_output_text
|
||||
|
||||
def _check_stop(self, seq: Sequence,
|
||||
sampling_params: SamplingParams) -> None:
|
||||
@ -643,6 +654,9 @@ class LLMEngine:
|
||||
seq.output_text = seq.output_text[:-len(stop_str)]
|
||||
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||
return
|
||||
if seq.get_last_token_id() in sampling_params.stop_token_ids:
|
||||
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||
return
|
||||
|
||||
# Check if the sequence has reached max_model_len.
|
||||
if seq.get_len() > self.scheduler_config.max_model_len:
|
||||
|
@ -2,6 +2,9 @@ import socket
|
||||
from typing import Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
try:
|
||||
import ray
|
||||
@ -11,7 +14,11 @@ try:
|
||||
"""Ray wrapper for vllm.worker.Worker, allowing Worker to be
|
||||
lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
def __init__(self, init_cached_hf_modules=False) -> None:
|
||||
if init_cached_hf_modules:
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from transformers.dynamic_module_utils import init_hf_modules
|
||||
init_hf_modules()
|
||||
self.worker = None
|
||||
|
||||
def init_worker(self, worker_init_fn):
|
||||
@ -24,7 +31,10 @@ try:
|
||||
executor = getattr(self, method)
|
||||
return executor(*args, **kwargs)
|
||||
|
||||
except ImportError:
|
||||
except ImportError as e:
|
||||
logger.warning(f"Failed to import Ray with {e!r}. "
|
||||
"For distributed inference, please install Ray with "
|
||||
"`pip install ray pandas pyarrow`.")
|
||||
ray = None
|
||||
TorchDistributedWorker = None
|
||||
RayWorker = None # pylint: disable=invalid-name
|
||||
@ -53,11 +63,10 @@ def initialize_cluster(
|
||||
the default Ray cluster address.
|
||||
|
||||
Returns:
|
||||
A tuple of (`distributed_init_method`, `all_stage_devices`). The
|
||||
A tuple of (`distributed_init_method`, `placement_group`). The
|
||||
`distributed_init_method` is the address for initializing the
|
||||
distributed backend. `all_stage_devices` includes device IDs for
|
||||
each worker in each pipeline stage. Each device ID is a tuple of
|
||||
(rank, node resource, device id).
|
||||
distributed backend. `placement_group` includes the specification
|
||||
of the resources for each distributed worker.
|
||||
"""
|
||||
if parallel_config.worker_use_ray or engine_use_ray:
|
||||
if ray is None:
|
||||
|
@ -2,7 +2,7 @@ import argparse
|
||||
import json
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from fastapi import BackgroundTasks, FastAPI, Request
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||
import uvicorn
|
||||
|
||||
@ -32,9 +32,6 @@ async def generate(request: Request) -> Response:
|
||||
sampling_params = SamplingParams(**request_dict)
|
||||
request_id = random_uuid()
|
||||
|
||||
if not engine.is_running:
|
||||
engine.start_background_loop()
|
||||
|
||||
results_generator = engine.generate(prompt, sampling_params, request_id)
|
||||
|
||||
# Streaming case
|
||||
@ -47,14 +44,8 @@ async def generate(request: Request) -> Response:
|
||||
ret = {"text": text_outputs}
|
||||
yield (json.dumps(ret) + "\0").encode("utf-8")
|
||||
|
||||
async def abort_request() -> None:
|
||||
await engine.abort(request_id)
|
||||
|
||||
if stream:
|
||||
background_tasks = BackgroundTasks()
|
||||
# Abort the request if the client disconnects.
|
||||
background_tasks.add_task(abort_request)
|
||||
return StreamingResponse(stream_results(), background=background_tasks)
|
||||
return StreamingResponse(stream_results())
|
||||
|
||||
# Non-streaming case
|
||||
final_output = None
|
||||
@ -80,8 +71,7 @@ if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
engine = AsyncLLMEngine.from_engine_args(engine_args,
|
||||
start_engine_loop=False)
|
||||
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
|
||||
uvicorn.run(app,
|
||||
host=args.host,
|
||||
|
@ -37,7 +37,22 @@ class LLM:
|
||||
the `torch_dtype` attribute specified in the model config file.
|
||||
However, if the `torch_dtype` in the config is `float32`, we will
|
||||
use `float16` instead.
|
||||
quantization: The method used to quantize the model weights. Currently,
|
||||
we support "awq". If None, we assume the model weights are not
|
||||
quantized and use `dtype` to determine the data type of the weights.
|
||||
revision: The specific model version to use. It can be a branch name,
|
||||
a tag name, or a commit id.
|
||||
seed: The seed to initialize the random number generator for sampling.
|
||||
gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
|
||||
reserve for the model weights, activations, and KV cache. Higher
|
||||
values will increase the KV cache size and thus improve the model's
|
||||
throughput. However, if the value is too high, it may cause out-of-
|
||||
memory (OOM) errors.
|
||||
swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
|
||||
This can be used for temporarily storing the states of the requests
|
||||
when their `best_of` sampling parameters are larger than 1. If all
|
||||
requests will have `best_of=1`, you can safely set this to 0.
|
||||
Otherwise, too small values may cause out-of-memory (OOM) errors.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -48,7 +63,11 @@ class LLM:
|
||||
trust_remote_code: bool = False,
|
||||
tensor_parallel_size: int = 1,
|
||||
dtype: str = "auto",
|
||||
quantization: Optional[str] = None,
|
||||
revision: Optional[str] = None,
|
||||
seed: int = 0,
|
||||
gpu_memory_utilization: float = 0.9,
|
||||
swap_space: int = 4,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
if "disable_log_stats" not in kwargs:
|
||||
@ -60,7 +79,11 @@ class LLM:
|
||||
trust_remote_code=trust_remote_code,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
dtype=dtype,
|
||||
quantization=quantization,
|
||||
revision=revision,
|
||||
seed=seed,
|
||||
gpu_memory_utilization=gpu_memory_utilization,
|
||||
swap_space=swap_space,
|
||||
**kwargs,
|
||||
)
|
||||
self.llm_engine = LLMEngine.from_engine_args(engine_args)
|
||||
|
@ -10,7 +10,7 @@ from typing import AsyncGenerator, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import fastapi
|
||||
import uvicorn
|
||||
from fastapi import BackgroundTasks, Request
|
||||
from fastapi import Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
@ -130,6 +130,8 @@ async def check_length(
|
||||
input_ids = tokenizer(prompt).input_ids
|
||||
token_num = len(input_ids)
|
||||
|
||||
if request.max_tokens is None:
|
||||
request.max_tokens = max_model_len - token_num
|
||||
if token_num + request.max_tokens > max_model_len:
|
||||
return input_ids, create_error_response(
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
@ -192,14 +194,11 @@ async def create_chat_completion(request: ChatCompletionRequest,
|
||||
"""
|
||||
logger.info(f"Received chat completion request: {request}")
|
||||
|
||||
if not engine.is_running:
|
||||
engine.start_background_loop()
|
||||
|
||||
error_check_ret = await check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
if request.logit_bias is not None:
|
||||
if request.logit_bias is not None and len(request.logit_bias) > 0:
|
||||
# TODO: support logit_bias in vLLM engine.
|
||||
return create_error_response(HTTPStatus.BAD_REQUEST,
|
||||
"logit_bias is not currently supported")
|
||||
@ -220,11 +219,13 @@ async def create_chat_completion(request: ChatCompletionRequest,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
stop=request.stop,
|
||||
stop_token_ids=request.stop_token_ids,
|
||||
max_tokens=request.max_tokens,
|
||||
best_of=request.best_of,
|
||||
top_k=request.top_k,
|
||||
ignore_eos=request.ignore_eos,
|
||||
use_beam_search=request.use_beam_search,
|
||||
skip_special_tokens=request.skip_special_tokens,
|
||||
)
|
||||
except ValueError as e:
|
||||
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
|
||||
@ -232,9 +233,6 @@ async def create_chat_completion(request: ChatCompletionRequest,
|
||||
result_generator = engine.generate(prompt, sampling_params, request_id,
|
||||
token_ids)
|
||||
|
||||
async def abort_request() -> None:
|
||||
await engine.abort(request_id)
|
||||
|
||||
def create_stream_response_json(
|
||||
index: int,
|
||||
text: str,
|
||||
@ -294,19 +292,15 @@ async def create_chat_completion(request: ChatCompletionRequest,
|
||||
|
||||
# Streaming response
|
||||
if request.stream:
|
||||
background_tasks = BackgroundTasks()
|
||||
# Abort the request if the client disconnects.
|
||||
background_tasks.add_task(abort_request)
|
||||
return StreamingResponse(completion_stream_generator(),
|
||||
media_type="text/event-stream",
|
||||
background=background_tasks)
|
||||
media_type="text/event-stream")
|
||||
|
||||
# Non-streaming response
|
||||
final_res: RequestOutput = None
|
||||
async for res in result_generator:
|
||||
if await raw_request.is_disconnected():
|
||||
# Abort the request if the client disconnects.
|
||||
await abort_request()
|
||||
await engine.abort(request_id)
|
||||
return create_error_response(HTTPStatus.BAD_REQUEST,
|
||||
"Client disconnected")
|
||||
final_res = res
|
||||
@ -367,9 +361,6 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
|
||||
"""
|
||||
logger.info(f"Received completion request: {request}")
|
||||
|
||||
if not engine.is_running:
|
||||
engine.start_background_loop()
|
||||
|
||||
error_check_ret = await check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
@ -385,7 +376,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
|
||||
return create_error_response(HTTPStatus.BAD_REQUEST,
|
||||
"suffix is not currently supported")
|
||||
|
||||
if request.logit_bias is not None:
|
||||
if request.logit_bias is not None and len(request.logit_bias) > 0:
|
||||
# TODO: support logit_bias in vLLM engine.
|
||||
return create_error_response(HTTPStatus.BAD_REQUEST,
|
||||
"logit_bias is not currently supported")
|
||||
@ -431,10 +422,12 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
|
||||
top_p=request.top_p,
|
||||
top_k=request.top_k,
|
||||
stop=request.stop,
|
||||
stop_token_ids=request.stop_token_ids,
|
||||
ignore_eos=request.ignore_eos,
|
||||
max_tokens=request.max_tokens,
|
||||
logprobs=request.logprobs,
|
||||
use_beam_search=request.use_beam_search,
|
||||
skip_special_tokens=request.skip_special_tokens,
|
||||
)
|
||||
except ValueError as e:
|
||||
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
|
||||
@ -454,9 +447,6 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
|
||||
and (request.best_of is None or request.n == request.best_of)
|
||||
and not request.use_beam_search)
|
||||
|
||||
async def abort_request() -> None:
|
||||
await engine.abort(request_id)
|
||||
|
||||
def create_stream_response_json(
|
||||
index: int,
|
||||
text: str,
|
||||
@ -516,19 +506,15 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
|
||||
|
||||
# Streaming response
|
||||
if stream:
|
||||
background_tasks = BackgroundTasks()
|
||||
# Abort the request if the client disconnects.
|
||||
background_tasks.add_task(abort_request)
|
||||
return StreamingResponse(completion_stream_generator(),
|
||||
media_type="text/event-stream",
|
||||
background=background_tasks)
|
||||
media_type="text/event-stream")
|
||||
|
||||
# Non-streaming response
|
||||
final_res: RequestOutput = None
|
||||
async for res in result_generator:
|
||||
if await raw_request.is_disconnected():
|
||||
# Abort the request if the client disconnects.
|
||||
await abort_request()
|
||||
await engine.abort(request_id)
|
||||
return create_error_response(HTTPStatus.BAD_REQUEST,
|
||||
"Client disconnected")
|
||||
final_res = res
|
||||
@ -627,10 +613,9 @@ if __name__ == "__main__":
|
||||
served_model = args.model
|
||||
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
engine = AsyncLLMEngine.from_engine_args(engine_args,
|
||||
start_engine_loop=False)
|
||||
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
engine_model_config = asyncio.run(engine.get_model_config())
|
||||
max_model_len = engine_model_config.get_max_model_len()
|
||||
max_model_len = engine_model_config.max_model_len
|
||||
|
||||
# A separate tokenizer to map token IDs to strings.
|
||||
tokenizer = get_tokenizer(engine_args.tokenizer,
|
||||
|
@ -58,7 +58,7 @@ class ChatCompletionRequest(BaseModel):
|
||||
temperature: Optional[float] = 0.7
|
||||
top_p: Optional[float] = 1.0
|
||||
n: Optional[int] = 1
|
||||
max_tokens: Optional[int] = 16
|
||||
max_tokens: Optional[int] = None
|
||||
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
||||
stream: Optional[bool] = False
|
||||
presence_penalty: Optional[float] = 0.0
|
||||
@ -70,6 +70,8 @@ class ChatCompletionRequest(BaseModel):
|
||||
top_k: Optional[int] = -1
|
||||
ignore_eos: Optional[bool] = False
|
||||
use_beam_search: Optional[bool] = False
|
||||
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
||||
skip_special_tokens: Optional[bool] = True
|
||||
|
||||
|
||||
class CompletionRequest(BaseModel):
|
||||
@ -94,6 +96,8 @@ class CompletionRequest(BaseModel):
|
||||
top_k: Optional[int] = -1
|
||||
ignore_eos: Optional[bool] = False
|
||||
use_beam_search: Optional[bool] = False
|
||||
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
||||
skip_special_tokens: Optional[bool] = True
|
||||
|
||||
|
||||
class LogProbs(BaseModel):
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from xformers.ops import AttentionBias
|
||||
@ -29,6 +29,7 @@ class InputMetadata:
|
||||
context_lens: torch.Tensor,
|
||||
max_context_len: int,
|
||||
block_tables: torch.Tensor,
|
||||
sliding_window: Optional[int] = None,
|
||||
) -> None:
|
||||
self.seq_groups = seq_groups
|
||||
self.seq_data = seq_data
|
||||
@ -38,6 +39,24 @@ class InputMetadata:
|
||||
self.max_context_len = max_context_len
|
||||
self.block_tables = block_tables
|
||||
|
||||
self.to_cache = None
|
||||
if sliding_window is not None:
|
||||
# We need to keep the positions of sliding windows within
|
||||
# the key / value tables, this is helpful to know which
|
||||
# elements we need to cache and where
|
||||
to_cache, start_idx = [], 0
|
||||
for prompt_len in self.prompt_lens:
|
||||
to_cache.extend(
|
||||
range(
|
||||
start_idx + max(0, prompt_len - sliding_window),
|
||||
start_idx + prompt_len,
|
||||
))
|
||||
start_idx += prompt_len
|
||||
to_cache.extend(range(start_idx, slot_mapping.shape[0]))
|
||||
self.to_cache = torch.tensor(to_cache,
|
||||
dtype=torch.int32,
|
||||
device=self.slot_mapping.device)
|
||||
|
||||
self.num_prompts = len(prompt_lens)
|
||||
self.num_prompt_tokens = sum(prompt_lens)
|
||||
self.num_generation_tokens = context_lens.shape[0]
|
||||
|
@ -1,5 +1,5 @@
|
||||
"""Multi-head attention."""
|
||||
from typing import List, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -9,8 +9,10 @@ from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
|
||||
|
||||
from vllm import attention_ops
|
||||
from vllm import cache_ops
|
||||
from vllm import pos_encoding_ops
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.rotary_embedding import (
|
||||
DynamicNTKScalingRotaryEmbedding, LinearScalingRotaryEmbedding,
|
||||
RotaryEmbedding)
|
||||
|
||||
_SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256]
|
||||
|
||||
@ -56,12 +58,14 @@ class PagedAttention(nn.Module):
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: Optional[int] = None) -> None:
|
||||
num_kv_heads: Optional[int] = None,
|
||||
sliding_window: Optional[int] = None) -> None:
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
||||
self.sliding_window = sliding_window
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
@ -73,12 +77,19 @@ class PagedAttention(nn.Module):
|
||||
raise ValueError(f"head_size ({self.head_size}) is not supported. "
|
||||
f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.")
|
||||
|
||||
def set_attn_bias(self, input_metadata: InputMetadata) -> None:
|
||||
def set_attn_bias(
|
||||
self,
|
||||
input_metadata: InputMetadata,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
del dtype # Unused.
|
||||
if input_metadata.attn_bias:
|
||||
# Already set by a previous layer.
|
||||
return
|
||||
prompt_lens = input_metadata.prompt_lens
|
||||
attn_bias = BlockDiagonalCausalMask.from_seqlens(prompt_lens)
|
||||
if self.sliding_window is not None:
|
||||
attn_bias = attn_bias.make_local_attention(self.sliding_window)
|
||||
input_metadata.attn_bias.append(attn_bias)
|
||||
|
||||
def multi_query_kv_attention(
|
||||
@ -196,7 +207,7 @@ class PagedAttention(nn.Module):
|
||||
if num_prompt_tokens > 0:
|
||||
# Prompt run.
|
||||
assert input_metadata.num_generation_tokens == 0
|
||||
self.set_attn_bias(input_metadata)
|
||||
self.set_attn_bias(input_metadata, dtype=query.dtype)
|
||||
self.multi_query_kv_attention(
|
||||
output[:num_prompt_tokens],
|
||||
query[:num_prompt_tokens],
|
||||
@ -216,12 +227,20 @@ class PagedAttention(nn.Module):
|
||||
if (num_valid_tokens > 0 and key_cache is not None
|
||||
and value_cache is not None):
|
||||
# The stride is 3 because the key and value are sliced from qkv.
|
||||
key_to_cache = key[:num_valid_tokens]
|
||||
value_to_cache = value[:num_valid_tokens]
|
||||
slot_mapping = input_metadata.slot_mapping
|
||||
if input_metadata.to_cache is not None:
|
||||
key_to_cache = key_to_cache[input_metadata.to_cache]
|
||||
value_to_cache = value_to_cache[input_metadata.to_cache]
|
||||
slot_mapping = slot_mapping[input_metadata.to_cache]
|
||||
|
||||
cache_ops.reshape_and_cache(
|
||||
key[:num_valid_tokens],
|
||||
value[:num_valid_tokens],
|
||||
key_to_cache,
|
||||
value_to_cache,
|
||||
key_cache,
|
||||
value_cache,
|
||||
input_metadata.slot_mapping,
|
||||
slot_mapping,
|
||||
)
|
||||
|
||||
if input_metadata.num_generation_tokens > 0:
|
||||
@ -242,7 +261,7 @@ class PagedAttention(nn.Module):
|
||||
|
||||
|
||||
class PagedAttentionWithRoPE(PagedAttention):
|
||||
"""PagedAttention with rotary embedding."""
|
||||
"""PagedAttention with rotary positional embedding."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -254,25 +273,31 @@ class PagedAttentionWithRoPE(PagedAttention):
|
||||
base: int = 10000,
|
||||
num_kv_heads: Optional[int] = None,
|
||||
is_neox_style: bool = True,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
sliding_window: Optional[int] = None,
|
||||
) -> None:
|
||||
super().__init__(num_heads, head_size, scale, num_kv_heads)
|
||||
self.is_neox_style = is_neox_style
|
||||
|
||||
# Create the cos and sin cache.
|
||||
inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim))
|
||||
t = torch.arange(max_position).float()
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq.float())
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
cache = torch.cat((cos, sin), dim=-1)
|
||||
|
||||
# FIXME(woosuk): This assumes that we configure the default dtype when
|
||||
# initializing the model.
|
||||
# TODO(woosuk): Make it more robust.
|
||||
torch_dtype = torch.get_default_dtype()
|
||||
cache = cache.to(torch_dtype)
|
||||
# Embedding size: [max_position, rotary_dim]
|
||||
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||
super().__init__(num_heads,
|
||||
head_size,
|
||||
scale,
|
||||
num_kv_heads,
|
||||
sliding_window=sliding_window)
|
||||
if rope_scaling is None:
|
||||
self.rotary_emb = RotaryEmbedding(head_size, rotary_dim,
|
||||
max_position, base,
|
||||
is_neox_style)
|
||||
else:
|
||||
scaling_type = rope_scaling["type"]
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
if scaling_type == "linear":
|
||||
self.rotary_emb = LinearScalingRotaryEmbedding(
|
||||
head_size, rotary_dim, max_position, base, is_neox_style,
|
||||
scaling_factor)
|
||||
elif scaling_type == "dynamic":
|
||||
self.rotary_emb = DynamicNTKScalingRotaryEmbedding(
|
||||
head_size, rotary_dim, max_position, base, is_neox_style,
|
||||
scaling_factor)
|
||||
else:
|
||||
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -289,7 +314,7 @@ class PagedAttentionWithRoPE(PagedAttention):
|
||||
|
||||
Args:
|
||||
positions: shape = [num_tokens]
|
||||
query: shape = [num_tokens, num_heads * head_size]
|
||||
query: shape = [num_tokens, num_heads * head_size]
|
||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
|
||||
@ -305,14 +330,7 @@ class PagedAttentionWithRoPE(PagedAttention):
|
||||
|
||||
# Apply rotary embedding to the query and key before passing them
|
||||
# to the attention op.
|
||||
pos_encoding_ops.rotary_embedding(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
self.head_size,
|
||||
self.cos_sin_cache,
|
||||
self.is_neox_style,
|
||||
)
|
||||
query, key = self.rotary_emb(positions, query, key)
|
||||
return super().forward(
|
||||
query,
|
||||
key,
|
||||
@ -339,13 +357,14 @@ class PagedAttentionWithALiBi(PagedAttention):
|
||||
slopes = torch.tensor(slopes, dtype=torch.float32)
|
||||
self.register_buffer("alibi_slopes", slopes, persistent=False)
|
||||
|
||||
def set_attn_bias(self, input_metadata: InputMetadata) -> None:
|
||||
def set_attn_bias(self, input_metadata: InputMetadata,
|
||||
dtype: torch.dtype) -> None:
|
||||
if input_metadata.attn_bias:
|
||||
# Already set by a previous layer.
|
||||
return
|
||||
# Generates ALiBi mask for each prompt.
|
||||
for prompt_len in input_metadata.prompt_lens:
|
||||
bias = torch.arange(prompt_len)
|
||||
bias = torch.arange(prompt_len, dtype=dtype)
|
||||
# Note(zhuohan): HF uses
|
||||
# `bias = bias[None, :].repeat(prompt_len, 1)`
|
||||
# here. We find that both biases give the same results, but
|
||||
@ -363,6 +382,7 @@ class PagedAttentionWithALiBi(PagedAttention):
|
||||
prompt_len,
|
||||
padded_len,
|
||||
device=self.alibi_slopes.device,
|
||||
dtype=dtype,
|
||||
)[:, :, :, :prompt_len].copy_(bias)
|
||||
bias.mul_(self.alibi_slopes[:, None, None])
|
||||
attn_bias = LowerTriangularMaskWithTensorBias(bias)
|
||||
|
37
vllm/model_executor/layers/quantized_linear/__init__.py
Normal file
37
vllm/model_executor/layers/quantized_linear/__init__.py
Normal file
@ -0,0 +1,37 @@
|
||||
from vllm.model_executor.layers.quantized_linear.awq import (
|
||||
AWQColumnParallelLinear, AWQRowParallelLinear)
|
||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||
ColumnParallelLinear, RowParallelLinear)
|
||||
|
||||
_QUANTIZED_LINEAR_REGISTRY = {
|
||||
"awq": (AWQColumnParallelLinear, AWQRowParallelLinear),
|
||||
}
|
||||
|
||||
|
||||
class ParallelLinear:
|
||||
|
||||
@classmethod
|
||||
def column(cls, *args, **kwargs) -> ColumnParallelLinear:
|
||||
quant_config = kwargs.get("quant_config", None)
|
||||
if quant_config is None:
|
||||
return ColumnParallelLinear(*args, **kwargs)
|
||||
|
||||
name = quant_config.get_name()
|
||||
if name not in _QUANTIZED_LINEAR_REGISTRY:
|
||||
raise ValueError(f"No quantized linear is found for {name}")
|
||||
|
||||
quant_linear_cls = _QUANTIZED_LINEAR_REGISTRY[name][0]
|
||||
return quant_linear_cls(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def row(cls, *args, **kwargs) -> RowParallelLinear:
|
||||
quant_config = kwargs.get("quant_config", None)
|
||||
if quant_config is None:
|
||||
return RowParallelLinear(*args, **kwargs)
|
||||
|
||||
name = quant_config.get_name()
|
||||
if name not in _QUANTIZED_LINEAR_REGISTRY:
|
||||
raise ValueError(f"No quantized linear is found for {name}")
|
||||
|
||||
quant_linear_cls = _QUANTIZED_LINEAR_REGISTRY[name][1]
|
||||
return quant_linear_cls(*args, **kwargs)
|
102
vllm/model_executor/layers/quantized_linear/awq.py
Normal file
102
vllm/model_executor/layers/quantized_linear/awq.py
Normal file
@ -0,0 +1,102 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm import quantization_ops
|
||||
from vllm.model_executor.parallel_utils.tensor_parallel.layers import (
|
||||
ColumnParallelLinear, RowParallelLinear)
|
||||
|
||||
|
||||
class AWQColumnParallelLinear(ColumnParallelLinear):
|
||||
|
||||
def create_weights(self, dtype: torch.dtype) -> None:
|
||||
assert self.input_size % self.quant_config.weight_bits == 0
|
||||
assert (self.output_size_per_partition %
|
||||
self.quant_config.pack_factor == 0)
|
||||
self.qweight = Parameter(
|
||||
torch.empty(
|
||||
self.input_size,
|
||||
self.output_size_per_partition //
|
||||
self.quant_config.pack_factor,
|
||||
device="cuda",
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
self.qzeros = Parameter(
|
||||
torch.empty(
|
||||
self.input_size // self.quant_config.group_size,
|
||||
self.output_size_per_partition //
|
||||
self.quant_config.pack_factor,
|
||||
device="cuda",
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
self.scales = Parameter(
|
||||
torch.empty(
|
||||
self.input_size // self.quant_config.group_size,
|
||||
self.output_size_per_partition,
|
||||
device="cuda",
|
||||
dtype=dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
pack_factor = self.quant_config.pack_factor
|
||||
out_shape = (x.shape[-2], self.qweight.shape[-1] * pack_factor)
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
out = quantization_ops.awq_gemm(reshaped_x, self.qweight, self.scales,
|
||||
self.qzeros, pack_factor)
|
||||
if bias is not None:
|
||||
out = out + bias
|
||||
return out.reshape(out_shape)
|
||||
|
||||
|
||||
class AWQRowParallelLinear(RowParallelLinear):
|
||||
|
||||
def create_weights(self, dtype: torch.dtype) -> None:
|
||||
assert (self.input_size_per_partition %
|
||||
self.quant_config.weight_bits == 0)
|
||||
assert self.output_size % self.quant_config.pack_factor == 0
|
||||
self.qweight = Parameter(
|
||||
torch.empty(
|
||||
self.input_size_per_partition,
|
||||
self.output_size // self.quant_config.pack_factor,
|
||||
device="cuda",
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
self.qzeros = Parameter(
|
||||
torch.empty(
|
||||
self.input_size_per_partition // self.quant_config.group_size,
|
||||
self.output_size // self.quant_config.pack_factor,
|
||||
device="cuda",
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
self.scales = Parameter(
|
||||
torch.empty(
|
||||
self.input_size_per_partition // self.quant_config.group_size,
|
||||
self.output_size,
|
||||
device="cuda",
|
||||
dtype=dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
|
||||
pack_factor = self.quant_config.pack_factor
|
||||
out_shape = (x.shape[-2], self.qweight.shape[-1] * pack_factor)
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
out = quantization_ops.awq_gemm(reshaped_x, self.qweight, self.scales,
|
||||
self.qzeros, pack_factor)
|
||||
return out.reshape(out_shape)
|
169
vllm/model_executor/layers/rotary_embedding.py
Normal file
169
vllm/model_executor/layers/rotary_embedding.py
Normal file
@ -0,0 +1,169 @@
|
||||
# coding=utf-8
|
||||
# Adapted from
|
||||
# https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/llama/modeling_llama.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Rotary Positional Embeddings."""
|
||||
from typing import Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm import pos_encoding_ops
|
||||
|
||||
|
||||
class RotaryEmbedding(nn.Module):
|
||||
"""Original rotary positional embedding."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.head_size = head_size
|
||||
self.rotary_dim = rotary_dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
self.is_neox_style = is_neox_style
|
||||
|
||||
cache = self._compute_cos_sin_cache()
|
||||
cache = cache.to(torch.get_default_dtype())
|
||||
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||
|
||||
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
|
||||
"""Compute the inverse frequency."""
|
||||
# NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`.
|
||||
# However, we use `torch.arange(..., dtype=torch.float)` instead to
|
||||
# avoid numerical issues with large base values (e.g., 10000000).
|
||||
# This may cause a slight numerical difference between the HF
|
||||
# implementation and ours.
|
||||
# NOTE(woosuk): To exactly match the HF implementation, we need to
|
||||
# use CPU to compute the cache and then move it to GPU. However, we
|
||||
# create the cache on GPU for faster initialization. This may cause
|
||||
# a slight numerical difference between the HF implementation and ours.
|
||||
inv_freq = 1.0 / (base**(torch.arange(
|
||||
0, self.rotary_dim, 2, dtype=torch.float, device="cuda") /
|
||||
self.rotary_dim))
|
||||
return inv_freq
|
||||
|
||||
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
||||
"""Compute the cos and sin cache."""
|
||||
inv_freq = self._compute_inv_freq(self.base)
|
||||
t = torch.arange(self.max_position_embeddings,
|
||||
dtype=torch.float,
|
||||
device="cuda")
|
||||
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
cache = torch.cat((cos, sin), dim=-1)
|
||||
return cache
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# pos_encoding_ops.rotary_embedding() is an in-place operation that
|
||||
# updates the query and key tensors.
|
||||
pos_encoding_ops.rotary_embedding(positions, query, key,
|
||||
self.head_size, self.cos_sin_cache,
|
||||
self.is_neox_style)
|
||||
return query, key
|
||||
|
||||
|
||||
class LinearScalingRotaryEmbedding(RotaryEmbedding):
|
||||
"""RotaryEmbedding extended with linear scaling.
|
||||
|
||||
Credits to the Reddit user /u/kaiokendev
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
scaling_factor: float,
|
||||
) -> None:
|
||||
self.scaling_factor = scaling_factor
|
||||
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
|
||||
is_neox_style)
|
||||
|
||||
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
||||
inv_freq = self._compute_inv_freq(self.base)
|
||||
# NOTE(woosuk): self.max_position_embeddings is the original
|
||||
# maximum length before applying the rope scaling.
|
||||
# Thus, the maximum length after applying the rope scaling is
|
||||
# self.max_position_embeddings * self.scaling_factor.
|
||||
max_len = self.max_position_embeddings * self.scaling_factor
|
||||
t = torch.arange(max_len, dtype=torch.float, device="cuda")
|
||||
t = t / self.scaling_factor
|
||||
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
cache = torch.cat((cos, sin), dim=-1)
|
||||
return cache
|
||||
|
||||
|
||||
class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
|
||||
"""RotaryEmbedding extended with Dynamic NTK scaling.
|
||||
|
||||
Credits to the Reddit users /u/bloc97 and /u/emozilla
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
scaling_factor: float,
|
||||
) -> None:
|
||||
self.scaling_factor = scaling_factor
|
||||
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
|
||||
is_neox_style)
|
||||
|
||||
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
||||
# NOTE(woosuk): self.max_position_embeddings is the original
|
||||
# maximum length before applying the rope scaling.
|
||||
# Thus, the maximum length after applying the rope scaling is
|
||||
# self.max_position_embeddings * self.scaling_factor.
|
||||
max_len = self.max_position_embeddings * self.scaling_factor
|
||||
base = self.base * (
|
||||
(self.scaling_factor * max_len / self.max_position_embeddings) -
|
||||
(self.scaling_factor - 1))**(self.rotary_dim /
|
||||
(self.rotary_dim - 2))
|
||||
inv_freq = self._compute_inv_freq(base)
|
||||
t = torch.arange(max_len, dtype=torch.float, device="cuda")
|
||||
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
cache = torch.cat((cos, sin), dim=-1)
|
||||
return cache
|
@ -1,15 +1,14 @@
|
||||
"""A layer that samples the next tokens from the model's outputs."""
|
||||
from typing import Dict, List, Tuple, Optional
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||
gather_from_tensor_model_parallel_region)
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import SamplerOutput, SequenceOutputs
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
from vllm.sequence import SamplerOutput, SequenceData, SequenceOutputs
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
|
||||
@ -44,12 +43,8 @@ class Sampler(nn.Module):
|
||||
hidden_states = _prune_hidden_states(hidden_states, input_metadata)
|
||||
|
||||
# Get the logits for the next tokens.
|
||||
logits = torch.matmul(hidden_states, embedding.t())
|
||||
if embedding_bias is not None:
|
||||
logits += embedding_bias
|
||||
logits = gather_from_tensor_model_parallel_region(logits)
|
||||
# Remove paddings in vocab (if any).
|
||||
logits = logits[:, :self.vocab_size]
|
||||
logits = _get_logits(hidden_states, embedding, embedding_bias,
|
||||
self.vocab_size)
|
||||
|
||||
# Apply presence and frequency penalties.
|
||||
output_tokens = _get_output_tokens(input_metadata)
|
||||
@ -59,7 +54,7 @@ class Sampler(nn.Module):
|
||||
assert len(presence_penalties) == logits.shape[0]
|
||||
assert len(frequency_penalties) == logits.shape[0]
|
||||
logits = _apply_penalties(logits, output_tokens, presence_penalties,
|
||||
frequency_penalties, self.vocab_size)
|
||||
frequency_penalties)
|
||||
|
||||
# Apply temperature scaling.
|
||||
temperatures = _get_temperatures(input_metadata)
|
||||
@ -82,26 +77,55 @@ class Sampler(nn.Module):
|
||||
# We use float32 for probabilities and log probabilities.
|
||||
# Compute the probabilities.
|
||||
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
||||
# Compute the log probabilities (before applying top-p and top-k).
|
||||
logprobs = torch.log(probs)
|
||||
# Compute the log probabilities.
|
||||
# Use log_softmax to ensure numerical stability.
|
||||
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
|
||||
|
||||
# Sample the next tokens.
|
||||
return _sample(probs, logprobs, input_metadata)
|
||||
|
||||
|
||||
def _get_logits(hidden_states: torch.Tensor, embedding: torch.Tensor,
|
||||
embedding_bias: Optional[torch.Tensor],
|
||||
vocab_size: int) -> torch.Tensor:
|
||||
# Get the logits for the next tokens.
|
||||
logits = torch.matmul(hidden_states, embedding.t())
|
||||
if embedding_bias is not None:
|
||||
logits += embedding_bias
|
||||
logits = gather_from_tensor_model_parallel_region(logits)
|
||||
# Remove paddings in vocab (if any).
|
||||
logits = logits[:, :vocab_size]
|
||||
return logits
|
||||
|
||||
|
||||
def _prune_hidden_states(
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
) -> torch.Tensor:
|
||||
last_token_indices = {t: [] for t in SamplingType}
|
||||
start_idx = 0
|
||||
last_token_indicies: List[int] = []
|
||||
for prompt_len in input_metadata.prompt_lens:
|
||||
last_token_indicies.append(start_idx + prompt_len - 1)
|
||||
start_idx += prompt_len
|
||||
last_token_indicies.extend(
|
||||
range(start_idx, start_idx + input_metadata.num_generation_tokens))
|
||||
return hidden_states.index_select(
|
||||
0, torch.tensor(last_token_indicies, device=hidden_states.device))
|
||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
||||
seq_ids, sampling_params = seq_group
|
||||
sampling_type = sampling_params.sampling_type
|
||||
if i < input_metadata.num_prompts:
|
||||
assert len(seq_ids) == 1, "Prompt input should have only one seq."
|
||||
prompt_len = input_metadata.prompt_lens[i]
|
||||
last_token_indices[sampling_type].append(start_idx + prompt_len -
|
||||
1)
|
||||
start_idx += prompt_len
|
||||
else:
|
||||
num_seqs = len(seq_ids)
|
||||
last_token_indices[sampling_type].extend(
|
||||
range(start_idx, start_idx + num_seqs))
|
||||
start_idx += num_seqs
|
||||
|
||||
all_last_token_indices = []
|
||||
for sampling_type in SamplingType:
|
||||
all_last_token_indices.extend(last_token_indices[sampling_type])
|
||||
all_last_token_indices = torch.tensor(all_last_token_indices,
|
||||
dtype=torch.long,
|
||||
device=hidden_states.device)
|
||||
return hidden_states.index_select(0, all_last_token_indices)
|
||||
|
||||
|
||||
def _get_penalties(
|
||||
@ -109,37 +133,22 @@ def _get_penalties(
|
||||
# Collect the presence and frequency penalties.
|
||||
presence_penalties: List[float] = []
|
||||
frequency_penalties: List[float] = []
|
||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
||||
for seq_group in input_metadata.seq_groups:
|
||||
seq_ids, sampling_params = seq_group
|
||||
p = sampling_params.presence_penalty
|
||||
f = sampling_params.frequency_penalty
|
||||
if i < input_metadata.num_prompts:
|
||||
# A prompt input.
|
||||
presence_penalties.append(p)
|
||||
frequency_penalties.append(f)
|
||||
else:
|
||||
# A generation token.
|
||||
presence_penalties += [p] * len(seq_ids)
|
||||
frequency_penalties += [f] * len(seq_ids)
|
||||
presence_penalties += [p] * len(seq_ids)
|
||||
frequency_penalties += [f] * len(seq_ids)
|
||||
return presence_penalties, frequency_penalties
|
||||
|
||||
|
||||
def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]:
|
||||
output_tokens: List[List[int]] = []
|
||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
||||
for seq_group in input_metadata.seq_groups:
|
||||
seq_ids, _ = seq_group
|
||||
if i < input_metadata.num_prompts:
|
||||
# A prompt input.
|
||||
# NOTE: While the prompt input usually has no output tokens,
|
||||
# it may have output tokens in the case of recomputation.
|
||||
seq_id = seq_ids[0]
|
||||
for seq_id in seq_ids:
|
||||
seq_data = input_metadata.seq_data[seq_id]
|
||||
output_tokens.append(seq_data.output_token_ids)
|
||||
else:
|
||||
# A generation token.
|
||||
for seq_id in seq_ids:
|
||||
seq_data = input_metadata.seq_data[seq_id]
|
||||
output_tokens.append(seq_data.output_token_ids)
|
||||
return output_tokens
|
||||
|
||||
|
||||
@ -148,11 +157,8 @@ def _apply_penalties(
|
||||
output_tokens: List[List[int]],
|
||||
presence_penalties: List[float],
|
||||
frequency_penalties: List[float],
|
||||
vocab_size: int,
|
||||
) -> torch.Tensor:
|
||||
num_seqs = logits.shape[0]
|
||||
# Collect the indices of sequences that have non-zero penalties.
|
||||
indices = []
|
||||
num_seqs, vocab_size = logits.shape
|
||||
for i in range(num_seqs):
|
||||
if not output_tokens[i]:
|
||||
continue
|
||||
@ -160,40 +166,47 @@ def _apply_penalties(
|
||||
f = frequency_penalties[i]
|
||||
if abs(p) < _SAMPLING_EPS and abs(f) < _SAMPLING_EPS:
|
||||
continue
|
||||
indices.append(i)
|
||||
|
||||
# Return early if all sequences have zero penalties.
|
||||
if not indices:
|
||||
break
|
||||
else:
|
||||
# Return early if all sequences have zero penalties.
|
||||
return logits
|
||||
|
||||
bin_counts = []
|
||||
for i in indices:
|
||||
bin_counts.append(np.bincount(output_tokens[i], minlength=vocab_size))
|
||||
bin_counts = np.stack(bin_counts, axis=0)
|
||||
bin_counts = torch.from_numpy(bin_counts).to(dtype=logits.dtype,
|
||||
device=logits.device)
|
||||
max_output_len = max(len(tokens) for tokens in output_tokens)
|
||||
padded_output_tokens = [
|
||||
tokens + [vocab_size] * (max_output_len - len(tokens))
|
||||
for tokens in output_tokens
|
||||
]
|
||||
output_tokens_tensor = torch.tensor(padded_output_tokens,
|
||||
dtype=torch.long,
|
||||
device=logits.device)
|
||||
|
||||
# Compute the bin counts for the output tokens.
|
||||
# vocab_size + 1 for padding.
|
||||
bin_counts = torch.zeros((num_seqs, vocab_size + 1),
|
||||
dtype=torch.long,
|
||||
device=logits.device)
|
||||
bin_counts.scatter_add_(1, output_tokens_tensor,
|
||||
torch.ones_like(output_tokens_tensor))
|
||||
bin_counts = bin_counts[:, :vocab_size] # Remove the padding bin.
|
||||
|
||||
frequency_penalties = [frequency_penalties[i] for i in indices]
|
||||
frequency_penalties = torch.tensor(frequency_penalties,
|
||||
dtype=logits.dtype,
|
||||
device=logits.device)
|
||||
presence_penalties = [presence_penalties[i] for i in indices]
|
||||
presence_penalties = torch.tensor(presence_penalties,
|
||||
dtype=logits.dtype,
|
||||
device=logits.device)
|
||||
|
||||
# We follow the definition in OpenAI API.
|
||||
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
|
||||
logits[indices] -= frequency_penalties.unsqueeze(dim=1) * bin_counts
|
||||
presence_mask = (bin_counts > 0.0).to(dtype=logits.dtype)
|
||||
logits[indices] -= presence_penalties.unsqueeze(dim=1) * presence_mask
|
||||
logits -= frequency_penalties.unsqueeze(dim=1) * bin_counts
|
||||
logits -= presence_penalties.unsqueeze(dim=1) * (bin_counts > 0)
|
||||
return logits
|
||||
|
||||
|
||||
def _get_temperatures(input_metadata: InputMetadata) -> List[float]:
|
||||
# Collect the temperatures for the logits.
|
||||
temperatures: List[float] = []
|
||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
||||
for seq_group in input_metadata.seq_groups:
|
||||
seq_ids, sampling_params = seq_group
|
||||
temperature = sampling_params.temperature
|
||||
if temperature < _SAMPLING_EPS:
|
||||
@ -201,13 +214,7 @@ def _get_temperatures(input_metadata: InputMetadata) -> List[float]:
|
||||
# (i.e., greedy sampling or beam search).
|
||||
# Set the temperature to 1 to avoid division by zero.
|
||||
temperature = 1.0
|
||||
|
||||
if i < input_metadata.num_prompts:
|
||||
# A prompt input.
|
||||
temperatures.append(temperature)
|
||||
else:
|
||||
# A generation token.
|
||||
temperatures += [temperature] * len(seq_ids)
|
||||
temperatures += [temperature] * len(seq_ids)
|
||||
return temperatures
|
||||
|
||||
|
||||
@ -217,21 +224,15 @@ def _get_top_p_top_k(
|
||||
) -> Tuple[List[float], List[int]]:
|
||||
top_ps: List[float] = []
|
||||
top_ks: List[int] = []
|
||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
||||
for seq_group in input_metadata.seq_groups:
|
||||
seq_ids, sampling_params = seq_group
|
||||
top_p = sampling_params.top_p
|
||||
# k should not be greater than the vocab size.
|
||||
top_k = min(sampling_params.top_k, vocab_size)
|
||||
# k=-1 means no truncation.
|
||||
top_k = vocab_size if top_k == -1 else top_k
|
||||
if i < input_metadata.num_prompts:
|
||||
# A prompt input.
|
||||
top_ps.append(top_p)
|
||||
top_ks.append(top_k)
|
||||
else:
|
||||
# A generation token.
|
||||
top_ps += [top_p] * len(seq_ids)
|
||||
top_ks += [top_k] * len(seq_ids)
|
||||
top_ps += [top_p] * len(seq_ids)
|
||||
top_ks += [top_k] * len(seq_ids)
|
||||
return top_ps, top_ks
|
||||
|
||||
|
||||
@ -267,95 +268,154 @@ def _apply_top_p_top_k(
|
||||
def _get_topk_logprobs(
|
||||
logprobs: torch.Tensor,
|
||||
num_logprobs: Optional[int],
|
||||
) -> Dict[int, float]:
|
||||
) -> List[Dict[int, float]]:
|
||||
num_seqs = logprobs.size(0)
|
||||
if num_logprobs is None or num_logprobs == 0:
|
||||
return {}
|
||||
return [{} for _ in range(num_seqs)]
|
||||
|
||||
topk_logprobs, topk_ids = torch.topk(logprobs, num_logprobs)
|
||||
if num_logprobs == 1:
|
||||
topk_logprobs = [topk_logprobs.item()]
|
||||
topk_ids = [topk_ids.item()]
|
||||
else:
|
||||
topk_logprobs = topk_logprobs.tolist()
|
||||
topk_ids = topk_ids.tolist()
|
||||
|
||||
token_to_logprob: Dict[int, float] = {}
|
||||
for token_id, logprob in zip(topk_ids, topk_logprobs):
|
||||
token_to_logprob[token_id] = logprob
|
||||
return token_to_logprob
|
||||
all_topk_logprobs, all_topk_ids = torch.topk(logprobs,
|
||||
num_logprobs,
|
||||
dim=-1)
|
||||
all_topk_logprobs = all_topk_logprobs.cpu()
|
||||
all_topk_ids = all_topk_ids.cpu()
|
||||
all_token_to_logprob = []
|
||||
for topk_logprobs, topk_ids in zip(all_topk_logprobs, all_topk_ids):
|
||||
token_to_logprob: Dict[int, float] = {}
|
||||
for token_id, logprob in zip(topk_ids, topk_logprobs):
|
||||
token_to_logprob[token_id.item()] = logprob.item()
|
||||
all_token_to_logprob.append(token_to_logprob)
|
||||
return all_token_to_logprob
|
||||
|
||||
|
||||
def _sample_from_prompt(
|
||||
prob: torch.Tensor,
|
||||
sampling_params: SamplingParams,
|
||||
) -> List[int]:
|
||||
if sampling_params.use_beam_search:
|
||||
# Beam search.
|
||||
beam_width = sampling_params.best_of
|
||||
# Sample 2 * beam_width candidates to make sure that with high
|
||||
# probability we can get `beam_width` candidates in addition to
|
||||
# the finished sequences for the next iteration. See
|
||||
# https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563
|
||||
# for details. See also HF reference:
|
||||
# https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
|
||||
_, next_token_ids = torch.topk(prob, 2 * beam_width)
|
||||
next_token_ids = next_token_ids.tolist()
|
||||
elif sampling_params.temperature < _SAMPLING_EPS:
|
||||
# Greedy sampling.
|
||||
assert sampling_params.best_of == 1
|
||||
next_token_id = torch.argmax(prob)
|
||||
next_token_ids = [next_token_id.item()]
|
||||
else:
|
||||
# Random sampling.
|
||||
# Sample `best_of` tokens for the prompt.
|
||||
num_seqs = sampling_params.best_of
|
||||
next_token_ids = torch.multinomial(prob,
|
||||
num_samples=num_seqs,
|
||||
replacement=True)
|
||||
next_token_ids = next_token_ids.tolist()
|
||||
return next_token_ids
|
||||
def _build_sequence_outputs(
|
||||
parent_ids: List[int],
|
||||
next_token_ids: List[int],
|
||||
selected_token_logprobs: torch.Tensor,
|
||||
parent_seq_ids: List[int],
|
||||
parent_logprobs: torch.Tensor,
|
||||
num_output_logprobs: Optional[int],
|
||||
) -> List[SequenceOutputs]:
|
||||
# Get top-k log probabilities for the next tokens.
|
||||
next_logprobs = _get_topk_logprobs(parent_logprobs, num_output_logprobs)
|
||||
seq_outputs: List[SequenceOutputs] = []
|
||||
for parent_id, next_token_id, token_logprob in zip(
|
||||
parent_ids, next_token_ids, selected_token_logprobs):
|
||||
output_logprobs = next_logprobs[parent_id].copy()
|
||||
output_logprobs[next_token_id] = token_logprob
|
||||
seq_outputs.append(
|
||||
SequenceOutputs(parent_seq_ids[parent_id], next_token_id,
|
||||
output_logprobs))
|
||||
return seq_outputs
|
||||
|
||||
|
||||
def _sample_from_generation_tokens(
|
||||
seq_ids: List[int],
|
||||
probs: torch.Tensor,
|
||||
def _greedy_sample(
|
||||
selected_seq_groups: List[Tuple[List[int], SamplingParams]],
|
||||
logprobs: torch.Tensor,
|
||||
seq_logprobs: List[float],
|
||||
sampling_params: SamplingParams,
|
||||
) -> Tuple[List[int], List[int]]:
|
||||
# NOTE(woosuk): sampling_params.best_of can be greater than
|
||||
# len(seq_ids) because some sequences in the group might have
|
||||
# been already terminated.
|
||||
if sampling_params.use_beam_search:
|
||||
# Beam search.
|
||||
# Add cumulative logprobs for the sequences in the group.
|
||||
seq_logprobs = torch.tensor(seq_logprobs,
|
||||
dtype=torch.float,
|
||||
device=logprobs.device)
|
||||
logprobs = logprobs + seq_logprobs.unsqueeze(dim=1)
|
||||
) -> List[Tuple[List[int], List[int]]]:
|
||||
samples = torch.argmax(logprobs, dim=-1).cpu()
|
||||
sample_idx = 0
|
||||
results = []
|
||||
for seq_group in selected_seq_groups:
|
||||
seq_ids, _ = seq_group
|
||||
num_parent_seqs = len(seq_ids)
|
||||
assert num_parent_seqs == 1, (
|
||||
"Greedy sampling should have only one seq.")
|
||||
parent_ids = list(range(num_parent_seqs))
|
||||
next_token_ids = [samples[sample_idx].item()]
|
||||
results.append((next_token_ids, parent_ids))
|
||||
sample_idx += num_parent_seqs
|
||||
assert sample_idx == logprobs.size(0)
|
||||
return results
|
||||
|
||||
vocab_size = logprobs.size(-1)
|
||||
beam_width = len(seq_ids)
|
||||
_, topk_ids = torch.topk(logprobs.flatten(), 2 * beam_width)
|
||||
topk_ids = topk_ids.tolist()
|
||||
seq_idx = [i // vocab_size for i in topk_ids]
|
||||
parent_seq_ids = [seq_ids[i] for i in seq_idx]
|
||||
next_token_ids = [i % vocab_size for i in topk_ids]
|
||||
elif sampling_params.temperature < _SAMPLING_EPS:
|
||||
# Greedy sampling.
|
||||
assert len(seq_ids) == 1
|
||||
next_token_id = torch.argmax(probs, dim=-1)
|
||||
next_token_ids = [int(next_token_id.item())]
|
||||
parent_seq_ids = seq_ids
|
||||
else:
|
||||
# Random sampling.
|
||||
# Sample 1 token for each sequence in the group.
|
||||
next_token_ids = torch.multinomial(probs,
|
||||
num_samples=1,
|
||||
replacement=True)
|
||||
next_token_ids = next_token_ids.squeeze(dim=-1).tolist()
|
||||
parent_seq_ids = seq_ids
|
||||
return parent_seq_ids, next_token_ids
|
||||
|
||||
def _random_sample(
|
||||
selected_seq_groups: List[Tuple[List[int], SamplingParams]],
|
||||
is_prompts: List[bool],
|
||||
probs: torch.Tensor,
|
||||
) -> List[Tuple[List[int], List[int]]]:
|
||||
# Find the maximum best_of value of the prompt phase requests.
|
||||
max_best_of = 1
|
||||
for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
|
||||
if is_prompt:
|
||||
seq_ids, sampling_params = seq_group
|
||||
max_best_of = max(max_best_of, sampling_params.best_of)
|
||||
random_samples = torch.multinomial(probs,
|
||||
num_samples=max_best_of,
|
||||
replacement=True).cpu()
|
||||
sample_idx = 0
|
||||
results = []
|
||||
for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
|
||||
seq_ids, sampling_params = seq_group
|
||||
num_parent_seqs = len(seq_ids)
|
||||
if is_prompt:
|
||||
# Prompt phase.
|
||||
assert num_parent_seqs == 1, (
|
||||
"Prompt input should have only one seq.")
|
||||
parent_ids = [0] * sampling_params.best_of
|
||||
next_token_ids = random_samples[
|
||||
sample_idx, :sampling_params.best_of].tolist()
|
||||
else:
|
||||
# Generation phase.
|
||||
parent_ids = list(range(num_parent_seqs))
|
||||
next_token_ids = random_samples[sample_idx:sample_idx +
|
||||
num_parent_seqs, 0].tolist()
|
||||
results.append((next_token_ids, parent_ids))
|
||||
sample_idx += num_parent_seqs
|
||||
assert sample_idx == probs.size(0)
|
||||
return results
|
||||
|
||||
|
||||
def _beam_search_sample(
|
||||
selected_seq_groups: List[Tuple[List[int], SamplingParams]],
|
||||
is_prompts: List[bool],
|
||||
seq_data: Dict[int, SequenceData],
|
||||
logprobs: torch.Tensor,
|
||||
) -> List[Tuple[List[int], List[int]]]:
|
||||
# We sample 2 * beam_width candidates to make sure that with high
|
||||
# probability we can get `beam_width` candidates in addition to
|
||||
# the finished sequences for the next iteration. See
|
||||
# https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563
|
||||
# for details. See also HF reference:
|
||||
# https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
|
||||
#
|
||||
# Note: Beam search is not vectorized, so its speed can be slower than
|
||||
# other sampling methods.
|
||||
sample_idx = 0
|
||||
results = []
|
||||
for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
|
||||
seq_ids, sampling_params = seq_group
|
||||
num_parent_seqs = len(seq_ids)
|
||||
beam_width = sampling_params.best_of
|
||||
seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs]
|
||||
if is_prompt:
|
||||
# Prompt phase.
|
||||
assert num_parent_seqs == 1, (
|
||||
"Prompt input should have only one seq.")
|
||||
parent_ids = [0] * (2 * beam_width)
|
||||
_, next_token_ids = torch.topk(seq_group_logprobs[0],
|
||||
2 * beam_width)
|
||||
next_token_ids = next_token_ids.tolist()
|
||||
else:
|
||||
# Generation phase.
|
||||
cumulative_logprobs = [
|
||||
seq_data[seq_id].cumulative_logprob for seq_id in seq_ids
|
||||
]
|
||||
cumulative_logprobs = torch.tensor(
|
||||
cumulative_logprobs,
|
||||
dtype=torch.float,
|
||||
device=seq_group_logprobs.device)
|
||||
seq_group_logprobs = (seq_group_logprobs +
|
||||
cumulative_logprobs.unsqueeze(dim=1))
|
||||
_, topk_ids = torch.topk(seq_group_logprobs.flatten(),
|
||||
2 * beam_width)
|
||||
topk_ids = topk_ids.tolist()
|
||||
vocab_size = seq_group_logprobs.size(-1)
|
||||
parent_ids = [i // vocab_size for i in topk_ids]
|
||||
next_token_ids = [i % vocab_size for i in topk_ids]
|
||||
results.append((next_token_ids, parent_ids))
|
||||
sample_idx += num_parent_seqs
|
||||
assert sample_idx == logprobs.size(0)
|
||||
return results
|
||||
|
||||
|
||||
def _sample(
|
||||
@ -363,65 +423,80 @@ def _sample(
|
||||
logprobs: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
) -> SamplerOutput:
|
||||
seq_outputs: SamplerOutput = []
|
||||
|
||||
# TODO(woosuk): Optimize.
|
||||
idx = 0
|
||||
categorized_seq_group_ids = {t: [] for t in SamplingType}
|
||||
category_num_tokens = {t: 0 for t in SamplingType}
|
||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
||||
seq_group_outputs: List[SequenceOutputs] = []
|
||||
seq_ids, sampling_params = seq_group
|
||||
if i < input_metadata.num_prompts:
|
||||
# Generate the next tokens for a prompt input.
|
||||
assert len(seq_ids) == 1, "Prompt input should have only one seq."
|
||||
parent_seq_id = seq_ids[0]
|
||||
prob = probs[idx]
|
||||
logprob = logprobs[idx]
|
||||
idx += 1
|
||||
sampling_type = sampling_params.sampling_type
|
||||
categorized_seq_group_ids[sampling_type].append(i)
|
||||
num_seqs = len(seq_ids)
|
||||
category_num_tokens[sampling_type] += num_seqs
|
||||
|
||||
# Sample the next tokens.
|
||||
next_token_ids = _sample_from_prompt(prob, sampling_params)
|
||||
# Get top-k log probabilities for the next tokens.
|
||||
next_logprobs = _get_topk_logprobs(logprob,
|
||||
sampling_params.logprobs)
|
||||
|
||||
# Build the output.
|
||||
for next_token_id in next_token_ids:
|
||||
output_logprobs = next_logprobs.copy()
|
||||
output_logprobs[next_token_id] = logprob[next_token_id].item()
|
||||
seq_group_outputs.append(
|
||||
SequenceOutputs(parent_seq_id, next_token_id,
|
||||
output_logprobs))
|
||||
seq_outputs_dict: Dict[int, List[SequenceOutputs]] = {}
|
||||
category_start_idx = 0
|
||||
for sampling_type in SamplingType:
|
||||
seq_group_ids = categorized_seq_group_ids[sampling_type]
|
||||
seq_groups = [input_metadata.seq_groups[i] for i in seq_group_ids]
|
||||
is_prompts = [i < input_metadata.num_prompts for i in seq_group_ids]
|
||||
num_tokens = category_num_tokens[sampling_type]
|
||||
if num_tokens == 0:
|
||||
continue
|
||||
category_logprobs = logprobs[category_start_idx:category_start_idx +
|
||||
num_tokens]
|
||||
category_probs = probs[category_start_idx:category_start_idx +
|
||||
num_tokens]
|
||||
if sampling_type == SamplingType.GREEDY:
|
||||
sample_results = _greedy_sample(seq_groups, category_logprobs)
|
||||
elif sampling_type == SamplingType.RANDOM:
|
||||
sample_results = _random_sample(seq_groups, is_prompts,
|
||||
category_probs)
|
||||
elif sampling_type == SamplingType.BEAM:
|
||||
sample_results = _beam_search_sample(seq_groups, is_prompts,
|
||||
input_metadata.seq_data,
|
||||
category_logprobs)
|
||||
else:
|
||||
# Generate the next tokens for generation tokens.
|
||||
raise ValueError(f"Unsupported sampling type: {sampling_type}")
|
||||
|
||||
# Batched query for logprobs of selected token
|
||||
batched_logprobs_query_seq_indices: List[int] = []
|
||||
batched_logprobs_query_token_indices: List[int] = []
|
||||
sample_idx = 0
|
||||
for seq_group_id, seq_group, sample_result in zip(
|
||||
seq_group_ids, seq_groups, sample_results):
|
||||
seq_ids, sampling_params = seq_group
|
||||
next_token_ids, parent_ids = sample_result
|
||||
num_parent_seqs = len(seq_ids)
|
||||
prob = probs[idx:idx + num_parent_seqs]
|
||||
logprob = logprobs[idx:idx + num_parent_seqs]
|
||||
idx += num_parent_seqs
|
||||
batched_logprobs_query_seq_indices.extend(
|
||||
[sample_idx + parent_id for parent_id in parent_ids])
|
||||
batched_logprobs_query_token_indices.extend(next_token_ids)
|
||||
sample_idx += num_parent_seqs
|
||||
assert sample_idx == num_tokens
|
||||
batched_logprobs_query_result = category_logprobs[[
|
||||
batched_logprobs_query_seq_indices,
|
||||
batched_logprobs_query_token_indices
|
||||
]].tolist()
|
||||
|
||||
# Sample the next tokens.
|
||||
seq_logprobs = [
|
||||
input_metadata.seq_data[seq_id].cumulative_logprob
|
||||
for seq_id in seq_ids
|
||||
]
|
||||
parent_seq_ids, next_token_ids = _sample_from_generation_tokens(
|
||||
seq_ids, prob, logprob, seq_logprobs, sampling_params)
|
||||
# Build the sequence outputs.
|
||||
sample_idx = 0
|
||||
result_idx = 0
|
||||
for seq_group_id, seq_group, sample_result in zip(
|
||||
seq_group_ids, seq_groups, sample_results):
|
||||
seq_ids, sampling_params = seq_group
|
||||
next_token_ids, parent_ids = sample_result
|
||||
num_results = len(next_token_ids)
|
||||
num_parent_seqs = len(seq_ids)
|
||||
parent_logprobs = category_logprobs[sample_idx:sample_idx +
|
||||
num_parent_seqs]
|
||||
selected_token_logprobs = batched_logprobs_query_result[
|
||||
result_idx:result_idx + num_results]
|
||||
seq_output = _build_sequence_outputs(parent_ids, next_token_ids,
|
||||
selected_token_logprobs,
|
||||
seq_ids, parent_logprobs,
|
||||
sampling_params.logprobs)
|
||||
seq_outputs_dict[seq_group_id] = seq_output
|
||||
sample_idx += num_parent_seqs
|
||||
result_idx += num_results
|
||||
assert sample_idx == num_tokens
|
||||
category_start_idx += num_tokens
|
||||
|
||||
# Get top-k log probabilities for the next tokens.
|
||||
next_logprobs: Dict[int, Dict[int, float]] = {}
|
||||
for j, seq_id in enumerate(seq_ids):
|
||||
next_logprobs[seq_id] = _get_topk_logprobs(
|
||||
logprob[j], sampling_params.logprobs)
|
||||
|
||||
# Build the output.
|
||||
for parent_seq_id, next_token_id in zip(parent_seq_ids,
|
||||
next_token_ids):
|
||||
j = seq_ids.index(parent_seq_id)
|
||||
output_logprobs = next_logprobs[parent_seq_id].copy()
|
||||
output_logprobs[next_token_id] = logprob[j,
|
||||
next_token_id].item()
|
||||
seq_group_outputs.append(
|
||||
SequenceOutputs(parent_seq_id, next_token_id,
|
||||
output_logprobs))
|
||||
seq_outputs.append(seq_group_outputs)
|
||||
|
||||
return seq_outputs
|
||||
return [seq_outputs_dict[i] for i in range(len(input_metadata.seq_groups))]
|
||||
|
@ -8,7 +8,8 @@ from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.model_executor.models import * # pylint: disable=wildcard-import
|
||||
from vllm.model_executor.weight_utils import initialize_dummy_weights
|
||||
from vllm.model_executor.weight_utils import (get_quant_config,
|
||||
initialize_dummy_weights)
|
||||
|
||||
# TODO(woosuk): Lazy-load the model classes.
|
||||
_MODEL_REGISTRY = {
|
||||
@ -24,12 +25,18 @@ _MODEL_REGISTRY = {
|
||||
"InternLMForCausalLM": InternLMForCausalLM,
|
||||
"LlamaForCausalLM": LlamaForCausalLM,
|
||||
"LLaMAForCausalLM": LlamaForCausalLM, # For decapoda-research/llama-*
|
||||
"MistralForCausalLM": MistralForCausalLM,
|
||||
"MPTForCausalLM": MPTForCausalLM,
|
||||
"OPTForCausalLM": OPTForCausalLM,
|
||||
"QWenLMHeadModel": QWenLMHeadModel,
|
||||
"RWForCausalLM": FalconForCausalLM,
|
||||
}
|
||||
|
||||
# FIXME(woosuk): Remove this once all models support quantization.
|
||||
_MODEL_CLASSES_SUPPORT_QUANTIZATION = [
|
||||
LlamaForCausalLM,
|
||||
]
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _set_default_torch_dtype(dtype: torch.dtype):
|
||||
@ -52,10 +59,38 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
|
||||
|
||||
def get_model(model_config: ModelConfig) -> nn.Module:
|
||||
model_class = _get_model_architecture(model_config.hf_config)
|
||||
|
||||
# Get the quantization config.
|
||||
quant_config = None
|
||||
if model_config.quantization is not None:
|
||||
if model_class not in _MODEL_CLASSES_SUPPORT_QUANTIZATION:
|
||||
raise ValueError(
|
||||
f"Quantization is not supported for {model_class}.")
|
||||
quant_config = get_quant_config(model_config.quantization,
|
||||
model_config.model,
|
||||
model_config.download_dir)
|
||||
capability = torch.cuda.get_device_capability()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
if capability < quant_config.get_min_capability():
|
||||
raise ValueError(
|
||||
f"The quantization method {model_config.quantization} is not "
|
||||
"supported for the current GPU. "
|
||||
f"Minimum capability: {quant_config.get_min_capability()}. "
|
||||
f"Current capability: {capability}.")
|
||||
supported_dtypes = quant_config.get_supported_act_dtypes()
|
||||
if model_config.dtype not in supported_dtypes:
|
||||
raise ValueError(
|
||||
f"{model_config.dtype} is not supported for quantization "
|
||||
f"method {model_config.quantization}. Supported dtypes: "
|
||||
f"{supported_dtypes}")
|
||||
|
||||
with _set_default_torch_dtype(model_config.dtype):
|
||||
# Create a model instance.
|
||||
# The weights will be initialized as empty tensors.
|
||||
model = model_class(model_config.hf_config)
|
||||
if model_class in _MODEL_CLASSES_SUPPORT_QUANTIZATION:
|
||||
model = model_class(model_config.hf_config, quant_config)
|
||||
else:
|
||||
model = model_class(model_config.hf_config)
|
||||
if model_config.load_format == "dummy":
|
||||
model = model.cuda()
|
||||
# NOTE(woosuk): For accurate performance evaluation, we assign
|
||||
@ -64,6 +99,6 @@ def get_model(model_config: ModelConfig) -> nn.Module:
|
||||
else:
|
||||
# Load the weights from the cached or downloaded files.
|
||||
model.load_weights(model_config.model, model_config.download_dir,
|
||||
model_config.load_format)
|
||||
model_config.load_format, model_config.revision)
|
||||
model = model.cuda()
|
||||
return model.eval()
|
||||
|
@ -12,6 +12,7 @@ from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||
from vllm.model_executor.models.mpt import MPTForCausalLM
|
||||
from vllm.model_executor.models.opt import OPTForCausalLM
|
||||
from vllm.model_executor.models.qwen import QWenLMHeadModel
|
||||
from vllm.model_executor.models.mistral import MistralForCausalLM
|
||||
|
||||
__all__ = [
|
||||
"AquilaForCausalLM",
|
||||
@ -28,4 +29,5 @@ __all__ = [
|
||||
"MPTForCausalLM",
|
||||
"OPTForCausalLM",
|
||||
"QWenLMHeadModel",
|
||||
"MistralForCausalLM",
|
||||
]
|
||||
|
@ -105,6 +105,8 @@ class AquilaAttention(nn.Module):
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
rope_theta: float = 10000,
|
||||
max_position_embeddings: int = 8192,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@ -119,6 +121,8 @@ class AquilaAttention(nn.Module):
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
self.qkv_proj = ColumnParallelLinear(
|
||||
hidden_size,
|
||||
@ -140,6 +144,8 @@ class AquilaAttention(nn.Module):
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
rotary_dim=self.head_dim,
|
||||
base=self.rope_theta,
|
||||
max_position=self.max_position_embeddings,
|
||||
)
|
||||
|
||||
def forward(
|
||||
@ -164,10 +170,15 @@ class AquilaDecoderLayer(nn.Module):
|
||||
def __init__(self, config: AquilaConfig):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||
8192)
|
||||
self.self_attn = AquilaAttention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_kv_heads=config.num_attention_heads,
|
||||
rope_theta=rope_theta,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
)
|
||||
self.mlp = AquilaMLP(
|
||||
hidden_size=self.hidden_size,
|
||||
@ -288,7 +299,8 @@ class AquilaForCausalLM(nn.Module):
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto"):
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None):
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||
q_proj_shard_size = (self.config.hidden_size // tp_size)
|
||||
@ -305,7 +317,7 @@ class AquilaForCausalLM(nn.Module):
|
||||
state_dict = self.state_dict()
|
||||
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format):
|
||||
model_name_or_path, cache_dir, load_format, revision):
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
|
@ -111,6 +111,8 @@ class BaiChuanAttention(nn.Module):
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
position_embedding: str,
|
||||
rope_theta: float = 10000,
|
||||
max_position_embeddings: int = 8192,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@ -122,6 +124,8 @@ class BaiChuanAttention(nn.Module):
|
||||
tensor_model_parallel_world_size)
|
||||
self.head_dim = hidden_size // self.total_num_heads
|
||||
self.postion_embedding = position_embedding
|
||||
self.rope_theta = rope_theta
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
self.W_pack = ColumnParallelLinear(
|
||||
@ -151,10 +155,13 @@ class BaiChuanAttention(nn.Module):
|
||||
scaling, alibi_slopes)
|
||||
else:
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.attn = PagedAttentionWithRoPE(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
rotary_dim=self.head_dim)
|
||||
self.attn = PagedAttentionWithRoPE(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
rotary_dim=self.head_dim,
|
||||
base=self.rope_theta,
|
||||
max_position=self.max_position_embeddings)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -183,10 +190,15 @@ class BaiChuanDecoderLayer(nn.Module):
|
||||
def __init__(self, config: BaiChuanConfig, position_embedding: str):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||
8192)
|
||||
self.self_attn = BaiChuanAttention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
position_embedding=position_embedding,
|
||||
rope_theta=rope_theta,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
)
|
||||
self.mlp = BaiChuanMLP(
|
||||
hidden_size=self.hidden_size,
|
||||
@ -303,13 +315,14 @@ class BaiChuanBaseForCausalLM(nn.Module):
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto"):
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None):
|
||||
tp_world_size = get_tensor_model_parallel_world_size()
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
state_dict = self.state_dict()
|
||||
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format):
|
||||
model_name_or_path, cache_dir, load_format, revision):
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
|
@ -279,11 +279,12 @@ class BloomForCausalLM(nn.Module):
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto"):
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None):
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
state_dict = self.state_dict()
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format):
|
||||
model_name_or_path, cache_dir, load_format, revision):
|
||||
if name == "lm_head.weight":
|
||||
# Since hidden_states are parallelized, we need to
|
||||
# load lm_head.weight in parallel.
|
||||
|
@ -161,12 +161,17 @@ class FalconAttention(nn.Module):
|
||||
"Rotary and alibi are mutually exclusive.")
|
||||
|
||||
if self.use_rotary:
|
||||
# TODO(zhuohan): Pass in correct `max_position``
|
||||
self.attn = PagedAttentionWithRoPE(self.num_heads,
|
||||
self.head_dim,
|
||||
self.inv_norm_factor,
|
||||
rotary_dim=self.head_dim,
|
||||
num_kv_heads=self.num_kv_heads)
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
max_position_embeddings = getattr(config,
|
||||
"max_position_embeddings", 8192)
|
||||
self.attn = PagedAttentionWithRoPE(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.inv_norm_factor,
|
||||
base=rope_theta,
|
||||
max_position=max_position_embeddings,
|
||||
rotary_dim=self.head_dim,
|
||||
num_kv_heads=self.num_kv_heads)
|
||||
elif self.use_alibi:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
head_start = tp_rank * self.num_heads
|
||||
@ -420,7 +425,8 @@ class FalconForCausalLM(nn.Module):
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto"):
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None):
|
||||
tp_size = (get_tensor_model_parallel_world_size())
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
@ -452,7 +458,7 @@ class FalconForCausalLM(nn.Module):
|
||||
state_dict = self.state_dict()
|
||||
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format):
|
||||
model_name_or_path, cache_dir, load_format, revision):
|
||||
if "query_key_value" in name:
|
||||
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
||||
loaded_weight_size = loaded_weight.size()
|
||||
|
@ -231,14 +231,15 @@ class GPT2LMHeadModel(nn.Module):
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto"):
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None):
|
||||
tensor_model_parallel_world_size = (
|
||||
get_tensor_model_parallel_world_size())
|
||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||
state_dict = self.state_dict()
|
||||
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format):
|
||||
model_name_or_path, cache_dir, load_format, revision):
|
||||
if "lm_head.weight" in name:
|
||||
# GPT-2 ties the weights of the embedding layer and the final
|
||||
# linear layer.
|
||||
|
@ -259,14 +259,15 @@ class GPTBigCodeForCausalLM(nn.Module):
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto"):
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None):
|
||||
tensor_model_parallel_world_size = (
|
||||
get_tensor_model_parallel_world_size())
|
||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||
state_dict = self.state_dict()
|
||||
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format):
|
||||
model_name_or_path, cache_dir, load_format, revision):
|
||||
if "lm_head.weight" in name:
|
||||
# GPT-2 ties the weights of the embedding layer and the final
|
||||
# linear layer.
|
||||
|
@ -67,11 +67,17 @@ class GPTJAttention(nn.Module):
|
||||
scaling = self.head_size**-0.5
|
||||
assert getattr(config, "rotary", True)
|
||||
assert config.rotary_dim % 2 == 0
|
||||
self.attn = PagedAttentionWithRoPE(self.num_heads,
|
||||
self.head_size,
|
||||
scaling,
|
||||
config.rotary_dim,
|
||||
is_neox_style=False)
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||
8192)
|
||||
self.attn = PagedAttentionWithRoPE(
|
||||
self.num_heads,
|
||||
self.head_size,
|
||||
scaling,
|
||||
config.rotary_dim,
|
||||
base=rope_theta,
|
||||
max_position=max_position_embeddings,
|
||||
is_neox_style=False)
|
||||
self.warmup = False
|
||||
|
||||
def forward(
|
||||
@ -222,11 +228,12 @@ class GPTJForCausalLM(nn.Module):
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto"):
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None):
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
state_dict = self.state_dict()
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format):
|
||||
model_name_or_path, cache_dir, load_format, revision):
|
||||
if "attn.bias" in name or "attn.masked_bias" in name:
|
||||
continue
|
||||
|
||||
|
@ -68,8 +68,16 @@ class GPTNeoXAttention(nn.Module):
|
||||
scaling = self.head_size**-0.5
|
||||
rotary_dim = int(self.head_size * config.rotary_pct)
|
||||
assert rotary_dim % 2 == 0
|
||||
self.attn = PagedAttentionWithRoPE(self.num_heads, self.head_size,
|
||||
scaling, rotary_dim)
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||
8192)
|
||||
self.attn = PagedAttentionWithRoPE(
|
||||
self.num_heads,
|
||||
self.head_size,
|
||||
scaling,
|
||||
rotary_dim,
|
||||
base=rope_theta,
|
||||
max_position=max_position_embeddings)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -231,11 +239,12 @@ class GPTNeoXForCausalLM(nn.Module):
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto"):
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None):
|
||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||
state_dict = self.state_dict()
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format):
|
||||
model_name_or_path, cache_dir, load_format, revision):
|
||||
if ("attention.bias" in name or "attention.masked_bias" in name
|
||||
or "rotary_emb.inv_freq" in name):
|
||||
continue
|
||||
|
@ -59,6 +59,8 @@ class InternLMAttention(nn.Module):
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
rope_theta: float = 10000,
|
||||
max_position_embeddings: int = 8192,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@ -70,6 +72,8 @@ class InternLMAttention(nn.Module):
|
||||
tensor_model_parallel_world_size)
|
||||
self.head_dim = hidden_size // self.total_num_heads
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
self.qkv_proj = ColumnParallelLinear(
|
||||
hidden_size,
|
||||
@ -85,10 +89,13 @@ class InternLMAttention(nn.Module):
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False,
|
||||
)
|
||||
self.attn = PagedAttentionWithRoPE(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
rotary_dim=self.head_dim)
|
||||
self.attn = PagedAttentionWithRoPE(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
base=self.rope_theta,
|
||||
max_position=self.max_position_embeddings,
|
||||
rotary_dim=self.head_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -112,9 +119,14 @@ class InternLMDecoderLayer(nn.Module):
|
||||
def __init__(self, config: LlamaConfig):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||
8192)
|
||||
self.self_attn = InternLMAttention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
rope_theta=rope_theta,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
)
|
||||
self.mlp = InternLMMLP(
|
||||
hidden_size=self.hidden_size,
|
||||
@ -233,12 +245,13 @@ class InternLMForCausalLM(nn.Module):
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto"):
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None):
|
||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||
state_dict = self.state_dict()
|
||||
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format):
|
||||
model_name_or_path, cache_dir, load_format, revision):
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
|
@ -25,7 +25,7 @@
|
||||
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
||||
InputMetadata to extract the original 2D shape of the input.
|
||||
"""
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -36,13 +36,15 @@ from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.weight_utils import (
|
||||
load_tensor_parallel_weights, load_padded_tensor_parallel_vocab,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.model_executor.layers.quantized_linear import ParallelLinear
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.quantization_utils import QuantizationConfig
|
||||
from vllm.model_executor.weight_utils import (
|
||||
convert_pyslice_to_tensor, hf_model_weights_iterator,
|
||||
load_tensor_parallel_weights, load_padded_tensor_parallel_vocab)
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
@ -55,18 +57,21 @@ class LlamaMLP(nn.Module):
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = ColumnParallelLinear(hidden_size,
|
||||
2 * intermediate_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.down_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
self.gate_up_proj = ParallelLinear.column(hidden_size,
|
||||
2 * intermediate_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
perform_initialization=False,
|
||||
quant_config=quant_config)
|
||||
self.down_proj = ParallelLinear.row(intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False,
|
||||
quant_config=quant_config)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
@ -87,7 +92,10 @@ class LlamaAttention(nn.Module):
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
rope_theta: float = 10000,
|
||||
):
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
max_position_embeddings: int = 8192,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
@ -102,28 +110,34 @@ class LlamaAttention(nn.Module):
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
self.qkv_proj = ColumnParallelLinear(
|
||||
self.qkv_proj = ParallelLinear.column(
|
||||
hidden_size,
|
||||
(self.total_num_heads + 2 * self.total_num_kv_heads) *
|
||||
self.head_dim,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
perform_initialization=False,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.o_proj = ParallelLinear.row(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.attn = PagedAttentionWithRoPE(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
base=self.rope_theta,
|
||||
rotary_dim=self.head_dim,
|
||||
num_kv_heads=self.num_kv_heads)
|
||||
self.attn = PagedAttentionWithRoPE(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
base=self.rope_theta,
|
||||
max_position=self.max_position_embeddings,
|
||||
rotary_dim=self.head_dim,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
rope_scaling=rope_scaling)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -144,21 +158,32 @@ class LlamaAttention(nn.Module):
|
||||
|
||||
class LlamaDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(self, config: LlamaConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
# Requires transformers > 4.32.0
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||
8192)
|
||||
self.self_attn = LlamaAttention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.mlp = LlamaMLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
@ -195,7 +220,11 @@ class LlamaDecoderLayer(nn.Module):
|
||||
|
||||
class LlamaModel(nn.Module):
|
||||
|
||||
def __init__(self, config: LlamaConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.padding_idx = config.pad_token_id
|
||||
@ -205,7 +234,8 @@ class LlamaModel(nn.Module):
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
vocab_size, config.hidden_size, perform_initialization=False)
|
||||
self.layers = nn.ModuleList([
|
||||
LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)
|
||||
LlamaDecoderLayer(config, quant_config)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
@ -237,16 +267,23 @@ class LlamaModel(nn.Module):
|
||||
|
||||
class LlamaForCausalLM(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
def __init__(
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.model = LlamaModel(config)
|
||||
self.quant_config = quant_config
|
||||
self.model = LlamaModel(config, quant_config)
|
||||
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
||||
self.lm_head = ColumnParallelLinear(config.hidden_size,
|
||||
vocab_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
# NOTE: The LM head is not quantized.
|
||||
self.lm_head = ParallelLinear.column(config.hidden_size,
|
||||
vocab_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
perform_initialization=False,
|
||||
quant_config=None)
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
|
||||
def forward(
|
||||
@ -263,15 +300,28 @@ class LlamaForCausalLM(nn.Module):
|
||||
input_metadata)
|
||||
return next_tokens
|
||||
|
||||
_column_parallel_weights = [
|
||||
"qkv_proj.weight", "gate_proj.weight", "up_proj.weight"
|
||||
]
|
||||
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
|
||||
_column_parallel_layers = []
|
||||
_row_parallel_layers = ["o_proj", "down_proj"]
|
||||
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto"):
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None):
|
||||
if self.quant_config is None:
|
||||
weight_suffixes = ["weight"]
|
||||
else:
|
||||
weight_suffixes = self.quant_config.get_tp_tensor_names()
|
||||
|
||||
column_parallel_weights: List[str] = []
|
||||
for layer in self._column_parallel_layers:
|
||||
for suffix in weight_suffixes:
|
||||
column_parallel_weights.append(f"{layer}.{suffix}")
|
||||
row_parallel_weights: List[str] = []
|
||||
for layer in self._row_parallel_layers:
|
||||
for suffix in weight_suffixes:
|
||||
row_parallel_weights.append(f"{layer}.{suffix}")
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||
q_proj_shard_size = (self.config.hidden_size // tp_size)
|
||||
@ -288,15 +338,30 @@ class LlamaForCausalLM(nn.Module):
|
||||
state_dict = self.state_dict()
|
||||
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format):
|
||||
model_name_or_path, cache_dir, load_format, revision):
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
is_packed = False
|
||||
is_transposed = False
|
||||
if self.quant_config is not None:
|
||||
is_packed = self.quant_config.is_packed(name)
|
||||
is_transposed = self.quant_config.is_transposed(name)
|
||||
if is_transposed:
|
||||
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
||||
loaded_weight = loaded_weight.T
|
||||
|
||||
is_attention_weight = False
|
||||
for weight_name, shard_size, offset in attention_weight_specs:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
param = state_dict[name.replace(weight_name, "qkv_proj")]
|
||||
if is_transposed:
|
||||
param = param.T
|
||||
|
||||
if is_packed:
|
||||
shard_size //= self.quant_config.pack_factor
|
||||
offset //= self.quant_config.pack_factor
|
||||
|
||||
loaded_weight = loaded_weight[
|
||||
shard_size * tensor_model_parallel_rank:shard_size *
|
||||
@ -315,6 +380,9 @@ class LlamaForCausalLM(nn.Module):
|
||||
if weight_name not in name:
|
||||
continue
|
||||
param = state_dict[name.replace(weight_name, "gate_up_proj")]
|
||||
if is_transposed:
|
||||
param = param.T
|
||||
|
||||
shard_size = param.shape[0] // 2
|
||||
loaded_weight = loaded_weight[
|
||||
shard_size * tensor_model_parallel_rank:shard_size *
|
||||
@ -329,6 +397,8 @@ class LlamaForCausalLM(nn.Module):
|
||||
continue
|
||||
|
||||
param = state_dict[name]
|
||||
if is_transposed:
|
||||
param = param.T
|
||||
|
||||
if "embed_tokens" in name or "lm_head" in name:
|
||||
load_padded_tensor_parallel_vocab(param, loaded_weight,
|
||||
@ -336,6 +406,6 @@ class LlamaForCausalLM(nn.Module):
|
||||
continue
|
||||
|
||||
load_tensor_parallel_weights(param, loaded_weight, name,
|
||||
self._column_parallel_weights,
|
||||
self._row_parallel_weights,
|
||||
column_parallel_weights,
|
||||
row_parallel_weights,
|
||||
tensor_model_parallel_rank)
|
||||
|
404
vllm/model_executor/models/mistral.py
Normal file
404
vllm/model_executor/models/mistral.py
Normal file
@ -0,0 +1,404 @@
|
||||
# coding=utf-8
|
||||
# Adapted from
|
||||
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only LLaMA model compatible with HuggingFace weights.
|
||||
|
||||
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
||||
InputMetadata to extract the original 2D shape of the input.
|
||||
"""
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.quantized_linear import ParallelLinear
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.quantization_utils import QuantizationConfig
|
||||
from vllm.model_executor.weight_utils import (
|
||||
convert_pyslice_to_tensor, hf_model_weights_iterator,
|
||||
load_tensor_parallel_weights, load_padded_tensor_parallel_vocab)
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.transformers_utils.configs.mistral import MistralConfig
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
class MistralMLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = ParallelLinear.column(hidden_size,
|
||||
2 * intermediate_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
perform_initialization=False,
|
||||
quant_config=quant_config)
|
||||
self.down_proj = ParallelLinear.row(intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False,
|
||||
quant_config=quant_config)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
def forward(self, x):
|
||||
gate_up, _ = self.gate_up_proj(x)
|
||||
x = self.act_fn(gate_up)
|
||||
x, _ = self.down_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class MistralAttention(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
max_position: int = 4096 * 32,
|
||||
rope_theta: float = 10000,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
sliding_window: Optional[int] = None) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.total_num_heads = num_heads
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_size
|
||||
self.total_num_kv_heads = num_kv_heads
|
||||
assert self.total_num_kv_heads % tp_size == 0
|
||||
self.num_kv_heads = self.total_num_kv_heads // tp_size
|
||||
self.head_dim = hidden_size // self.total_num_heads
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
self.sliding_window = sliding_window
|
||||
|
||||
self.qkv_proj = ParallelLinear.column(
|
||||
hidden_size,
|
||||
(self.total_num_heads + 2 * self.total_num_kv_heads) *
|
||||
self.head_dim,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
perform_initialization=False,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.o_proj = ParallelLinear.row(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.attn = PagedAttentionWithRoPE(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
base=self.rope_theta,
|
||||
max_position=max_position,
|
||||
rotary_dim=self.head_dim,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
sliding_window=self.sliding_window)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
k_cache, v_cache = kv_cache
|
||||
attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
|
||||
input_metadata, cache_event)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class MistralDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: MistralConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
# Requires transformers > 4.32.0
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
self.self_attn = MistralAttention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
max_position=config.max_position_embeddings,
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
rope_theta=rope_theta,
|
||||
quant_config=quant_config,
|
||||
sliding_window=config.sliding_window)
|
||||
self.mlp = MistralMLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
# Self Attention
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
input_metadata=input_metadata,
|
||||
cache_event=cache_event,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
return hidden_states
|
||||
|
||||
|
||||
class MistralModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: MistralConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
vocab_size, config.hidden_size, perform_initialization=False)
|
||||
self.layers = nn.ModuleList([
|
||||
MistralDecoderLayer(config, quant_config)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
for i in range(len(self.layers)):
|
||||
if cache_events is None:
|
||||
cache_event = None
|
||||
else:
|
||||
cache_event = cache_events[i]
|
||||
layer = self.layers[i]
|
||||
hidden_states = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
kv_caches[i],
|
||||
input_metadata,
|
||||
cache_event,
|
||||
)
|
||||
hidden_states = self.norm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class MistralForCausalLM(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: MistralConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = MistralModel(config, quant_config)
|
||||
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
||||
# NOTE: The LM head is not quantized.
|
||||
self.lm_head = ParallelLinear.column(config.hidden_size,
|
||||
vocab_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
perform_initialization=False,
|
||||
quant_config=None)
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> SamplerOutput:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
input_metadata, cache_events)
|
||||
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
||||
input_metadata)
|
||||
return next_tokens
|
||||
|
||||
_column_parallel_layers = []
|
||||
_row_parallel_layers = ["o_proj", "down_proj"]
|
||||
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None):
|
||||
if self.quant_config is None:
|
||||
weight_suffixes = ["weight"]
|
||||
else:
|
||||
weight_suffixes = self.quant_config.get_tp_tensor_names()
|
||||
|
||||
column_parallel_weights: List[str] = []
|
||||
for layer in self._column_parallel_layers:
|
||||
for suffix in weight_suffixes:
|
||||
column_parallel_weights.append(f"{layer}.{suffix}")
|
||||
row_parallel_weights: List[str] = []
|
||||
for layer in self._row_parallel_layers:
|
||||
for suffix in weight_suffixes:
|
||||
row_parallel_weights.append(f"{layer}.{suffix}")
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||
q_proj_shard_size = (self.config.hidden_size // tp_size)
|
||||
kv_proj_shard_size = (self.config.hidden_size //
|
||||
self.config.num_attention_heads *
|
||||
self.config.num_key_value_heads // tp_size)
|
||||
attention_weight_specs = [
|
||||
# (weight_name, shard_size, offset)
|
||||
("q_proj", q_proj_shard_size, 0),
|
||||
("k_proj", kv_proj_shard_size, q_proj_shard_size),
|
||||
("v_proj", kv_proj_shard_size,
|
||||
q_proj_shard_size + kv_proj_shard_size),
|
||||
]
|
||||
state_dict = self.state_dict()
|
||||
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format, revision):
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
is_packed = False
|
||||
is_transposed = False
|
||||
if self.quant_config is not None:
|
||||
is_packed = self.quant_config.is_packed(name)
|
||||
is_transposed = self.quant_config.is_transposed(name)
|
||||
if is_transposed:
|
||||
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
||||
loaded_weight = loaded_weight.T
|
||||
|
||||
is_attention_weight = False
|
||||
for weight_name, shard_size, offset in attention_weight_specs:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
param = state_dict[name.replace(weight_name, "qkv_proj")]
|
||||
if is_transposed:
|
||||
param = param.T
|
||||
|
||||
if is_packed:
|
||||
shard_size //= self.quant_config.pack_factor
|
||||
offset //= self.quant_config.pack_factor
|
||||
|
||||
loaded_weight = loaded_weight[
|
||||
shard_size * tensor_model_parallel_rank:shard_size *
|
||||
(tensor_model_parallel_rank + 1)]
|
||||
param_slice = param.data[offset:offset + shard_size]
|
||||
assert param_slice.shape == loaded_weight.shape
|
||||
|
||||
param_slice.copy_(loaded_weight)
|
||||
is_attention_weight = True
|
||||
break
|
||||
if is_attention_weight:
|
||||
continue
|
||||
|
||||
is_gate_up_weight = False
|
||||
for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
|
||||
if weight_name not in name:
|
||||
continue
|
||||
param = state_dict[name.replace(weight_name, "gate_up_proj")]
|
||||
if is_transposed:
|
||||
param = param.T
|
||||
|
||||
shard_size = param.shape[0] // 2
|
||||
loaded_weight = loaded_weight[
|
||||
shard_size * tensor_model_parallel_rank:shard_size *
|
||||
(tensor_model_parallel_rank + 1)]
|
||||
param_slice = param.data[shard_size * stride_id:shard_size *
|
||||
(stride_id + 1)]
|
||||
assert param_slice.shape == loaded_weight.shape
|
||||
param_slice.copy_(loaded_weight)
|
||||
is_gate_up_weight = True
|
||||
break
|
||||
if is_gate_up_weight:
|
||||
continue
|
||||
|
||||
param = state_dict[name]
|
||||
if is_transposed:
|
||||
param = param.T
|
||||
|
||||
if "embed_tokens" in name or "lm_head" in name:
|
||||
load_padded_tensor_parallel_vocab(param, loaded_weight,
|
||||
tensor_model_parallel_rank)
|
||||
continue
|
||||
|
||||
load_tensor_parallel_weights(param, loaded_weight, name,
|
||||
column_parallel_weights,
|
||||
row_parallel_weights,
|
||||
tensor_model_parallel_rank)
|
@ -244,12 +244,13 @@ class MPTForCausalLM(nn.Module):
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto"):
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None):
|
||||
tp_world_size = get_tensor_model_parallel_world_size()
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
state_dict = self.state_dict()
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format):
|
||||
model_name_or_path, cache_dir, load_format, revision):
|
||||
if "Wqkv" in name:
|
||||
# NOTE(woosuk): MPT's fused QKV has the shape of
|
||||
# [3 * num_heads * head_size, hidden_size].
|
||||
|
@ -297,12 +297,13 @@ class OPTForCausalLM(nn.Module):
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto"):
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None):
|
||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||
state_dict = self.state_dict()
|
||||
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format):
|
||||
model_name_or_path, cache_dir, load_format, revision):
|
||||
if "lm_head.weight" in name:
|
||||
continue
|
||||
|
||||
|
@ -8,7 +8,7 @@
|
||||
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
||||
InputMetadata to extract the original 2D shape of the input.
|
||||
"""
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -76,8 +76,12 @@ class QWenMLP(nn.Module):
|
||||
|
||||
class QWenAttention(nn.Module):
|
||||
|
||||
def __init__(self, hidden_size: int, num_heads: int,
|
||||
max_position_embeddings: int):
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
max_position_embeddings: int,
|
||||
rope_theta: float = 10000,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
|
||||
@ -109,8 +113,9 @@ class QWenAttention(nn.Module):
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
rotary_dim=self.head_dim,
|
||||
base=rope_theta,
|
||||
max_position=max_position_embeddings,
|
||||
)
|
||||
rope_scaling=rope_scaling)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -135,14 +140,19 @@ class QWenBlock(nn.Module):
|
||||
|
||||
def __init__(self, config: QWenConfig):
|
||||
super().__init__()
|
||||
self.ln_1 = RMSNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
||||
self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||
|
||||
self.attn = QWenAttention(config.n_embd, config.num_attention_heads,
|
||||
config.max_position_embeddings)
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
self.attn = QWenAttention(config.hidden_size,
|
||||
config.num_attention_heads,
|
||||
config.max_position_embeddings,
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=rope_scaling)
|
||||
|
||||
self.ln_2 = RMSNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
||||
self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||
|
||||
self.mlp = QWenMLP(config.n_embd, config.ffn_hidden_size // 2)
|
||||
self.mlp = QWenMLP(config.hidden_size, config.intermediate_size // 2)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -181,11 +191,11 @@ class QWenModel(nn.Module):
|
||||
|
||||
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
||||
self.wte = VocabParallelEmbedding(vocab_size,
|
||||
config.n_embd,
|
||||
config.hidden_size,
|
||||
perform_initialization=False)
|
||||
self.h = nn.ModuleList(
|
||||
[QWenBlock(config) for _ in range(config.num_hidden_layers)])
|
||||
self.ln_f = RMSNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
||||
self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -221,7 +231,7 @@ class QWenLMHeadModel(nn.Module):
|
||||
self.transformer = QWenModel(config)
|
||||
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
||||
self.lm_head = ColumnParallelLinear(
|
||||
config.n_embd,
|
||||
config.hidden_size,
|
||||
vocab_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
@ -251,13 +261,14 @@ class QWenLMHeadModel(nn.Module):
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None,
|
||||
):
|
||||
tp_world_size = get_tensor_model_parallel_world_size()
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
state_dict = self.state_dict()
|
||||
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format):
|
||||
model_name_or_path, cache_dir, load_format, revision):
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
|
@ -4,7 +4,7 @@
|
||||
|
||||
# Parts of the code here are adapted from PyTorch
|
||||
# repo: https://github.com/pytorch/pytorch
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -16,13 +16,11 @@ from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from .mappings import (
|
||||
copy_to_tensor_model_parallel_region,
|
||||
gather_from_tensor_model_parallel_region,
|
||||
reduce_from_tensor_model_parallel_region,
|
||||
scatter_to_tensor_model_parallel_region,
|
||||
)
|
||||
|
||||
from .random import get_cuda_rng_tracker
|
||||
from .utils import (
|
||||
divide,
|
||||
VocabUtility,
|
||||
@ -65,59 +63,6 @@ def copy_tensor_model_parallel_attributes(destination_tensor, source_tensor):
|
||||
maybe_copy(attribute)
|
||||
|
||||
|
||||
def _initialize_affine_weight_gpu(weight, init_method,
|
||||
partition_dim, stride=1):
|
||||
"""Initialize affine weight for model parallel on GPU."""
|
||||
|
||||
set_tensor_model_parallel_attributes(tensor=weight,
|
||||
is_parallel=True,
|
||||
dim=partition_dim,
|
||||
stride=stride)
|
||||
|
||||
with get_cuda_rng_tracker().fork():
|
||||
init_method(weight)
|
||||
|
||||
|
||||
def _initialize_affine_weight_cpu(weight, output_size, input_size,
|
||||
per_partition_size, partition_dim,
|
||||
init_method, stride=1,
|
||||
return_master_weight=False,
|
||||
*, params_dtype=None):
|
||||
"""Initialize affine weight for model parallel.
|
||||
|
||||
Build the master weight on all processes and scatter
|
||||
the relevant chunk."""
|
||||
|
||||
set_tensor_model_parallel_attributes(tensor=weight,
|
||||
is_parallel=True,
|
||||
dim=partition_dim,
|
||||
stride=stride)
|
||||
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
|
||||
# Initialize master weight
|
||||
master_weight = torch.empty(output_size, input_size,
|
||||
dtype=torch.float,
|
||||
requires_grad=False)
|
||||
init_method(master_weight)
|
||||
master_weight = master_weight.to(dtype=params_dtype)
|
||||
|
||||
# Split and copy
|
||||
per_partition_per_stride_size = divide(per_partition_size, stride)
|
||||
weight_list = torch.split(master_weight, per_partition_per_stride_size,
|
||||
dim=partition_dim)
|
||||
rank = get_tensor_model_parallel_rank()
|
||||
world_size = get_tensor_model_parallel_world_size()
|
||||
my_weight_list = weight_list[rank::world_size]
|
||||
|
||||
with torch.no_grad():
|
||||
torch.cat(my_weight_list, dim=partition_dim, out=weight)
|
||||
if return_master_weight:
|
||||
return master_weight
|
||||
return None
|
||||
|
||||
|
||||
class VocabParallelEmbedding(torch.nn.Module):
|
||||
"""Embedding parallelized in the vocabulary dimension.
|
||||
|
||||
@ -138,8 +83,11 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
init_method=init.xavier_normal_,
|
||||
params_dtype: torch.dtype=None,
|
||||
use_cpu_initialization: bool=False,
|
||||
perform_initialization: bool=True):
|
||||
perform_initialization: bool=False):
|
||||
super(VocabParallelEmbedding, self).__init__()
|
||||
assert not perform_initialization
|
||||
assert not use_cpu_initialization
|
||||
|
||||
# Keep the input dimensions.
|
||||
self.num_embeddings = num_embeddings
|
||||
self.embedding_dim = embedding_dim
|
||||
@ -162,23 +110,9 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
self.num_embeddings_per_partition = self.vocab_end_index - \
|
||||
self.vocab_start_index
|
||||
|
||||
# Allocate weights and initialize.
|
||||
if use_cpu_initialization:
|
||||
self.weight = Parameter(torch.empty(
|
||||
self.num_embeddings_per_partition, self.embedding_dim,
|
||||
dtype=params_dtype))
|
||||
if perform_initialization:
|
||||
_initialize_affine_weight_cpu(
|
||||
self.weight, self.num_embeddings, self.embedding_dim,
|
||||
self.num_embeddings_per_partition, 0, init_method,
|
||||
params_dtype=params_dtype)
|
||||
else:
|
||||
self.weight = Parameter(torch.empty(
|
||||
self.num_embeddings_per_partition, self.embedding_dim,
|
||||
device=torch.cuda.current_device(), dtype=params_dtype))
|
||||
if perform_initialization:
|
||||
_initialize_affine_weight_gpu(self.weight, init_method,
|
||||
partition_dim=0, stride=1)
|
||||
self.weight = Parameter(torch.empty(
|
||||
self.num_embeddings_per_partition, self.embedding_dim,
|
||||
device=torch.cuda.current_device(), dtype=params_dtype))
|
||||
|
||||
def forward(self, input_):
|
||||
if self.tensor_model_parallel_size > 1:
|
||||
@ -238,9 +172,12 @@ class ColumnParallelLinear(torch.nn.Module):
|
||||
skip_bias_add=False,
|
||||
params_dtype=None,
|
||||
use_cpu_initialization=False,
|
||||
perform_initialization=True,
|
||||
perform_initialization=False,
|
||||
quant_config=None,
|
||||
):
|
||||
super(ColumnParallelLinear, self).__init__()
|
||||
assert not perform_initialization
|
||||
assert not use_cpu_initialization
|
||||
|
||||
# Keep input parameters
|
||||
self.input_size = input_size
|
||||
@ -250,6 +187,7 @@ class ColumnParallelLinear(torch.nn.Module):
|
||||
self.world_size = get_tensor_model_parallel_world_size()
|
||||
self.output_size_per_partition = divide(output_size, self.world_size)
|
||||
self.skip_bias_add = skip_bias_add
|
||||
self.quant_config = quant_config
|
||||
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
@ -257,33 +195,13 @@ class ColumnParallelLinear(torch.nn.Module):
|
||||
# Parameters.
|
||||
# Note: torch.nn.functional.linear performs XA^T + b and as a result
|
||||
# we allocate the transpose.
|
||||
# Initialize weight.
|
||||
if use_cpu_initialization:
|
||||
self.weight = Parameter(torch.empty(self.output_size_per_partition,
|
||||
self.input_size,
|
||||
dtype=params_dtype))
|
||||
if perform_initialization:
|
||||
self.master_weight = _initialize_affine_weight_cpu(
|
||||
self.weight, self.output_size, self.input_size,
|
||||
self.output_size_per_partition, 0, init_method,
|
||||
stride=stride, return_master_weight=keep_master_weight_for_test)
|
||||
else:
|
||||
self.weight = Parameter(torch.empty(
|
||||
self.output_size_per_partition, self.input_size,
|
||||
device=torch.cuda.current_device(), dtype=params_dtype))
|
||||
if perform_initialization:
|
||||
_initialize_affine_weight_gpu(self.weight, init_method,
|
||||
partition_dim=0, stride=stride)
|
||||
self.create_weights(params_dtype)
|
||||
|
||||
if bias:
|
||||
if use_cpu_initialization:
|
||||
self.bias = Parameter(torch.empty(
|
||||
self.output_size_per_partition, dtype=params_dtype))
|
||||
else:
|
||||
self.bias = Parameter(torch.empty(
|
||||
self.output_size_per_partition,
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=params_dtype))
|
||||
self.bias = Parameter(torch.empty(
|
||||
self.output_size_per_partition,
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=params_dtype))
|
||||
set_tensor_model_parallel_attributes(self.bias, True, 0, stride)
|
||||
# Always initialize bias to zero.
|
||||
with torch.no_grad():
|
||||
@ -291,6 +209,17 @@ class ColumnParallelLinear(torch.nn.Module):
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
|
||||
def create_weights(self, dtype: torch.dtype) -> None:
|
||||
self.weight = Parameter(torch.empty(
|
||||
self.output_size_per_partition, self.input_size,
|
||||
device=torch.cuda.current_device(), dtype=dtype))
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
return F.linear(x, self.weight, bias)
|
||||
|
||||
def forward(self, input_):
|
||||
"""Forward of ColumnParallelLinear
|
||||
@ -306,7 +235,7 @@ class ColumnParallelLinear(torch.nn.Module):
|
||||
|
||||
input_parallel = input_
|
||||
# Matrix multiply.
|
||||
output_parallel = F.linear(input_parallel, self.weight, bias)
|
||||
output_parallel = self.apply_weights(input_parallel, bias)
|
||||
if self.gather_output:
|
||||
# All-gather across the partitions.
|
||||
output = gather_from_tensor_model_parallel_region(output_parallel)
|
||||
@ -359,10 +288,13 @@ class RowParallelLinear(torch.nn.Module):
|
||||
skip_bias_add=False,
|
||||
params_dtype=None,
|
||||
use_cpu_initialization=False,
|
||||
perform_initialization=True,
|
||||
perform_initialization=False,
|
||||
reduce_results=True,
|
||||
quant_config=None,
|
||||
):
|
||||
super(RowParallelLinear, self).__init__()
|
||||
assert not perform_initialization
|
||||
assert not use_cpu_initialization
|
||||
|
||||
# Keep input parameters
|
||||
self.input_size = input_size
|
||||
@ -376,47 +308,32 @@ class RowParallelLinear(torch.nn.Module):
|
||||
self.world_size = get_tensor_model_parallel_world_size()
|
||||
self.input_size_per_partition = divide(input_size, self.world_size)
|
||||
self.skip_bias_add = skip_bias_add
|
||||
self.quant_config = quant_config
|
||||
|
||||
self.create_weights(params_dtype)
|
||||
|
||||
if not reduce_results and (bias and not skip_bias_add):
|
||||
raise ValueError("When not reduce the results, adding bias to the "
|
||||
"results can lead to incorrect results")
|
||||
|
||||
# Parameters.
|
||||
# Note: torch.nn.functional.linear performs XA^T + b and as a result
|
||||
# we allocate the transpose.
|
||||
# Initialize weight.
|
||||
if use_cpu_initialization:
|
||||
self.weight = Parameter(torch.empty(self.output_size,
|
||||
self.input_size_per_partition,
|
||||
dtype=params_dtype))
|
||||
if perform_initialization:
|
||||
self.master_weight = _initialize_affine_weight_cpu(
|
||||
self.weight, self.output_size, self.input_size,
|
||||
self.input_size_per_partition, 1, init_method,
|
||||
stride=stride, return_master_weight=keep_master_weight_for_test,
|
||||
params_dtype=params_dtype)
|
||||
else:
|
||||
self.weight = Parameter(torch.empty(
|
||||
self.output_size, self.input_size_per_partition,
|
||||
device=torch.cuda.current_device(), dtype=params_dtype))
|
||||
if perform_initialization:
|
||||
_initialize_affine_weight_gpu(self.weight, init_method,
|
||||
partition_dim=1, stride=stride)
|
||||
if bias:
|
||||
if use_cpu_initialization:
|
||||
self.bias = Parameter(torch.empty(self.output_size,
|
||||
dtype=params_dtype))
|
||||
else:
|
||||
self.bias = Parameter(torch.empty(
|
||||
self.output_size, device=torch.cuda.current_device(),
|
||||
dtype=params_dtype))
|
||||
self.bias = Parameter(torch.empty(
|
||||
self.output_size, device=torch.cuda.current_device(),
|
||||
dtype=params_dtype))
|
||||
|
||||
# Always initialize bias to zero.
|
||||
with torch.no_grad():
|
||||
self.bias.zero_()
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
self.weight_t = self.weight.t()
|
||||
|
||||
def create_weights(self, dtype: torch.dtype) -> None:
|
||||
self.weight = Parameter(torch.empty(
|
||||
self.output_size, self.input_size_per_partition,
|
||||
device=torch.cuda.current_device(), dtype=dtype))
|
||||
|
||||
def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return F.linear(x, self.weight)
|
||||
|
||||
def forward(self, input_):
|
||||
"""Forward of RowParallelLinear
|
||||
@ -434,7 +351,7 @@ class RowParallelLinear(torch.nn.Module):
|
||||
else:
|
||||
input_parallel = scatter_to_tensor_model_parallel_region(input_)
|
||||
# Matrix multiply.
|
||||
output_parallel = F.linear(input_parallel, self.weight)
|
||||
output_parallel = self.apply_weights(input_parallel)
|
||||
if self.reduce_results and self.world_size > 1:
|
||||
output_ = reduce_from_tensor_model_parallel_region(output_parallel)
|
||||
else:
|
||||
|
20
vllm/model_executor/quantization_utils/__init__.py
Normal file
20
vllm/model_executor/quantization_utils/__init__.py
Normal file
@ -0,0 +1,20 @@
|
||||
from typing import Type
|
||||
|
||||
from vllm.model_executor.quantization_utils.awq import AWQConfig
|
||||
from vllm.model_executor.quantization_utils.base import QuantizationConfig
|
||||
|
||||
_QUANTIZATION_REGISTRY = {
|
||||
"awq": AWQConfig,
|
||||
}
|
||||
|
||||
|
||||
def get_quant_class(quantization: str) -> Type[QuantizationConfig]:
|
||||
if quantization not in _QUANTIZATION_REGISTRY:
|
||||
raise ValueError(f"Invalid quantization method: {quantization}")
|
||||
return _QUANTIZATION_REGISTRY[quantization]
|
||||
|
||||
|
||||
__all__ = [
|
||||
"QuantizationConfig",
|
||||
"get_quant_class",
|
||||
]
|
72
vllm/model_executor/quantization_utils/awq.py
Normal file
72
vllm/model_executor/quantization_utils/awq.py
Normal file
@ -0,0 +1,72 @@
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.quantization_utils.base import QuantizationConfig
|
||||
|
||||
|
||||
class AWQConfig(QuantizationConfig):
|
||||
"""Config class for AWQ.
|
||||
|
||||
Reference: https://arxiv.org/abs/2306.00978
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
zero_point: bool,
|
||||
) -> None:
|
||||
self.weight_bits = weight_bits
|
||||
self.group_size = group_size
|
||||
self.zero_point = zero_point
|
||||
|
||||
if self.weight_bits != 4:
|
||||
raise ValueError(
|
||||
"Currently, only 4-bit weight quantization is supported for "
|
||||
f"AWQ, but got {self.weight_bits} bits.")
|
||||
self.pack_factor = 32 // self.weight_bits
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"AWQConfig(weight_bits={self.weight_bits}, "
|
||||
f"group_size={self.group_size}, "
|
||||
f"zero_point={self.zero_point})")
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
return "awq"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
return [torch.half]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# The AWQ kernel only supports Ampere or newer GPUs.
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
return [
|
||||
"quant_config.json", # E.g., casperhansen/vicuna-7b-v1.5-awq
|
||||
"quantize_config.json", # E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq # pylint: disable=line-too-long
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "AWQConfig":
|
||||
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
|
||||
group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
|
||||
zero_point = cls.get_from_keys(config, ["zero_point"])
|
||||
return cls(weight_bits, group_size, zero_point)
|
||||
|
||||
@classmethod
|
||||
def get_packed_tensor_names(cls) -> List[str]:
|
||||
return ["qweight", "qzeros"]
|
||||
|
||||
@classmethod
|
||||
def get_transposed_tensor_names(cls) -> List[str]:
|
||||
return ["qweight", "qzeros", "scales"]
|
||||
|
||||
@classmethod
|
||||
def get_tp_tensor_names(cls) -> List[str]:
|
||||
return ["qweight", "qzeros", "scales"]
|
75
vllm/model_executor/quantization_utils/base.py
Normal file
75
vllm/model_executor/quantization_utils/base.py
Normal file
@ -0,0 +1,75 @@
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class QuantizationConfig:
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
"""Name of the quantization method."""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
"""List of supported activation dtypes."""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
"""Minimum GPU capability to support the quantization method.
|
||||
|
||||
E.g., 70 for Volta, 75 for Turing, 80 for Ampere.
|
||||
This requirement is due to the custom CUDA kernels used by the
|
||||
quantization method.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
"""List of filenames to search for in the model directory."""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "QuantizationConfig":
|
||||
"""Create a config class from the model's quantization config."""
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any:
|
||||
"""Get a value from the model's quantization config."""
|
||||
for key in keys:
|
||||
if key in config:
|
||||
return config[key]
|
||||
raise ValueError(f"Cannot find any of {keys} in the model's "
|
||||
"quantization config.")
|
||||
|
||||
@classmethod
|
||||
def get_packed_tensor_names(cls) -> List[str]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def is_packed(cls, tensor_name: str) -> bool:
|
||||
"""Returns True if a tensor is packed.
|
||||
|
||||
A tensor is considered packed if each element in the tensor is a
|
||||
packed representation of multiple elements in the original tensor.
|
||||
For example, an INT32 element in the tensor may represent 8 INT4
|
||||
elements in the original tensor.
|
||||
"""
|
||||
return any(tag in tensor_name for tag in cls.get_packed_tensor_names())
|
||||
|
||||
@classmethod
|
||||
def get_transposed_tensor_names(cls) -> List[str]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def is_transposed(cls, tensor_name: str) -> bool:
|
||||
"""Returns True if a tensor is transposed relative to nn.Linear.weight.
|
||||
"""
|
||||
return any(tag in tensor_name
|
||||
for tag in cls.get_transposed_tensor_names())
|
||||
|
||||
@classmethod
|
||||
def get_tp_tensor_names(cls) -> List[str]:
|
||||
raise NotImplementedError
|
@ -4,7 +4,7 @@ import glob
|
||||
import json
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from typing import Iterator, List, Optional, Tuple, Any
|
||||
from typing import Any, Iterator, List, Optional, Tuple
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from safetensors.torch import load_file, save_file, safe_open
|
||||
@ -13,6 +13,8 @@ import torch
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.quantization_utils import get_quant_class
|
||||
from vllm.model_executor.quantization_utils.base import QuantizationConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -44,7 +46,7 @@ def _shared_pointers(tensors):
|
||||
def convert_bin_to_safetensor_file(
|
||||
pt_filename: str,
|
||||
sf_filename: str,
|
||||
):
|
||||
) -> None:
|
||||
loaded = torch.load(pt_filename, map_location="cpu")
|
||||
if "state_dict" in loaded:
|
||||
loaded = loaded["state_dict"]
|
||||
@ -78,15 +80,55 @@ def convert_bin_to_safetensor_file(
|
||||
raise RuntimeError(f"The output tensors do not match for key {k}")
|
||||
|
||||
|
||||
# TODO(woosuk): Move this to other place.
|
||||
def get_quant_config(
|
||||
quantization: str,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
) -> QuantizationConfig:
|
||||
is_local = os.path.isdir(model_name_or_path)
|
||||
if not is_local:
|
||||
# Download the config files.
|
||||
with get_lock(model_name_or_path, cache_dir):
|
||||
hf_folder = snapshot_download(model_name_or_path,
|
||||
allow_patterns="*.json",
|
||||
cache_dir=cache_dir,
|
||||
tqdm_class=Disabledtqdm)
|
||||
else:
|
||||
hf_folder = model_name_or_path
|
||||
config_files = glob.glob(os.path.join(hf_folder, "*.json"))
|
||||
|
||||
quant_cls = get_quant_class(quantization)
|
||||
quant_config_files = [
|
||||
f for f in config_files if any(
|
||||
f.endswith(x) for x in quant_cls.get_config_filenames())
|
||||
]
|
||||
if len(quant_config_files) == 0:
|
||||
raise ValueError(f"Cannot find the config file for {quantization}")
|
||||
if len(quant_config_files) > 1:
|
||||
raise ValueError(f"Found multiple config files for {quantization}: "
|
||||
f"{quant_config_files}")
|
||||
|
||||
quant_config_file = quant_config_files[0]
|
||||
with open(quant_config_file, "r") as f:
|
||||
config = json.load(f)
|
||||
return quant_cls.from_config(config)
|
||||
|
||||
|
||||
def prepare_hf_model_weights(
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
use_safetensors: bool = False,
|
||||
fall_back_to_pt: bool = True,
|
||||
):
|
||||
revision: Optional[str] = None,
|
||||
) -> Tuple[str, List[str], bool]:
|
||||
# Download model weights from huggingface.
|
||||
is_local = os.path.isdir(model_name_or_path)
|
||||
allow_patterns = "*.safetensors" if use_safetensors else "*.bin"
|
||||
if use_safetensors:
|
||||
allow_patterns = ["*.safetensors"]
|
||||
else:
|
||||
# Some quantized models use .pt files for storing the weights.
|
||||
allow_patterns = ["*.bin", "*.pt"]
|
||||
if not is_local:
|
||||
# Use file lock to prevent multiple processes from
|
||||
# downloading the same model weights at the same time.
|
||||
@ -94,10 +136,13 @@ def prepare_hf_model_weights(
|
||||
hf_folder = snapshot_download(model_name_or_path,
|
||||
allow_patterns=allow_patterns,
|
||||
cache_dir=cache_dir,
|
||||
tqdm_class=Disabledtqdm)
|
||||
tqdm_class=Disabledtqdm,
|
||||
revision=revision)
|
||||
else:
|
||||
hf_folder = model_name_or_path
|
||||
hf_weights_files = glob.glob(os.path.join(hf_folder, allow_patterns))
|
||||
hf_weights_files: List[str] = []
|
||||
for pattern in allow_patterns:
|
||||
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
|
||||
if not use_safetensors:
|
||||
hf_weights_files = [
|
||||
x for x in hf_weights_files if not x.endswith("training_args.bin")
|
||||
@ -107,7 +152,8 @@ def prepare_hf_model_weights(
|
||||
return prepare_hf_model_weights(model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
use_safetensors=False,
|
||||
fall_back_to_pt=False)
|
||||
fall_back_to_pt=False,
|
||||
revision=revision)
|
||||
|
||||
if len(hf_weights_files) == 0:
|
||||
raise RuntimeError(
|
||||
@ -120,6 +166,7 @@ def hf_model_weights_iterator(
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None,
|
||||
) -> Iterator[Tuple[str, torch.Tensor]]:
|
||||
use_safetensors = False
|
||||
use_np_cache = False
|
||||
@ -140,7 +187,8 @@ def hf_model_weights_iterator(
|
||||
model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
use_safetensors=use_safetensors,
|
||||
fall_back_to_pt=fall_back_to_pt)
|
||||
fall_back_to_pt=fall_back_to_pt,
|
||||
revision=revision)
|
||||
|
||||
if use_np_cache:
|
||||
# Currently np_cache only support *.bin checkpoints
|
||||
|
@ -1,9 +1,17 @@
|
||||
"""Sampling parameters for text generation."""
|
||||
from enum import IntEnum
|
||||
from functools import cached_property
|
||||
from typing import List, Optional, Union
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
|
||||
|
||||
class SamplingType(IntEnum):
|
||||
GREEDY = 0
|
||||
RANDOM = 1
|
||||
BEAM = 2
|
||||
|
||||
|
||||
class SamplingParams:
|
||||
"""Sampling parameters for text generation.
|
||||
|
||||
@ -45,10 +53,15 @@ class SamplingParams:
|
||||
(canonical beam search algorithm).
|
||||
stop: List of strings that stop the generation when they are generated.
|
||||
The returned output will not contain the stop strings.
|
||||
stop_token_ids: List of tokens that stop the generation when they are
|
||||
generated. The returned output will contain the stop tokens unless
|
||||
the stop tokens are sepcial tokens.
|
||||
ignore_eos: Whether to ignore the EOS token and continue generating
|
||||
tokens after the EOS token is generated.
|
||||
max_tokens: Maximum number of tokens to generate per output sequence.
|
||||
logprobs: Number of log probabilities to return per output token.
|
||||
skip_special_tokens: Whether to skip special tokens in the output.
|
||||
Defaults to true.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -64,9 +77,11 @@ class SamplingParams:
|
||||
length_penalty: float = 1.0,
|
||||
early_stopping: Union[bool, str] = False,
|
||||
stop: Union[None, str, List[str]] = None,
|
||||
stop_token_ids: List[int] = None,
|
||||
ignore_eos: bool = False,
|
||||
max_tokens: int = 16,
|
||||
logprobs: Optional[int] = None,
|
||||
skip_special_tokens: bool = True,
|
||||
) -> None:
|
||||
self.n = n
|
||||
self.best_of = best_of if best_of is not None else n
|
||||
@ -84,9 +99,14 @@ class SamplingParams:
|
||||
self.stop = [stop]
|
||||
else:
|
||||
self.stop = list(stop)
|
||||
if stop_token_ids is None:
|
||||
self.stop_token_ids = []
|
||||
else:
|
||||
self.stop_token_ids = list(stop_token_ids)
|
||||
self.ignore_eos = ignore_eos
|
||||
self.max_tokens = max_tokens
|
||||
self.logprobs = logprobs
|
||||
self.skip_special_tokens = skip_special_tokens
|
||||
|
||||
self._verify_args()
|
||||
if self.use_beam_search:
|
||||
@ -158,6 +178,14 @@ class SamplingParams:
|
||||
if self.top_k != -1:
|
||||
raise ValueError("top_k must be -1 when using greedy sampling.")
|
||||
|
||||
@cached_property
|
||||
def sampling_type(self) -> SamplingType:
|
||||
if self.use_beam_search:
|
||||
return SamplingType.BEAM
|
||||
if self.temperature < _SAMPLING_EPS:
|
||||
return SamplingType.GREEDY
|
||||
return SamplingType.RANDOM
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"SamplingParams(n={self.n}, "
|
||||
f"best_of={self.best_of}, "
|
||||
@ -172,4 +200,5 @@ class SamplingParams:
|
||||
f"stop={self.stop}, "
|
||||
f"ignore_eos={self.ignore_eos}, "
|
||||
f"max_tokens={self.max_tokens}, "
|
||||
f"logprobs={self.logprobs})")
|
||||
f"logprobs={self.logprobs}, "
|
||||
f"skip_special_tokens={self.skip_special_tokens})")
|
||||
|
@ -114,7 +114,6 @@ class Sequence:
|
||||
|
||||
self.data = SequenceData(prompt_token_ids)
|
||||
self.output_logprobs: List[Dict[int, float]] = []
|
||||
self.output_tokens: List[str] = []
|
||||
self.output_text = ""
|
||||
|
||||
self.logical_token_blocks: List[LogicalTokenBlock] = []
|
||||
@ -122,6 +121,12 @@ class Sequence:
|
||||
self._append_tokens_to_blocks(prompt_token_ids)
|
||||
self.status = SequenceStatus.WAITING
|
||||
|
||||
# Used for incremental detokenization
|
||||
self.prefix_offset = 0
|
||||
self.read_offset = 0
|
||||
# Input + output tokens
|
||||
self.tokens: Optional[List[str]] = None
|
||||
|
||||
def _append_logical_block(self) -> None:
|
||||
block = LogicalTokenBlock(
|
||||
block_number=len(self.logical_token_blocks),
|
||||
@ -245,8 +250,8 @@ class SequenceGroup:
|
||||
# generation stage, we will have `best_of` sequences running.
|
||||
return self.sampling_params.best_of
|
||||
# At sampling stages, return the number of actual sequences
|
||||
# running.
|
||||
return self.num_seqs(status=SequenceStatus.RUNNING)
|
||||
# that are not finished yet.
|
||||
return self.num_unfinished_seqs()
|
||||
|
||||
def get_seqs(
|
||||
self,
|
||||
@ -259,12 +264,23 @@ class SequenceGroup:
|
||||
seq for seq in self.seqs_dict.values() if seq.status == status
|
||||
]
|
||||
|
||||
def get_unfinished_seqs(self) -> List[Sequence]:
|
||||
return [
|
||||
seq for seq in self.seqs_dict.values() if not seq.is_finished()
|
||||
]
|
||||
|
||||
def get_finished_seqs(self) -> List[Sequence]:
|
||||
return [seq for seq in self.seqs_dict.values() if seq.is_finished()]
|
||||
|
||||
def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
|
||||
return len(self.get_seqs(status))
|
||||
|
||||
def num_unfinished_seqs(self) -> int:
|
||||
return len(self.get_unfinished_seqs())
|
||||
|
||||
def num_finished_seqs(self) -> int:
|
||||
return len(self.get_finished_seqs())
|
||||
|
||||
def find(self, seq_id: int) -> Sequence:
|
||||
if seq_id not in self.seqs_dict:
|
||||
raise ValueError(f"Sequence {seq_id} not found.")
|
||||
@ -345,7 +361,7 @@ class SequenceOutputs:
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, SequenceOutputs):
|
||||
return NotImplementedError()
|
||||
raise NotImplementedError()
|
||||
return (self.parent_seq_id == other.parent_seq_id
|
||||
and self.output_token == other.output_token
|
||||
and self.logprobs == other.logprobs)
|
||||
|
@ -1,3 +1,5 @@
|
||||
from typing import Optional
|
||||
|
||||
from transformers import AutoConfig, PretrainedConfig
|
||||
|
||||
from vllm.transformers_utils.configs import * # pylint: disable=wildcard-import
|
||||
@ -12,10 +14,21 @@ _CONFIG_REGISTRY = {
|
||||
}
|
||||
|
||||
|
||||
def get_config(model: str, trust_remote_code: bool) -> PretrainedConfig:
|
||||
def get_config(model: str,
|
||||
trust_remote_code: bool,
|
||||
revision: Optional[str] = None) -> PretrainedConfig:
|
||||
# NOTE: Because the Mistral model in HF hub does not have
|
||||
# `configuration_mistral.py`, we cannot use `AutoConfig` to load the
|
||||
# config. Instead, we use `MistralConfig` directly.
|
||||
# NOTE: This is a hack. This does not work for local models.
|
||||
# FIXME: Remove this once the Mistral model is available in the stable
|
||||
# version of HF transformers.
|
||||
if "mistral" in model.lower():
|
||||
return MistralConfig.from_pretrained(model, revision=revision)
|
||||
|
||||
try:
|
||||
config = AutoConfig.from_pretrained(
|
||||
model, trust_remote_code=trust_remote_code)
|
||||
model, trust_remote_code=trust_remote_code, revision=revision)
|
||||
except ValueError as e:
|
||||
if (not trust_remote_code and
|
||||
"requires you to execute the configuration file" in str(e)):
|
||||
@ -29,5 +42,5 @@ def get_config(model: str, trust_remote_code: bool) -> PretrainedConfig:
|
||||
raise e
|
||||
if config.model_type in _CONFIG_REGISTRY:
|
||||
config_class = _CONFIG_REGISTRY[config.model_type]
|
||||
config = config_class.from_pretrained(model)
|
||||
config = config_class.from_pretrained(model, revision=revision)
|
||||
return config
|
||||
|
@ -6,6 +6,7 @@ from vllm.transformers_utils.configs.qwen import QWenConfig
|
||||
# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
|
||||
# `FalconConfig` class from the official HuggingFace transformers library.
|
||||
from vllm.transformers_utils.configs.falcon import RWConfig
|
||||
from vllm.transformers_utils.configs.mistral import MistralConfig
|
||||
|
||||
__all__ = [
|
||||
"MPTConfig",
|
||||
@ -13,4 +14,5 @@ __all__ = [
|
||||
"AquilaConfig",
|
||||
"QWenConfig",
|
||||
"RWConfig",
|
||||
"MistralConfig",
|
||||
]
|
||||
|
66
vllm/transformers_utils/configs/mistral.py
Normal file
66
vllm/transformers_utils/configs/mistral.py
Normal file
@ -0,0 +1,66 @@
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Mistral-7B-v0.1 configuration"""
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
|
||||
class MistralConfig(PretrainedConfig):
|
||||
model_type = "mistral"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=32000,
|
||||
hidden_size=4096,
|
||||
intermediate_size=14336,
|
||||
num_hidden_layers=32,
|
||||
num_attention_heads=32,
|
||||
num_key_value_heads=8,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=4096 * 32,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-6,
|
||||
use_cache=True,
|
||||
pad_token_id=None,
|
||||
bos_token_id=1,
|
||||
eos_token_id=2,
|
||||
tie_word_embeddings=False,
|
||||
rope_theta=10000.0,
|
||||
sliding_window=4096,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.sliding_window = sliding_window
|
||||
|
||||
# for backward compatibility
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
@ -7,65 +7,54 @@ from transformers import PretrainedConfig
|
||||
class QWenConfig(PretrainedConfig):
|
||||
model_type = "qwen"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
attribute_map = {
|
||||
"hidden_size": "n_embd",
|
||||
"num_attention_heads": "n_head",
|
||||
"max_position_embeddings": "n_positions",
|
||||
"num_hidden_layers": "n_layer",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=151851,
|
||||
n_embd=4096,
|
||||
n_layer=32,
|
||||
n_head=32,
|
||||
n_inner=None,
|
||||
embd_pdrop=0.0,
|
||||
attn_pdrop=0.0,
|
||||
layer_norm_epsilon=1e-5,
|
||||
vocab_size=151936,
|
||||
hidden_size=4096,
|
||||
num_hidden_layers=32,
|
||||
num_attention_heads=32,
|
||||
emb_dropout_prob=0.0,
|
||||
attn_dropout_prob=0.0,
|
||||
layer_norm_epsilon=1e-6,
|
||||
initializer_range=0.02,
|
||||
max_position_embeddings=8192,
|
||||
scale_attn_weights=True,
|
||||
use_cache=True,
|
||||
eos_token_id=151643,
|
||||
apply_residual_connection_post_layernorm=False,
|
||||
bf16=True,
|
||||
bf16=False,
|
||||
fp16=False,
|
||||
fp32=False,
|
||||
kv_channels=128,
|
||||
rotary_pct=1.0,
|
||||
rotary_emb_base=10000,
|
||||
use_dynamic_ntk=False,
|
||||
use_logn_attn=False,
|
||||
use_flash_attn=True,
|
||||
ffn_hidden_size=22016,
|
||||
use_dynamic_ntk=True,
|
||||
use_logn_attn=True,
|
||||
use_flash_attn="auto",
|
||||
intermediate_size=22016,
|
||||
no_bias=True,
|
||||
tie_word_embeddings=False,
|
||||
**kwargs,
|
||||
):
|
||||
self.eos_token_id = eos_token_id
|
||||
super().__init__(eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs)
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.n_embd = n_embd
|
||||
self.n_layer = n_layer
|
||||
self.n_head = n_head
|
||||
self.n_inner = n_inner
|
||||
self.embd_pdrop = embd_pdrop
|
||||
self.attn_pdrop = attn_pdrop
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.emb_dropout_prob = emb_dropout_prob
|
||||
self.attn_dropout_prob = attn_dropout_prob
|
||||
self.layer_norm_epsilon = layer_norm_epsilon
|
||||
self.initializer_range = initializer_range
|
||||
self.scale_attn_weights = scale_attn_weights
|
||||
self.use_cache = use_cache
|
||||
self.apply_residual_connection_post_layernorm = (
|
||||
apply_residual_connection_post_layernorm)
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.bf16 = bf16
|
||||
self.fp16 = fp16
|
||||
self.fp32 = fp32
|
||||
self.kv_channels = kv_channels
|
||||
self.rotary_pct = rotary_pct
|
||||
self.rotary_emb_base = rotary_emb_base
|
||||
self.use_dynamic_ntk = use_dynamic_ntk
|
||||
self.use_logn_attn = use_logn_attn
|
||||
self.use_flash_attn = use_flash_attn
|
||||
self.ffn_hidden_size = ffn_hidden_size
|
||||
self.no_bias = no_bias
|
||||
self.tie_word_embeddings = tie_word_embeddings
|
||||
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import List, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
from transformers import (AutoTokenizer, PreTrainedTokenizer,
|
||||
PreTrainedTokenizerFast)
|
||||
@ -28,8 +28,8 @@ def get_tokenizer(
|
||||
if ("llama" in tokenizer_name.lower() and kwargs.get("use_fast", True)
|
||||
and tokenizer_name != _FAST_LLAMA_TOKENIZER):
|
||||
logger.info(
|
||||
"For some LLaMA-based models, initializing the fast tokenizer may "
|
||||
"take a long time. To eliminate the initialization time, consider "
|
||||
"For some LLaMA V1 models, initializing the fast tokenizer may "
|
||||
"take a long time. To reduce the initialization time, consider "
|
||||
f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original "
|
||||
"tokenizer.")
|
||||
try:
|
||||
@ -41,9 +41,9 @@ def get_tokenizer(
|
||||
except TypeError as e:
|
||||
# The LLaMA tokenizer causes a protobuf error in some environments.
|
||||
err_msg = (
|
||||
"Failed to load the tokenizer. If you are using a LLaMA-based "
|
||||
f"model, use '{_FAST_LLAMA_TOKENIZER}' instead of the original "
|
||||
"tokenizer.")
|
||||
"Failed to load the tokenizer. If you are using a LLaMA V1 model "
|
||||
f"consider using '{_FAST_LLAMA_TOKENIZER}' instead of the "
|
||||
"original tokenizer.")
|
||||
raise RuntimeError(err_msg) from e
|
||||
except ValueError as e:
|
||||
# If the error pertains to the tokenizer class not existing or not
|
||||
@ -67,33 +67,11 @@ def get_tokenizer(
|
||||
return tokenizer
|
||||
|
||||
|
||||
def detokenize_incrementally(
|
||||
def _convert_tokens_to_string_with_added_encoders(
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||
prev_output_tokens: List[str],
|
||||
new_token_id: int,
|
||||
output_tokens: List[str],
|
||||
skip_special_tokens: bool,
|
||||
) -> Tuple[str, str]:
|
||||
"""Detokenizes the new token in conjunction with the previous output tokens.
|
||||
|
||||
NOTE: This function does not update prev_output_tokens.
|
||||
|
||||
Returns:
|
||||
new_token: The new token as a string.
|
||||
output_text: The new output text as a string.
|
||||
"""
|
||||
if skip_special_tokens and (new_token_id in tokenizer.all_special_ids):
|
||||
return None, prev_output_tokens
|
||||
new_token = tokenizer.convert_ids_to_tokens(
|
||||
new_token_id, skip_special_tokens=skip_special_tokens)
|
||||
output_tokens = prev_output_tokens + [new_token]
|
||||
|
||||
# Convert the tokens to a string.
|
||||
# Optimization: If the tokenizer does not have `added_tokens_encoder`,
|
||||
# then we can directly use `convert_tokens_to_string`.
|
||||
if not getattr(tokenizer, "added_tokens_encoder", {}):
|
||||
output_text = tokenizer.convert_tokens_to_string(output_tokens)
|
||||
return new_token, output_text
|
||||
|
||||
) -> str:
|
||||
# Adapted from
|
||||
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921
|
||||
# NOTE(woosuk): The following code is slow because it runs a for loop over
|
||||
@ -115,5 +93,61 @@ def detokenize_incrementally(
|
||||
if current_sub_text:
|
||||
sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
|
||||
sub_texts.append(sub_text)
|
||||
output_text = " ".join(sub_texts)
|
||||
return new_token, output_text
|
||||
return " ".join(sub_texts)
|
||||
|
||||
|
||||
# Based on
|
||||
# https://github.com/huggingface/text-generation-inference/blob/v0.9.4/server/text_generation_server/models/model.py#L62C9-L62C15
|
||||
# under Apache 2.0 license
|
||||
def detokenize_incrementally(
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||
all_input_ids: List[int],
|
||||
prev_tokens: Optional[List[str]],
|
||||
prefix_offset: int = 0,
|
||||
read_offset: int = 0,
|
||||
skip_special_tokens: bool = False,
|
||||
) -> Tuple[List[str], str, int, int]:
|
||||
new_token_id = all_input_ids[-1]
|
||||
# This is the first iteration for this sequence
|
||||
if prev_tokens is None:
|
||||
new_tokens = tokenizer.convert_ids_to_tokens(
|
||||
all_input_ids, skip_special_tokens=skip_special_tokens)
|
||||
output_tokens = new_tokens
|
||||
# 5 is an arbitrary value that should work for all
|
||||
# tokenizers (bigger = more conservative).
|
||||
# Subtract 1 extra to account for the generated token.
|
||||
prefix_offset = max(len(output_tokens) - 6, 0)
|
||||
read_offset = max(len(output_tokens) - 1, 0)
|
||||
else:
|
||||
# Put new_token_id in a list so skip_special_tokens is respected
|
||||
new_tokens = tokenizer.convert_ids_to_tokens(
|
||||
[new_token_id], skip_special_tokens=skip_special_tokens)
|
||||
output_tokens = prev_tokens + new_tokens
|
||||
|
||||
# The prefix text is necessary only to defeat cleanup algorithms in
|
||||
# the decode which decide to add a space or not depending on the
|
||||
# surrounding ids.
|
||||
if not getattr(tokenizer, "added_tokens_encoder", {}):
|
||||
prefix_text = tokenizer.convert_tokens_to_string(
|
||||
output_tokens[prefix_offset:read_offset])
|
||||
new_text = tokenizer.convert_tokens_to_string(
|
||||
output_tokens[prefix_offset:])
|
||||
else:
|
||||
prefix_text = _convert_tokens_to_string_with_added_encoders(
|
||||
tokenizer,
|
||||
output_tokens[prefix_offset:read_offset],
|
||||
skip_special_tokens=skip_special_tokens)
|
||||
new_text = _convert_tokens_to_string_with_added_encoders(
|
||||
tokenizer,
|
||||
output_tokens[prefix_offset:],
|
||||
skip_special_tokens=skip_special_tokens)
|
||||
|
||||
if len(new_text) > len(prefix_text) and not new_text.endswith("<EFBFBD>"):
|
||||
# utf-8 char at the end means it's a potential unfinished byte sequence
|
||||
# from byte fallback tokenization.
|
||||
# If it's in the middle, it's probably a real invalid id generated
|
||||
# by the model
|
||||
new_text = new_text[len(prefix_text):]
|
||||
return new_tokens, new_text, read_offset, len(output_tokens)
|
||||
else:
|
||||
return new_tokens, "", prefix_offset, read_offset
|
||||
|
@ -1,10 +1,12 @@
|
||||
import enum
|
||||
from platform import uname
|
||||
import uuid
|
||||
from platform import uname
|
||||
|
||||
import psutil
|
||||
import torch
|
||||
|
||||
from vllm import cuda_utils
|
||||
|
||||
|
||||
class Device(enum.Enum):
|
||||
GPU = enum.auto()
|
||||
@ -25,6 +27,15 @@ class Counter:
|
||||
self.counter = 0
|
||||
|
||||
|
||||
def get_max_shared_memory_bytes(gpu: int = 0) -> int:
|
||||
"""Returns the maximum shared memory per thread block in bytes."""
|
||||
# https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
|
||||
cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 # pylint: disable=invalid-name
|
||||
max_shared_mem = cuda_utils.get_device_attribute(
|
||||
cudaDevAttrMaxSharedMemoryPerBlockOptin, gpu)
|
||||
return int(max_shared_mem)
|
||||
|
||||
|
||||
def get_gpu_memory(gpu: int = 0) -> int:
|
||||
"""Returns the total memory of the GPU in bytes."""
|
||||
return torch.cuda.get_device_properties(gpu).total_memory
|
||||
|
@ -33,7 +33,7 @@ class CacheEngine:
|
||||
|
||||
self.head_size = model_config.get_head_size()
|
||||
self.num_layers = model_config.get_num_layers(parallel_config)
|
||||
self.num_heads = model_config.get_num_heads(parallel_config)
|
||||
self.num_heads = model_config.get_num_kv_heads(parallel_config)
|
||||
self.dtype = model_config.dtype
|
||||
|
||||
self.block_size = cache_config.block_size
|
||||
@ -146,7 +146,7 @@ class CacheEngine:
|
||||
parallel_config: ParallelConfig,
|
||||
) -> int:
|
||||
head_size = model_config.get_head_size()
|
||||
num_heads = model_config.get_num_heads(parallel_config)
|
||||
num_heads = model_config.get_num_kv_heads(parallel_config)
|
||||
num_layers = model_config.get_num_layers(parallel_config)
|
||||
|
||||
key_cache_block = block_size * num_heads * head_size
|
||||
|
@ -13,7 +13,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
|
||||
from vllm.worker.cache_engine import CacheEngine
|
||||
from vllm.utils import get_gpu_memory
|
||||
from vllm.utils import get_gpu_memory, get_max_shared_memory_bytes
|
||||
|
||||
|
||||
class Worker:
|
||||
@ -42,6 +42,7 @@ class Worker:
|
||||
# self.init_cache_engine().
|
||||
self.cache_config = None
|
||||
self.block_size = None
|
||||
self.sliding_window = None
|
||||
self.cache_engine = None
|
||||
self.cache_events = None
|
||||
self.gpu_cache = None
|
||||
@ -136,6 +137,15 @@ class Worker:
|
||||
def init_cache_engine(self, cache_config: CacheConfig) -> None:
|
||||
self.cache_config = cache_config
|
||||
self.block_size = cache_config.block_size
|
||||
self.sliding_window = cache_config.sliding_window
|
||||
|
||||
if self.sliding_window is None:
|
||||
max_seq_len = self.scheduler_config.max_model_len
|
||||
else:
|
||||
max_seq_len = min(self.scheduler_config.max_model_len,
|
||||
self.sliding_window)
|
||||
_check_if_can_support_max_seq_len(max_seq_len, self.block_size)
|
||||
|
||||
self.cache_engine = CacheEngine(self.cache_config, self.model_config,
|
||||
self.parallel_config)
|
||||
self.cache_events = self.cache_engine.events
|
||||
@ -207,10 +217,11 @@ class Worker:
|
||||
|
||||
context_len = seq_data.get_len()
|
||||
position = context_len - 1
|
||||
if self.sliding_window is not None:
|
||||
context_len = min(context_len, self.sliding_window)
|
||||
input_positions.append(position)
|
||||
|
||||
block_table = seq_group_metadata.block_tables[seq_id]
|
||||
generation_block_tables.append(block_table)
|
||||
|
||||
max_context_len = max(max_context_len, context_len)
|
||||
max_num_blocks_per_seq = max(max_num_blocks_per_seq,
|
||||
@ -222,21 +233,37 @@ class Worker:
|
||||
slot = block_number * self.block_size + block_offset
|
||||
slot_mapping.append(slot)
|
||||
|
||||
if self.sliding_window is not None:
|
||||
sliding_window_blocks = (self.sliding_window //
|
||||
self.block_size)
|
||||
block_table = block_table[-sliding_window_blocks:]
|
||||
generation_block_tables.append(block_table)
|
||||
|
||||
# Optimization: Pad the input length to be a multiple of 8.
|
||||
# This is required for utilizing the Tensor Cores in NVIDIA GPUs.
|
||||
input_tokens = _pad_to_alignment(input_tokens, multiple_of=8)
|
||||
input_positions = _pad_to_alignment(input_positions, multiple_of=8)
|
||||
|
||||
# Convert to tensors.
|
||||
tokens_tensor = torch.cuda.LongTensor(input_tokens)
|
||||
positions_tensor = torch.cuda.LongTensor(input_positions)
|
||||
slot_mapping_tensor = torch.cuda.IntTensor(slot_mapping)
|
||||
context_lens_tensor = torch.cuda.IntTensor(context_lens)
|
||||
tokens_tensor = torch.tensor(input_tokens,
|
||||
dtype=torch.long,
|
||||
device="cuda")
|
||||
positions_tensor = torch.tensor(input_positions,
|
||||
dtype=torch.long,
|
||||
device="cuda")
|
||||
slot_mapping_tensor = torch.tensor(slot_mapping,
|
||||
dtype=torch.int,
|
||||
device="cuda")
|
||||
context_lens_tensor = torch.tensor(context_lens,
|
||||
dtype=torch.int,
|
||||
device="cuda")
|
||||
padded_block_tables = [
|
||||
_pad_to_max(block_table, max_num_blocks_per_seq)
|
||||
for block_table in generation_block_tables
|
||||
]
|
||||
block_tables_tensor = torch.cuda.IntTensor(padded_block_tables)
|
||||
block_tables_tensor = torch.tensor(padded_block_tables,
|
||||
dtype=torch.int,
|
||||
device="cuda")
|
||||
|
||||
seq_data: Dict[int, SequenceData] = {}
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
@ -250,6 +277,7 @@ class Worker:
|
||||
context_lens=context_lens_tensor,
|
||||
max_context_len=max_context_len,
|
||||
block_tables=block_tables_tensor,
|
||||
sliding_window=self.sliding_window,
|
||||
)
|
||||
return tokens_tensor, positions_tensor, input_metadata
|
||||
|
||||
@ -337,3 +365,23 @@ def _pad_to_alignment(x: List[int], multiple_of: int) -> List[int]:
|
||||
|
||||
def _pad_to_max(x: List[int], max_len: int) -> List[int]:
|
||||
return x + [0] * (max_len - len(x))
|
||||
|
||||
|
||||
def _check_if_can_support_max_seq_len(max_seq_len: int,
|
||||
block_size: int) -> None:
|
||||
# Follows the logic in
|
||||
# attention_kernels.cu::single_query_cached_kv_attention_launcher
|
||||
max_shared_mem = get_max_shared_memory_bytes()
|
||||
float32_bytes = torch.finfo(torch.float).bits // 8
|
||||
padded_max_seq_len = (
|
||||
(max_seq_len + block_size - 1) / block_size) * block_size
|
||||
# padded_max_seq_len + extra buffer
|
||||
required_shared_mem = (padded_max_seq_len + 512) * float32_bytes
|
||||
if padded_max_seq_len * float32_bytes > max_shared_mem:
|
||||
raise RuntimeError(
|
||||
f"vLLM cannot currently support max_model_len={max_seq_len} "
|
||||
f"with block_size={block_size} on GPU with compute "
|
||||
f"capability {torch.cuda.get_device_capability()} "
|
||||
f"(required shared memory {required_shared_mem} > "
|
||||
f"available shared memory {max_shared_mem}). "
|
||||
"This will be fixed in a future release.")
|
||||
|
Reference in New Issue
Block a user