Compare commits

...

45 Commits

Author SHA1 Message Date
aa84c92ef6 Bump up version to 0.1.3 (#657) 2023-08-02 16:46:53 -07:00
f7389f4763 [Doc] Add Baichuan 13B to supported models (#656) 2023-08-02 16:45:12 -07:00
55fe8a81ec Refactor scheduler (#658) 2023-08-02 16:42:01 -07:00
e8ddc08ec8 [BUG FIX] upgrade fschat version to 0.2.23 (#650)
Co-authored-by: hao.yu <hao.yu@cn-c017.server.mila.quebec>
2023-08-02 14:05:59 -07:00
1b0bd0fe8a Add Falcon support (new) (#592) 2023-08-02 14:04:39 -07:00
20044cab7a Fix log message in scheduler (#652) 2023-08-02 13:35:10 -07:00
64f23c2900 fix baichuan for different position embedding for 7b and 13b models (#643) 2023-08-01 22:22:51 -07:00
d4c7755ca8 fix biachuan-7b tp (#598)
Co-authored-by: wq.chu <wq.chu@tianrang-inc.com>
2023-08-01 15:41:36 -07:00
aa39e42c5a fix doc (#622) 2023-07-31 13:11:57 -07:00
953f28cf9a fix ModuleNotFoundError (#599)
Co-authored-by: fangli <fangli@tencent.com>
2023-07-29 20:52:41 -07:00
c0d00f5be6 [Fix] fix import error of RayWorker (#604) (#605) 2023-07-27 23:37:40 -07:00
58a072be15 [Fix] Add model sequence length into model config (#575) 2023-07-25 23:46:30 -07:00
82ad323dee [Fix] Add chat completion Example and simplify dependencies (#576) 2023-07-25 23:45:48 -07:00
df5dd3c68e Add Baichuan-7B to README (#494) 2023-07-25 15:25:12 -07:00
2d867b55fa fixed tensor parallel is not defined (#564) 2023-07-25 14:16:51 -07:00
d7a1c6d614 Fix paged attention testing. (#495)
Signed-off-by: Tao Peng <jiankeng.pt@alibaba-inc.com>
2023-07-24 21:01:56 -07:00
7d5a155e4a [Fix] Fix GPTBigcoder for distributed execution (#503) 2023-07-24 18:36:33 -07:00
1dde34e0f8 GPTJConfig has no attribute rotary. (#532) 2023-07-24 11:29:30 -07:00
6fc2a38b11 Add support for LLaMA-2 (#505) 2023-07-20 11:38:27 -07:00
c487a221ee Fix bad assert in initialize_cluster if PG already exists (#526) 2023-07-19 23:17:12 -07:00
9925c17940 Ray placement group support (#397) 2023-07-19 22:49:31 -07:00
8c4b2592fb fix: enable trust-remote-code in api server & benchmark. (#509) 2023-07-19 17:06:15 -07:00
WRH
cf21a9bd5c support trust_remote_code in benchmark (#518) 2023-07-19 17:02:40 -07:00
16c3e295a8 fix(ray_utils): ignore re-init error (#465) 2023-07-19 17:01:19 -07:00
bda41c70dd hotfix attn alibi wo head mapping (#496)
Co-authored-by: oliveryuan <oliveryuan@basemind.com>
2023-07-18 11:31:48 -07:00
453bafb96f Merge pull request #498 from MoeedDar/main
Fixed old name reference for max_seq_len
2023-07-18 09:22:56 -07:00
328d231c17 Fixed old name reference for max_seq_len 2023-07-18 16:47:59 +01:00
b4b195b360 fix max seq len (#489) 2023-07-17 23:20:20 -07:00
20b0d88d16 Add support for baichuan (#365) 2023-07-17 13:50:55 -07:00
2bdea7ac11 [Fix] Fix the condition of max_seq_len (#477) 2023-07-17 00:33:48 -04:00
58df2883cb [Doc] Add doc for running vLLM on the cloud (#426)
Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
2023-07-16 13:37:14 -07:00
6d7d95a70a Offload port selection to OS (#467) 2023-07-15 23:11:02 -07:00
96853af5a8 Optimize MQA Kernel (#452) 2023-07-14 20:06:40 -04:00
dbed69058c Fix the KeyError when loading bloom-based models (#441) 2023-07-13 21:58:09 -07:00
7b6ae94059 add vocab padding for LLama(Support WizardLM) (#411) 2023-07-13 23:56:22 -04:00
c6dfc3cdbe Fix handling of special tokens in decoding. (#418) 2023-07-12 11:14:56 -04:00
51be365143 fix: freeze pydantic to v1 (#429) 2023-07-12 11:10:55 -04:00
c894836108 [Model] Add support for GPT-J (#226)
Co-authored-by: woWoosuk Kwon <woosuk.kwon@berkeley.edu>
2023-07-08 17:55:16 -07:00
75beba29b5 Don't try to load training_args.bin (#373) 2023-07-08 15:26:28 -07:00
ddfdf470ae Add trust_remote_code arg to get_config (#405) 2023-07-08 15:24:17 -07:00
b6fbb9a565 Sort the outputs before return (#402) 2023-07-08 14:48:18 -07:00
2179e4f4c5 avoid python list copy in sequence initialization (#401) 2023-07-08 12:42:08 -07:00
a945fcc2ae Add trust-remote-code flag to handle remote tokenizers (#364) 2023-07-07 11:04:58 -07:00
be54f8e5c4 [Fix] Change /generate response-type to json for non-streaming (#374) 2023-07-06 18:15:17 -07:00
b396cb4998 fix: only response [DONE] once when streaming response. (#378) 2023-07-06 18:08:40 -07:00
48 changed files with 2300 additions and 574 deletions

View File

@ -17,6 +17,7 @@ Easy, fast, and cheap LLM serving for everyone
---
*Latest News* 🔥
- [2023/07] Added support for LLaMA-2! You can run and serve 7B/13B/70B LLaMA-2s on vLLM with a single command!
- [2023/06] Serving vLLM On any Cloud with SkyPilot. Check out a 1-click [example](https://github.com/skypilot-org/skypilot/blob/master/llm/vllm) to start the vLLM demo, and the [blog post](https://blog.skypilot.co/serving-llm-24x-faster-on-the-cloud-with-vllm-and-skypilot/) for the story behind vLLM development on the clouds.
- [2023/06] We officially released vLLM! FastChat-vLLM integration has powered [LMSYS Vicuna and Chatbot Arena](https://chat.lmsys.org) since mid-April. Check out our [blog post](https://vllm.ai).
@ -41,11 +42,14 @@ vLLM is flexible and easy to use with:
vLLM seamlessly supports many Huggingface models, including the following architectures:
- Baichuan (`baichuan-inc/Baichuan-7B`, `baichuan-inc/Baichuan-13B-Chat`, etc.)
- BLOOM (`bigscience/bloom`, `bigscience/bloomz`, etc.)
- Falcon (`tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc.)
- GPT-2 (`gpt2`, `gpt2-xl`, etc.)
- GPT BigCode (`bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, etc.)
- GPT-J (`EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc.)
- GPT-NeoX (`EleutherAI/gpt-neox-20b`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc.)
- LLaMA (`lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.)
- LLaMA & LLaMA-2 (`meta-llama/Llama-2-70b-hf`, `lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.)
- MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.)
- OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.)

View File

@ -21,6 +21,7 @@ def main(args: argparse.Namespace):
tensor_parallel_size=args.tensor_parallel_size,
max_num_seqs=args.batch_size,
max_num_batched_tokens=args.batch_size * args.input_len,
trust_remote_code=args.trust_remote_code,
)
sampling_params = SamplingParams(
@ -74,5 +75,7 @@ if __name__ == '__main__':
parser.add_argument('--use-beam-search', action='store_true')
parser.add_argument('--num-iters', type=int, default=3,
help='Number of iterations to run.')
parser.add_argument('--trust-remote-code', action='store_true',
help='trust remote code from huggingface')
args = parser.parse_args()
main(args)

View File

@ -177,7 +177,7 @@ def main(args: argparse.Namespace):
np.random.seed(args.seed)
api_url = f"http://{args.host}:{args.port}/generate"
tokenizer = get_tokenizer(args.tokenizer)
tokenizer = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code)
input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
benchmark_start_time = time.time()
@ -227,5 +227,7 @@ if __name__ == "__main__":
"Otherwise, we use Poisson process to synthesize "
"the request arrival times.")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument('--trust-remote-code', action='store_true',
help='trust remote code from huggingface')
args = parser.parse_args()
main(args)

View File

@ -67,12 +67,14 @@ def run_vllm(
seed: int,
n: int,
use_beam_search: bool,
trust_remote_code: bool,
) -> float:
llm = LLM(
model=model,
tokenizer=tokenizer,
tensor_parallel_size=tensor_parallel_size,
seed=seed,
trust_remote_code=trust_remote_code,
)
# Add the requests to the engine.
@ -106,9 +108,11 @@ def run_hf(
n: int,
use_beam_search: bool,
max_batch_size: int,
trust_remote_code: bool,
) -> float:
assert not use_beam_search
llm = AutoModelForCausalLM.from_pretrained(model, torch_dtype=torch.float16)
llm = AutoModelForCausalLM.from_pretrained(model,
torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
if llm.config.model_type == "llama":
# To enable padding in the HF backend.
tokenizer.pad_token = tokenizer.eos_token
@ -161,17 +165,18 @@ def main(args: argparse.Namespace):
random.seed(args.seed)
# Sample the requests.
tokenizer = get_tokenizer(args.tokenizer)
tokenizer = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code)
requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
if args.backend == "vllm":
elapsed_time = run_vllm(
requests, args.model, args.tokenizer, args.tensor_parallel_size,
args.seed, args.n, args.use_beam_search)
args.seed, args.n, args.use_beam_search, args.trust_remote_code)
elif args.backend == "hf":
assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
args.use_beam_search, args.hf_max_batch_size)
elapsed_time = run_hf(
requests, args.model, tokenizer, args.n, args.use_beam_search,
args.hf_max_batch_size, args.trust_remote_code)
else:
raise ValueError(f"Unknown backend: {args.backend}")
total_num_tokens = sum(
@ -199,6 +204,9 @@ if __name__ == "__main__":
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--hf-max-batch-size", type=int, default=None,
help="Maximum batch size for HF backend.")
parser.add_argument('--trust-remote-code',
action='store_true',
help='trust remote code from huggingface')
args = parser.parse_args()
if args.backend == "vllm":

View File

@ -6,6 +6,7 @@ void single_query_cached_kv_attention(
torch::Tensor& query,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
torch::Tensor& head_mapping,
float scale,
torch::Tensor& block_tables,
torch::Tensor& context_lens,

View File

@ -74,14 +74,17 @@ template<
__global__ void single_query_cached_kv_attention_kernel(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ k_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
const scalar_t* __restrict__ v_cache, // [num_blocks, num_heads, head_size, block_size]
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
const int* __restrict__ head_mapping, // [num_heads]
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride) {
const int q_stride,
const int kv_block_stride,
const int kv_head_stride) {
constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
constexpr int NUM_TOKENS_PER_THREAD_GROUP = (BLOCK_SIZE + WARP_SIZE - 1) / WARP_SIZE;
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
@ -91,6 +94,7 @@ __global__ void single_query_cached_kv_attention_kernel(
const int head_idx = blockIdx.x;
const int num_heads = gridDim.x;
const int kv_head_idx = head_mapping[head_idx];
const int seq_idx = blockIdx.y;
const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
@ -158,8 +162,8 @@ __global__ void single_query_cached_kv_attention_kernel(
#pragma unroll
for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
const scalar_t* k_ptr = k_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE
+ head_idx * HEAD_SIZE * BLOCK_SIZE
const scalar_t* k_ptr = k_cache + physical_block_number * kv_block_stride
+ kv_head_idx * kv_head_stride
+ physical_block_offset * x;
const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
const int offset1 = (vec_idx * VEC_SIZE) / x;
@ -246,8 +250,8 @@ __global__ void single_query_cached_kv_attention_kernel(
L_vec logits_vec;
from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx));
const scalar_t* v_ptr = v_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE
+ head_idx * HEAD_SIZE * BLOCK_SIZE;
const scalar_t* v_ptr = v_cache + physical_block_number * kv_block_stride
+ kv_head_idx * kv_head_stride;
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
@ -328,12 +332,15 @@ __global__ void single_query_cached_kv_attention_kernel(
query_ptr, \
key_cache_ptr, \
value_cache_ptr, \
head_mapping_ptr, \
scale, \
block_tables_ptr, \
context_lens_ptr, \
max_num_blocks_per_seq, \
alibi_slopes_ptr, \
query_stride);
q_stride, \
kv_block_stride, \
kv_head_stride);
// TODO(woosuk): Tune NUM_THREADS.
template<
@ -345,6 +352,7 @@ void single_query_cached_kv_attention_launcher(
torch::Tensor& query,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
torch::Tensor& head_mapping,
float scale,
torch::Tensor& block_tables,
torch::Tensor& context_lens,
@ -354,7 +362,9 @@ void single_query_cached_kv_attention_launcher(
int num_heads = query.size(1);
int head_size = query.size(2);
int max_num_blocks_per_seq = block_tables.size(1);
int query_stride = query.stride(0);
int q_stride = query.stride(0);
int kv_block_stride = key_cache.stride(0);
int kv_head_stride = key_cache.stride(1);
int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
assert(head_size % thread_group_size == 0);
@ -368,6 +378,7 @@ void single_query_cached_kv_attention_launcher(
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
int* head_mapping_ptr = reinterpret_cast<int*>(head_mapping.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>();
int* context_lens_ptr = context_lens.data_ptr<int>();
@ -382,7 +393,7 @@ void single_query_cached_kv_attention_launcher(
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
switch (head_size) {
// NOTE(woosuk): To reduce the compilation time, we omitted head sizes
// 32, 160, 192, 256.
// 32, 160, 192.
// case 32:
// LAUNCH_ATTENTION_KERNEL(T, 32, BLOCK_SIZE, NUM_THREADS);
// break;
@ -407,9 +418,9 @@ void single_query_cached_kv_attention_launcher(
// case 192:
// LAUNCH_ATTENTION_KERNEL(T, 192, BLOCK_SIZE, NUM_THREADS);
// break;
// case 256:
// LAUNCH_ATTENTION_KERNEL(T, 256, BLOCK_SIZE, NUM_THREADS);
// break;
case 256:
LAUNCH_ATTENTION_KERNEL(T, 256, BLOCK_SIZE, NUM_THREADS);
break;
default:
TORCH_CHECK(false, "Unsupported head size: ", head_size);
break;
@ -422,6 +433,7 @@ void single_query_cached_kv_attention_launcher(
query, \
key_cache, \
value_cache, \
head_mapping, \
scale, \
block_tables, \
context_lens, \
@ -469,6 +481,7 @@ void single_query_cached_kv_attention(
torch::Tensor& query, // [num_seqs, num_heads, head_size]
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
torch::Tensor& head_mapping, // [num_heads]
float scale,
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& context_lens, // [num_seqs]

View File

@ -7,11 +7,13 @@ template<typename scalar_t>
__global__ void rotary_embedding_neox_kernel(
const int64_t* __restrict__ positions, // [num_tokens]
scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size]
scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
scalar_t* __restrict__ key, // [num_tokens, num_kv_heads, head_size]
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
const int rot_dim,
const int stride,
const int query_stride,
const int key_stride,
const int num_heads,
const int num_kv_heads,
const int head_size) {
// Each thread block is responsible for one token.
const int token_idx = blockIdx.x;
@ -19,17 +21,17 @@ __global__ void rotary_embedding_neox_kernel(
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
const int embed_dim = rot_dim / 2;
const int n = num_heads * embed_dim;
for (int i = threadIdx.x; i < n; i += blockDim.x) {
const int nq = num_heads * embed_dim;
for (int i = threadIdx.x; i < nq; i += blockDim.x) {
const int head_idx = i / embed_dim;
const int token_head = token_idx * stride + head_idx * head_size;
const int token_head = token_idx * query_stride + head_idx * head_size;
const int rot_offset = i % embed_dim;
const int x_index = rot_offset;
const int y_index = embed_dim + rot_offset;
const int out_x = token_idx * stride + head_idx * head_size + x_index;
const int out_y = token_idx * stride + head_idx * head_size + y_index;
const int out_x = token_idx * query_stride + head_idx * head_size + x_index;
const int out_y = token_idx * query_stride + head_idx * head_size + y_index;
const scalar_t cos = __ldg(cache_ptr + x_index);
const scalar_t sin = __ldg(cache_ptr + y_index);
@ -38,6 +40,22 @@ __global__ void rotary_embedding_neox_kernel(
const scalar_t q_y = query[token_head + y_index];
query[out_x] = q_x * cos - q_y * sin;
query[out_y] = q_y * cos + q_x * sin;
}
const int nk = num_kv_heads * embed_dim;
for (int i = threadIdx.x; i < nk; i += blockDim.x) {
const int head_idx = i / embed_dim;
const int token_head = token_idx * key_stride + head_idx * head_size;
const int rot_offset = i % embed_dim;
const int x_index = rot_offset;
const int y_index = embed_dim + rot_offset;
const int out_x = token_idx * key_stride + head_idx * head_size + x_index;
const int out_y = token_idx * key_stride + head_idx * head_size + y_index;
const scalar_t cos = __ldg(cache_ptr + x_index);
const scalar_t sin = __ldg(cache_ptr + y_index);
const scalar_t k_x = key[token_head + x_index];
const scalar_t k_y = key[token_head + y_index];
@ -51,15 +69,16 @@ __global__ void rotary_embedding_neox_kernel(
void rotary_embedding_neox(
torch::Tensor& positions, // [num_tokens]
torch::Tensor& query, // [num_tokens, num_heads * head_size]
torch::Tensor& key, // [num_tokens, num_heads * head_size]
torch::Tensor& key, // [num_tokens, num_kv_heads * head_size]
int head_size,
torch::Tensor& cos_sin_cache) // [max_position, rot_dim]
{
int num_tokens = query.size(0);
int rot_dim = cos_sin_cache.size(1);
int num_heads = query.size(1) / head_size;
int stride = query.stride(0);
TORCH_CHECK(stride == key.stride(0));
int num_kv_heads = key.size(1) / head_size;
int query_stride = query.stride(0);
int key_stride = key.stride(0);
dim3 grid(num_tokens);
dim3 block(std::min(num_heads * rot_dim / 2, 512));
@ -76,8 +95,10 @@ void rotary_embedding_neox(
key.data_ptr<scalar_t>(),
cos_sin_cache.data_ptr<scalar_t>(),
rot_dim,
stride,
query_stride,
key_stride,
num_heads,
num_kv_heads,
head_size);
});
}

View File

@ -62,6 +62,7 @@ Documentation
:caption: Serving
serving/distributed_serving
serving/run_on_sky
.. toctree::
:maxdepth: 1

View File

@ -14,21 +14,30 @@ Alongside each architecture, we include some popular models that use it.
* - Architecture
- Models
- Example HuggingFace Models
* - :code:`BaiChuanForCausalLM`
- Baichuan
- :code:`baichuan-inc/Baichuan-7B`, `baichuan-inc/Baichuan-13B-Chat`, etc.
* - :code:`BloomForCausalLM`
- BLOOM, BLOOMZ, BLOOMChat
- :code:`bigscience/bloom`, :code:`bigscience/bloomz`, etc.
* - :code:`FalconForCausalLM`
- Falcon
- :code:`tiiuae/falcon-7b``, :code:`tiiuae/falcon-40b`, :code:`tiiuae/falcon-rw-7b`, etc.
* - :code:`GPT2LMHeadModel`
- GPT-2
- :code:`gpt2`, :code:`gpt2-xl`, etc.
* - :code:`GPTBigCodeForCausalLM`
- StarCoder, SantaCoder, WizardCoder
- :code:`bigcode/starcoder`, :code:`bigcode/gpt_bigcode-santacoder`, :code:`WizardLM/WizardCoder-15B-V1.0`, etc.
* - :code:`GPTJForCausalLM`
- GPT-J
- :code:`EleutherAI/gpt-j-6b`, :code:`nomic-ai/gpt4all-j`, etc.
* - :code:`GPTNeoXForCausalLM`
- GPT-NeoX, Pythia, OpenAssistant, Dolly V2, StableLM
- :code:`EleutherAI/gpt-neox-20b`, :code:`EleutherAI/pythia-12b`, :code:`OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5`, :code:`databricks/dolly-v2-12b`, :code:`stabilityai/stablelm-tuned-alpha-7b`, etc.
* - :code:`LlamaForCausalLM`
- LLaMA, Vicuna, Alpaca, Koala, Guanaco
- :code:`openlm-research/open_llama_13b`, :code:`lmsys/vicuna-13b-v1.3`, :code:`young-geng/koala`, :code:`JosephusCheung/Guanaco`, etc.
- LLaMA, LLaMA-2, Vicuna, Alpaca, Koala, Guanaco
- :code:`meta-llama/Llama-2-13b-hf`, :code:`openlm-research/open_llama_13b`, :code:`lmsys/vicuna-13b-v1.3`, :code:`young-geng/koala`, :code:`JosephusCheung/Guanaco`, etc.
* - :code:`MPTForCausalLM`
- MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter
- :code:`mosaicml/mpt-7b`, :code:`mosaicml/mpt-7b-storywriter`, :code:`mosaicml/mpt-30b`, etc.

View File

@ -0,0 +1,69 @@
.. _on_cloud:
Running on clouds with SkyPilot
===============================
.. raw:: html
<p align="center">
<img src="https://imgur.com/yxtzPEu.png" alt="vLLM"/>
</p>
vLLM can be run on the cloud to scale to multiple GPUs with `SkyPilot <https://github.com/skypilot-org/skypilot>`__, an open-source framework for running LLMs on any cloud.
To install SkyPilot and setup your cloud credentials, run:
.. code-block:: console
$ pip install skypilot
$ sky check
See the vLLM SkyPilot YAML for serving, `serving.yaml <https://github.com/skypilot-org/skypilot/blob/master/llm/vllm/serve.yaml>`__.
.. code-block:: yaml
resources:
accelerators: A100
envs:
MODEL_NAME: decapoda-research/llama-13b-hf
TOKENIZER: hf-internal-testing/llama-tokenizer
setup: |
conda create -n vllm python=3.9 -y
conda activate vllm
git clone https://github.com/vllm-project/vllm.git
cd vllm
pip install .
pip install gradio
run: |
conda activate vllm
echo 'Starting vllm api server...'
python -u -m vllm.entrypoints.api_server \
--model $MODEL_NAME \
--tensor-parallel-size $SKYPILOT_NUM_GPUS_PER_NODE \
--tokenizer $TOKENIZER 2>&1 | tee api_server.log &
echo 'Waiting for vllm api server to start...'
while ! `cat api_server.log | grep -q 'Uvicorn running on'`; do sleep 1; done
echo 'Starting gradio server...'
python vllm/examples/gradio_webserver.py
Start the serving the LLaMA-13B model on an A100 GPU:
.. code-block:: console
$ sky launch serving.yaml
Check the output of the command. There will be a sharable gradio link (like the last line of the following). Open it in your browser to use the LLaMA model to do the text completion.
.. code-block:: console
(task, pid=7431) Running on public URL: https://<gradio-hash>.gradio.live
**Optional**: Serve the 65B model instead of the default 13B and use more GPU:
.. code-block:: console
sky launch -c vllm-serve-new -s serve.yaml --gpus A100:8 --env MODEL_NAME=decapoda-research/llama-65b-hf

View File

@ -10,7 +10,8 @@ def main(args: argparse.Namespace):
# Test the following prompts.
test_prompts = [
("A robot may not injure a human being", SamplingParams()),
("A robot may not injure a human being",
SamplingParams(temperature=0.0)),
("To be or not to be,",
SamplingParams(temperature=0.8, top_k=5, presence_penalty=0.2)),
("What is the meaning of life?",
@ -27,7 +28,7 @@ def main(args: argparse.Namespace):
# Run the engine by calling `engine.step()` manually.
request_id = 0
while True:
# To test iteration-level scheduling, we add one request at each step.
# To test continuous batching, we add one request at each step.
if test_prompts:
prompt, sampling_params = test_prompts.pop(0)
engine.add_request(str(request_id), prompt, sampling_params)

View File

@ -0,0 +1,33 @@
import openai
# Modify OpenAI's API key and API base to use vLLM's API server.
openai.api_key = "EMPTY"
openai.api_base = "http://localhost:8000/v1"
# List models API
models = openai.Model.list()
print("Models:", models)
model = models["data"][0]["id"]
# Chat completion API
chat_completion = openai.ChatCompletion.create(
model=model,
messages=[{
"role": "system",
"content": "You are a helpful assistant."
}, {
"role": "user",
"content": "Who won the world series in 2020?"
}, {
"role":
"assistant",
"content":
"The Los Angeles Dodgers won the World Series in 2020."
}, {
"role": "user",
"content": "Where was it played?"
}])
print("Chat completion results:")
print(chat_completion)

View File

@ -3,26 +3,26 @@ import openai
# Modify OpenAI's API key and API base to use vLLM's API server.
openai.api_key = "EMPTY"
openai.api_base = "http://localhost:8000/v1"
model = "facebook/opt-125m"
# Test list models API
# List models API
models = openai.Model.list()
print("Models:", models)
# Test completion API
stream = True
model = models["data"][0]["id"]
# Completion API
stream = False
completion = openai.Completion.create(
model=model,
prompt="A robot may not injure a human being",
echo=False,
n=2,
best_of=3,
stream=stream,
logprobs=3)
# print the completion
print("Completion results:")
if stream:
for c in completion:
print(c)
else:
print("Completion result:", completion)
print(completion)

View File

@ -1,12 +1,11 @@
ninja # For faster builds.
psutil
ray
ray >= 2.5.1
sentencepiece # Required for LLaMA tokenizer.
numpy
torch >= 2.0.0
transformers >= 4.28.0 # Required for LLaMA.
transformers >= 4.31.0 # Required for LLaMA-2.
xformers >= 0.0.19
fastapi
uvicorn
pydantic # Required for OpenAI server.
fschat # Required for OpenAI ChatCompletion Endpoint.
pydantic < 2 # Required for OpenAI server.

View File

@ -164,6 +164,7 @@ def run_single_query_cached_kv_attention(
block_size: int,
num_blocks: int,
dtype: torch.dtype,
num_kv_heads: int = None,
) -> None:
qkv = torch.empty(num_tokens,
3,
@ -199,8 +200,17 @@ def run_single_query_cached_kv_attention(
]
block_tables.append(block_table)
block_tables = torch.tensor(block_tables, dtype=torch.int, device='cuda')
head_mapping = torch.arange(num_heads, dtype=torch.int32, device="cuda")
scale = float(1.0 / (head_size**0.5))
num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
assert num_heads % num_kv_heads == 0
num_queries_per_kv = num_heads // num_kv_heads
head_mapping = torch.repeat_interleave(
torch.arange(num_kv_heads, dtype=torch.int32, device="cuda"),
num_queries_per_kv)
output = torch.empty(num_tokens,
num_heads,
head_size,
@ -211,6 +221,7 @@ def run_single_query_cached_kv_attention(
query,
key_cache,
value_cache,
head_mapping,
scale,
block_tables,
context_lens,
@ -286,7 +297,7 @@ def test_single_query_cached_kv_attention() -> None:
torch.cuda.manual_seed(TEST_SEED)
for dtype in [torch.half, torch.bfloat16, torch.float]:
for block_size in [8, 16, 32]:
for head_size in [64, 80, 96, 128]:
for head_size in [64, 80, 96, 112, 128, 256]:
print(f'Testing single_query_cached_kv_attention with '
f'dtype={dtype}, block_size={block_size}, '
f'head_size={head_size}')
@ -304,7 +315,7 @@ def test_multi_query_kv_attention() -> None:
torch.random.manual_seed(TEST_SEED)
torch.cuda.manual_seed(TEST_SEED)
for dtype in [torch.half, torch.bfloat16, torch.float]:
for head_size in [64, 80, 96, 128]:
for head_size in [64, 80, 96, 112, 128, 256]:
print(f'Testing multi_query_kv_attention with dtype={dtype}, '
f'head_size={head_size}')
run_multi_query_kv_attention(

View File

@ -8,7 +8,7 @@ from vllm.entrypoints.llm import LLM
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import SamplingParams
__version__ = "0.1.2"
__version__ = "0.1.3"
__all__ = [
"LLM",

View File

@ -20,6 +20,8 @@ class ModelConfig:
tokenizer: Name or path of the huggingface tokenizer to use.
tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
available, and "slow" will always use the slow tokenizer.
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
downloading the model and tokenizer.
download_dir: Directory to download and load the weights, default to the
default cache directory of huggingface.
use_np_weights: Save a numpy copy of model weights for faster loading.
@ -36,6 +38,7 @@ class ModelConfig:
model: str,
tokenizer: str,
tokenizer_mode: str,
trust_remote_code: bool,
download_dir: Optional[str],
use_np_weights: bool,
use_dummy_weights: bool,
@ -45,12 +48,13 @@ class ModelConfig:
self.model = model
self.tokenizer = tokenizer
self.tokenizer_mode = tokenizer_mode
self.trust_remote_code = trust_remote_code
self.download_dir = download_dir
self.use_np_weights = use_np_weights
self.use_dummy_weights = use_dummy_weights
self.seed = seed
self.hf_config = get_config(model)
self.hf_config = get_config(model, trust_remote_code)
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
self._verify_tokenizer_mode()
@ -90,9 +94,46 @@ class ModelConfig:
return self.hf_config.hidden_size // self.hf_config.num_attention_heads
def get_num_heads(self, parallel_config: "ParallelConfig") -> int:
# For GPTBigCode & Falcon:
# Note: for falcon, when new_decoder_architecture is True, the
# multi_query flag is ignored and we use n_head_kv for the number of
# KV heads.
if (getattr(self.hf_config, "multi_query", False) and
(self.hf_config.model_type == "falcon" and
not getattr(self.hf_config, "new_decoder_architecture", False))):
# Multi-query attention, only one KV head.
return 1
# For Falcon:
if getattr(self.hf_config, "n_head_kv", None) is not None:
return (self.hf_config.n_head_kv //
parallel_config.tensor_parallel_size)
# For LLaMA-2:
if getattr(self.hf_config, "num_key_value_heads", None) is not None:
return (self.hf_config.num_key_value_heads //
parallel_config.tensor_parallel_size)
total_num_attention_heads = self.hf_config.num_attention_heads
return total_num_attention_heads // parallel_config.tensor_parallel_size
def get_max_model_len(self) -> int:
max_model_len = float("inf")
possible_keys = [
# OPT
"max_position_embeddings",
# GPT-2
"n_positions",
# MPT
"max_seq_len",
# Others
"max_sequence_length",
"max_seq_length",
"seq_len",
]
for key in possible_keys:
max_len_key = getattr(self.hf_config, key, None)
if max_len_key is not None:
max_model_len = min(max_model_len, max_len_key)
return max_model_len
def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
total_num_hidden_layers = self.hf_config.num_hidden_layers
return total_num_hidden_layers // parallel_config.pipeline_parallel_size
@ -188,15 +229,15 @@ class SchedulerConfig:
a single iteration.
max_num_seqs: Maximum number of sequences to be processed in a single
iteration.
max_seq_len: Maximum length of a sequence (including prompt
max_model_len: Maximum length of a sequence (including prompt
and generated text).
"""
def __init__(self, max_num_batched_tokens: int, max_num_seqs: int,
max_seq_len: int) -> None:
max_model_len: int) -> None:
self.max_num_batched_tokens = max_num_batched_tokens
self.max_num_seqs = max_num_seqs
self.max_seq_len = max_seq_len
self.max_model_len = max_model_len
_STR_DTYPE_TO_TORCH_DTYPE = {

View File

@ -12,8 +12,6 @@ from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
logger = init_logger(__name__)
_LOGGING_INTERVAL_SEC = 5
class PreemptionMode(enum.Enum):
"""Preemption modes.
@ -32,19 +30,28 @@ class SchedulerOutputs:
def __init__(
self,
scheduled_seq_groups: List[SequenceGroup],
prompt_run: bool,
num_batched_tokens: int,
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
ignored_seq_groups: List[SequenceGroup],
) -> None:
self.scheduled_seq_groups = scheduled_seq_groups
self.prompt_run = prompt_run
self.num_batched_tokens = num_batched_tokens
self.blocks_to_swap_in = blocks_to_swap_in
self.blocks_to_swap_out = blocks_to_swap_out
self.blocks_to_copy = blocks_to_copy
# Swap in and swap out should never happen at the same time.
assert not (blocks_to_swap_in and blocks_to_swap_out)
self.ignored_seq_groups = ignored_seq_groups
def is_empty(self) -> bool:
return (not self.blocks_to_swap_in and not self.blocks_to_swap_out
and not self.blocks_to_copy)
# NOTE: We do not consider the ignored sequence groups.
return (not self.scheduled_seq_groups and not self.blocks_to_swap_in
and not self.blocks_to_swap_out and not self.blocks_to_copy)
class Scheduler:
@ -53,11 +60,9 @@ class Scheduler:
self,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig,
log_stats: bool,
) -> None:
self.scheduler_config = scheduler_config
self.cache_config = cache_config
self.log_stats = log_stats
# Instantiate the scheduling policy.
self.policy = PolicyFactory.get_policy(policy_name="fcfs")
@ -75,10 +80,6 @@ class Scheduler:
# Sequence groups in the SWAPPED state.
self.swapped: List[SequenceGroup] = []
self.last_logging_time: float = 0.0
# List[timestamp, num_tokens]
self.num_input_tokens: List[Tuple[float, int]] = []
def add_seq_group(self, seq_group: SequenceGroup) -> None:
# Add sequence groups to the waiting queue.
self.waiting.append(seq_group)
@ -101,21 +102,80 @@ class Scheduler:
def get_num_unfinished_seq_groups(self) -> int:
return len(self.waiting) + len(self.running) + len(self.swapped)
def _schedule(
self) -> Tuple[SchedulerOutputs, List[str], List[SequenceGroup]]:
def _schedule(self) -> SchedulerOutputs:
# Blocks that need to be swaped or copied before model execution.
blocks_to_swap_in: Dict[int, int] = {}
blocks_to_swap_out: Dict[int, int] = {}
blocks_to_copy: Dict[int, List[int]] = {}
ignored_seq_groups: List[SequenceGroup] = []
# Fix the current time.
now = time.time()
# NOTE(woosuk): We prioritize the sequence groups in the RUNNING state
# in order to minimize the preemption overheads.
# Preemption happens only when there is no available slot to keep all
# the sequence groups in the RUNNING state.
# Join waiting sequences if possible.
if not self.swapped:
ignored_seq_groups: List[SequenceGroup] = []
scheduled: List[SequenceGroup] = []
num_batched_tokens = 0
# Optimization: We do not sort the waiting queue since the preempted
# sequence groups are added to the front and the new sequence groups
# are added to the back.
while self.waiting:
seq_group = self.waiting[0]
num_prompt_tokens = seq_group.get_seqs()[0].get_len()
prompt_limit = min(
self.scheduler_config.max_model_len,
self.scheduler_config.max_num_batched_tokens)
if num_prompt_tokens > prompt_limit:
logger.warning(
f"Input prompt ({num_prompt_tokens} tokens) is too long"
f" and exceeds limit of {prompt_limit}")
for seq in seq_group.get_seqs():
seq.status = SequenceStatus.FINISHED_IGNORED
ignored_seq_groups.append(seq_group)
self.waiting.pop(0)
break
# If the sequence group cannot be allocated, stop.
if not self.block_manager.can_allocate(seq_group):
break
# If the number of batched tokens exceeds the limit, stop.
if (num_batched_tokens + num_prompt_tokens >
self.scheduler_config.max_num_batched_tokens):
break
# The total number of sequences in the RUNNING state should not
# exceed the maximum number of sequences.
num_new_seqs = seq_group.num_seqs(
status=SequenceStatus.WAITING)
num_curr_seqs = sum(
seq_group.num_seqs(status=SequenceStatus.RUNNING)
for seq_group in self.running)
if (num_curr_seqs + num_new_seqs >
self.scheduler_config.max_num_seqs):
break
seq_group = self.waiting.pop(0)
self._allocate(seq_group)
self.running.append(seq_group)
num_batched_tokens += num_prompt_tokens
scheduled.append(seq_group)
if scheduled:
scheduler_outputs = SchedulerOutputs(
scheduled_seq_groups=scheduled,
prompt_run=True,
num_batched_tokens=num_batched_tokens,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
ignored_seq_groups=ignored_seq_groups,
)
return scheduler_outputs
# NOTE(woosuk): Preemption happens only when there is no available slot
# to keep all the sequence groups in the RUNNING state.
# In this case, the policy is responsible for deciding which sequence
# groups to preempt.
self.running = self.policy.sort_by_priority(now, self.running)
@ -173,122 +233,26 @@ class Scheduler:
seq_group.num_seqs(status=SequenceStatus.RUNNING)
for seq_group in self.running)
# Join waiting sequences if possible.
prompt_group_ids: List[str] = []
# NOTE(woosuk): The sequence groups in the SWAPPED state are strictly
# prioritized over the sequence groups in the WAITING state.
# This is because we want to bound the amount of CPU memory taken by
# the swapped sequence groups.
if not self.swapped:
# Optimization: We do not sort the waiting queue since the preempted
# sequence groups are added to the front and the new sequence groups
# are added to the back.
while self.waiting:
seq_group = self.waiting[0]
# If the sequence group has been preempted in this step, stop.
if seq_group in preempted:
break
num_prompt_tokens = seq_group.get_seqs()[0].get_len()
if num_prompt_tokens >= self.scheduler_config.max_seq_len:
logger.warning(
f"Input prompt ({num_prompt_tokens} tokens) is too long"
" and exceeds limit of "
f"{self.scheduler_config.max_seq_len}")
for seq in seq_group.get_seqs():
seq.status = SequenceStatus.FINISHED_IGNORED
ignored_seq_groups.append(seq_group)
self.waiting.pop(0)
break
# If the sequence group cannot be allocated, stop.
if not self.block_manager.can_allocate(seq_group):
break
# If the number of batched tokens exceeds the limit, stop.
if (num_batched_tokens + num_prompt_tokens >
self.scheduler_config.max_num_batched_tokens):
break
# The total number of sequences in the RUNNING state should not
# exceed the maximum number of sequences.
num_new_seqs = seq_group.num_seqs(
status=SequenceStatus.WAITING)
num_curr_seqs = sum(
seq_group.num_seqs(status=SequenceStatus.RUNNING)
for seq_group in self.running)
if (num_curr_seqs + num_new_seqs >
self.scheduler_config.max_num_seqs):
break
seq_group = self.waiting.pop(0)
self._allocate(seq_group)
self.running.append(seq_group)
num_batched_tokens += num_prompt_tokens
prompt_group_ids.append(seq_group.request_id)
scheduler_outputs = SchedulerOutputs(
scheduled_seq_groups=self.running,
prompt_run=False,
num_batched_tokens=num_batched_tokens,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
ignored_seq_groups=[],
)
if not self.log_stats:
return scheduler_outputs, prompt_group_ids, ignored_seq_groups
return scheduler_outputs
# TODO(woosuk): Move the below code to the engine.
now = time.time()
if num_batched_tokens > 0:
self.num_input_tokens.append((now, num_batched_tokens))
elapsed_time = now - self.last_logging_time
if elapsed_time > _LOGGING_INTERVAL_SEC:
self.last_logging_time = now
self.num_input_tokens = [(t, n) for t, n in self.num_input_tokens
if now - t < _LOGGING_INTERVAL_SEC]
if len(self.num_input_tokens) > 1:
total_num_tokens = sum(n
for _, n in self.num_input_tokens[:-1])
window = now - self.num_input_tokens[0][0]
avg_throughput = total_num_tokens / window
else:
avg_throughput = 0.0
total_num_gpu_blocks = self.cache_config.num_gpu_blocks
num_free_gpu_blocks = self.block_manager.get_num_free_gpu_blocks()
num_used_gpu_blocks = total_num_gpu_blocks - num_free_gpu_blocks
gpu_cache_usage = num_used_gpu_blocks / total_num_gpu_blocks
total_num_cpu_blocks = self.cache_config.num_cpu_blocks
if total_num_cpu_blocks > 0:
num_free_cpu_blocks = (
self.block_manager.get_num_free_cpu_blocks())
num_used_cpu_blocks = total_num_cpu_blocks - num_free_cpu_blocks
cpu_cache_usage = num_used_cpu_blocks / total_num_cpu_blocks
else:
cpu_cache_usage = 0.0
logger.info(f"Throughput: {avg_throughput:.1f} tokens/s, "
f"Running: {len(self.running)} reqs, "
f"Swapped: {len(self.swapped)} reqs, "
f"Pending: {len(self.waiting)} reqs, "
f"GPU KV cache usage: {gpu_cache_usage * 100:.1f}%, "
f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%")
return scheduler_outputs, prompt_group_ids, ignored_seq_groups
def schedule(
self
) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs,
List[SequenceGroup]]:
def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
# Schedule sequence groups.
# This function call changes the internal states of the scheduler
# such as self.running, self.swapped, and self.waiting.
(scheduler_outputs, prompt_group_ids,
ignored_seq_groups) = self._schedule()
scheduler_outputs = self._schedule()
# Create input data structures.
seq_group_metadata_list: List[SequenceGroupMetadata] = []
for seq_group in self.running:
is_prompt = seq_group.request_id in prompt_group_ids
for seq_group in scheduler_outputs.scheduled_seq_groups:
seq_data: Dict[int, List[SequenceData]] = {}
block_tables: Dict[int, List[int]] = {}
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
@ -298,20 +262,27 @@ class Scheduler:
seq_group_metadata = SequenceGroupMetadata(
request_id=seq_group.request_id,
is_prompt=is_prompt,
is_prompt=scheduler_outputs.prompt_run,
seq_data=seq_data,
sampling_params=seq_group.sampling_params,
block_tables=block_tables,
)
seq_group_metadata_list.append(seq_group_metadata)
return seq_group_metadata_list, scheduler_outputs, ignored_seq_groups
return seq_group_metadata_list, scheduler_outputs
def update(
self,
seq_outputs: Dict[int, SequenceOutputs],
) -> List[SequenceGroup]:
# Update the running sequences and free blocks.
scheduled: List[SequenceGroup] = []
for seq_group in self.running:
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
if seq.seq_id in seq_outputs:
scheduled.append(seq_group)
break
# Update the scheduled sequences and free blocks.
for seq_group in scheduled:
# Process beam search results before processing the new tokens.
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
output = seq_outputs[seq.seq_id]
@ -329,9 +300,7 @@ class Scheduler:
# Append a new token to the sequence.
output = seq_outputs[seq.seq_id]
seq.append_token_id(output.output_token, output.logprobs)
# Return a shallow copy of the running queue to prevent the queue
# from being modified by the caller.
return self.running.copy()
return scheduled
def free_seq(self, seq: Sequence, finish_status: SequenceStatus) -> None:
seq.status = finish_status

View File

@ -13,6 +13,7 @@ class EngineArgs:
model: str
tokenizer: Optional[str] = None
tokenizer_mode: str = 'auto'
trust_remote_code: bool = False
download_dir: Optional[str] = None
use_np_weights: bool = False
use_dummy_weights: bool = False
@ -55,6 +56,9 @@ class EngineArgs:
help='tokenizer mode. "auto" will use the fast '
'tokenizer if available, and "slow" will '
'always use the slow tokenizer.')
parser.add_argument('--trust-remote-code',
action='store_true',
help='trust remote code from huggingface')
parser.add_argument('--download-dir',
type=str,
default=EngineArgs.download_dir,
@ -141,20 +145,19 @@ class EngineArgs:
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]:
# Initialize the configs.
model_config = ModelConfig(self.model, self.tokenizer,
self.tokenizer_mode, self.download_dir,
self.use_np_weights, self.use_dummy_weights,
self.dtype, self.seed)
self.tokenizer_mode, self.trust_remote_code,
self.download_dir, self.use_np_weights,
self.use_dummy_weights, self.dtype,
self.seed)
cache_config = CacheConfig(self.block_size,
self.gpu_memory_utilization,
self.swap_space)
parallel_config = ParallelConfig(self.pipeline_parallel_size,
self.tensor_parallel_size,
self.worker_use_ray)
model_max_len = getattr(model_config.hf_config,
'max_position_embeddings', float('inf'))
max_seq_len = min(self.max_num_batched_tokens, model_max_len)
scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
self.max_num_seqs, max_seq_len)
self.max_num_seqs,
model_config.get_max_model_len())
return model_config, cache_config, parallel_config, scheduler_config

View File

@ -226,14 +226,14 @@ class AsyncLLMEngine:
engine_configs = engine_args.create_engine_configs()
parallel_config = engine_configs[2]
# Initialize the cluster.
distributed_init_method, devices = initialize_cluster(
distributed_init_method, placement_group = initialize_cluster(
parallel_config, engine_args.engine_use_ray)
# Create the async LLM engine.
engine = cls(engine_args.worker_use_ray,
engine_args.engine_use_ray,
*engine_configs,
distributed_init_method,
devices,
placement_group,
log_requests=not engine_args.disable_log_requests,
log_stats=not engine_args.disable_log_stats)
return engine

View File

@ -1,11 +1,13 @@
import time
from typing import Any, List, Optional
import copy
from functools import partial
from typing import Any, List, Optional, Tuple, TYPE_CHECKING
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig)
from vllm.core.scheduler import Scheduler
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.ray_utils import DeviceID, initialize_cluster, ray
from vllm.engine.ray_utils import initialize_cluster, ray, RayWorker
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
@ -13,10 +15,18 @@ from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
get_tokenizer)
from vllm.utils import Counter
from vllm.worker.worker import Worker
if ray:
from ray.air.util.torch_dist import init_torch_dist_process_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
logger = init_logger(__name__)
_LOGGING_INTERVAL_SEC = 5
class LLMEngine:
"""An LLM engine that receives requests and generates texts.
@ -54,7 +64,7 @@ class LLMEngine:
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
distributed_init_method: str,
stage_devices: List[List[DeviceID]],
placement_group: Optional["PlacementGroup"],
log_stats: bool,
) -> None:
logger.info(
@ -62,6 +72,7 @@ class LLMEngine:
f"model={model_config.model!r}, "
f"tokenizer={model_config.tokenizer!r}, "
f"tokenizer_mode={model_config.tokenizer_mode}, "
f"trust_remote_code={model_config.trust_remote_code}, "
f"dtype={model_config.dtype}, "
f"use_dummy_weights={model_config.use_dummy_weights}, "
f"download_dir={model_config.download_dir!r}, "
@ -78,34 +89,88 @@ class LLMEngine:
self._verify_args()
self.tokenizer = get_tokenizer(
model_config.tokenizer, tokenizer_mode=model_config.tokenizer_mode)
model_config.tokenizer,
tokenizer_mode=model_config.tokenizer_mode,
trust_remote_code=model_config.trust_remote_code)
self.seq_counter = Counter()
# Create the parallel GPU workers.
self.workers: List[Worker] = []
assert len(stage_devices) == 1, "Only support one stage for now."
for rank, node_resource, _ in stage_devices[0]:
worker_cls = Worker
if self.parallel_config.worker_use_ray:
worker_cls = ray.remote(
num_cpus=0,
num_gpus=1,
resources={node_resource: 1e-3},
)(worker_cls).remote
if self.parallel_config.worker_use_ray:
self._init_workers_ray(placement_group)
else:
self._init_workers(distributed_init_method)
worker = worker_cls(
model_config,
parallel_config,
scheduler_config,
rank,
distributed_init_method,
)
self.workers.append(worker)
# Profile the memory usage and initialize the cache.
self._init_cache()
# Create the scheduler.
self.scheduler = Scheduler(scheduler_config, cache_config, log_stats)
self.scheduler = Scheduler(scheduler_config, cache_config)
# Logging.
self.last_logging_time = 0.0
# List of (timestamp, num_tokens)
self.num_prompt_tokens: List[Tuple[float, int]] = []
# List of (timestamp, num_tokens)
self.num_generation_tokens: List[Tuple[float, int]] = []
def _init_workers(self, distributed_init_method: str):
# Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker
from vllm.worker.worker import Worker # pylint: disable=import-outside-toplevel
assert self.parallel_config.world_size == 1, (
"Ray is required if parallel_config.world_size > 1.")
self.workers: List[Worker] = []
worker = Worker(
self.model_config,
self.parallel_config,
self.scheduler_config,
0,
distributed_init_method,
)
self.workers.append(worker)
self._run_workers(
"init_model",
get_all_outputs=True,
)
def _init_workers_ray(self, placement_group: "PlacementGroup"):
# Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker
from vllm.worker.worker import Worker # pylint: disable=import-outside-toplevel
self.workers: List[Worker] = []
for bundle in placement_group.bundle_specs:
if not bundle.get("GPU", 0):
continue
worker = ray.remote(
num_cpus=0,
num_gpus=1,
scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=placement_group,
placement_group_capture_child_tasks=True),
)(RayWorker).remote()
self.workers.append(worker)
# Initialize torch distributed process group for the workers.
init_torch_dist_process_group(self.workers, backend="nccl")
model_config = copy.deepcopy(self.model_config)
parallel_config = copy.deepcopy(self.parallel_config)
scheduler_config = copy.deepcopy(self.scheduler_config)
self._run_workers("init_worker",
get_all_outputs=True,
worker_init_fn=lambda: Worker(
model_config,
parallel_config,
scheduler_config,
None,
None,
))
self._run_workers(
"init_model",
get_all_outputs=True,
)
def _verify_args(self) -> None:
self.model_config.verify_with_parallel_config(self.parallel_config)
@ -149,11 +214,12 @@ class LLMEngine:
engine_configs = engine_args.create_engine_configs()
parallel_config = engine_configs[2]
# Initialize the cluster.
distributed_init_method, devices = initialize_cluster(parallel_config)
distributed_init_method, placement_group = initialize_cluster(
parallel_config)
# Create the LLM engine.
engine = cls(*engine_configs,
distributed_init_method,
devices,
placement_group,
log_stats=not engine_args.disable_log_stats)
return engine
@ -231,12 +297,17 @@ class LLMEngine:
and updates the scheduler with the model outputs. Finally, it decodes
the sequences and returns the newly generated results.
"""
(seq_group_metadata_list, scheduler_outputs,
ignored_seq_groups) = self.scheduler.schedule()
if ((not seq_group_metadata_list) and scheduler_outputs.is_empty()
and (not ignored_seq_groups)):
# Nothing to do.
return []
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
if scheduler_outputs.is_empty():
if not scheduler_outputs.ignored_seq_groups:
# Nothing to do.
return []
# If there are ignored seq groups, we need to return them as the
# request outputs.
return [
RequestOutput.from_seq_group(seq_group)
for seq_group in scheduler_outputs.ignored_seq_groups
]
# Execute the model.
output = self._run_workers(
@ -258,11 +329,79 @@ class LLMEngine:
# Create the outputs.
request_outputs: List[RequestOutput] = []
for seq_group in seq_groups + ignored_seq_groups:
for seq_group in seq_groups + scheduler_outputs.ignored_seq_groups:
request_output = RequestOutput.from_seq_group(seq_group)
request_outputs.append(request_output)
if self.log_stats:
# Log the system stats.
self._log_system_stats(scheduler_outputs.prompt_run,
scheduler_outputs.num_batched_tokens)
return request_outputs
def _log_system_stats(
self,
prompt_run: bool,
num_batched_tokens: int,
) -> None:
now = time.time()
# Log the number of batched input tokens.
if prompt_run:
self.num_prompt_tokens.append((now, num_batched_tokens))
else:
self.num_generation_tokens.append((now, num_batched_tokens))
elapsed_time = now - self.last_logging_time
if elapsed_time < _LOGGING_INTERVAL_SEC:
return
# Discard the old stats.
self.num_prompt_tokens = [(t, n) for t, n in self.num_prompt_tokens
if now - t < _LOGGING_INTERVAL_SEC]
self.num_generation_tokens = [(t, n)
for t, n in self.num_generation_tokens
if now - t < _LOGGING_INTERVAL_SEC]
if len(self.num_prompt_tokens) > 1:
total_num_tokens = sum(n for _, n in self.num_prompt_tokens[:-1])
window = now - self.num_prompt_tokens[0][0]
avg_prompt_throughput = total_num_tokens / window
else:
avg_prompt_throughput = 0.0
if len(self.num_generation_tokens) > 1:
total_num_tokens = sum(n
for _, n in self.num_generation_tokens[:-1])
window = now - self.num_generation_tokens[0][0]
avg_generation_throughput = total_num_tokens / window
else:
avg_generation_throughput = 0.0
total_num_gpu_blocks = self.cache_config.num_gpu_blocks
num_free_gpu_blocks = (
self.scheduler.block_manager.get_num_free_gpu_blocks())
num_used_gpu_blocks = total_num_gpu_blocks - num_free_gpu_blocks
gpu_cache_usage = num_used_gpu_blocks / total_num_gpu_blocks
total_num_cpu_blocks = self.cache_config.num_cpu_blocks
if total_num_cpu_blocks > 0:
num_free_cpu_blocks = (
self.scheduler.block_manager.get_num_free_cpu_blocks())
num_used_cpu_blocks = total_num_cpu_blocks - num_free_cpu_blocks
cpu_cache_usage = num_used_cpu_blocks / total_num_cpu_blocks
else:
cpu_cache_usage = 0.0
logger.info("Avg prompt throughput: "
f"{avg_prompt_throughput:.1f} tokens/s, "
"Avg generation throughput: "
f"{avg_generation_throughput:.1f} tokens/s, "
f"Running: {len(self.scheduler.running)} reqs, "
f"Swapped: {len(self.scheduler.swapped)} reqs, "
f"Pending: {len(self.scheduler.waiting)} reqs, "
f"GPU KV cache usage: {gpu_cache_usage * 100:.1f}%, "
f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%")
self.last_logging_time = now
def _decode_sequences(self, seq_groups: List[SequenceGroup]) -> None:
"""Decodes the sequence outputs."""
for seq_group in seq_groups:
@ -273,8 +412,9 @@ class LLMEngine:
seq.get_last_token_id(),
skip_special_tokens=True,
)
seq.output_tokens.append(new_token)
seq.output_text = new_output_text
if new_token is not None:
seq.output_tokens.append(new_token)
seq.output_text = new_output_text
def _stop_sequences(self, seq_groups: List[SequenceGroup]) -> None:
"""Stop the finished sequences."""
@ -295,9 +435,8 @@ class LLMEngine:
if stopped:
continue
# Check if the sequence has reached max_seq_len.
if (seq.get_len() >=
self.scheduler.scheduler_config.max_seq_len):
# Check if the sequence has reached max_model_len.
if seq.get_len() > self.scheduler_config.max_model_len:
self.scheduler.free_seq(
seq, SequenceStatus.FINISHED_LENGTH_CAPPED)
continue
@ -323,9 +462,10 @@ class LLMEngine:
"""Runs the given method on all workers."""
all_outputs = []
for worker in self.workers:
executor = getattr(worker, method)
if self.parallel_config.worker_use_ray:
executor = executor.remote
executor = partial(worker.execute_method.remote, method)
else:
executor = getattr(worker, method)
output = executor(*args, **kwargs)
all_outputs.append(output)

View File

@ -1,22 +1,49 @@
import random
from typing import List, Optional, Tuple
try:
import ray
except ImportError:
ray = None
import socket
from typing import Optional, Tuple, TYPE_CHECKING
from vllm.config import ParallelConfig
# rank, node resource (node IP), device id
DeviceID = Tuple[int, Optional[str], int]
try:
import ray
from ray.air.util.torch_dist import TorchDistributedWorker
class RayWorker(TorchDistributedWorker):
"""Ray wrapper for vllm.worker.Worker, allowing Worker to be
lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES."""
def __init__(self) -> None:
self.worker = None
def init_worker(self, worker_init_fn):
self.worker = worker_init_fn()
def __getattr__(self, name):
return getattr(self.worker, name)
def execute_method(self, method, *args, **kwargs):
executor = getattr(self, method)
return executor(*args, **kwargs)
except ImportError:
ray = None
TorchDistributedWorker = None
RayWorker = None # pylint: disable=invalid-name
if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
def get_open_port():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return s.getsockname()[1]
def initialize_cluster(
parallel_config: ParallelConfig,
engine_use_ray: bool = False,
ray_address: Optional[str] = None,
) -> Tuple[str, List[List[DeviceID]]]:
) -> Tuple[str, Optional["PlacementGroup"]]:
"""Initialize the distributed cluster probably with Ray.
Args:
@ -38,71 +65,46 @@ def initialize_cluster(
"Ray is not installed. Please install Ray to use distributed "
"serving.")
# Connect to a ray cluster.
ray.init(address=ray_address)
ray.init(address=ray_address, ignore_reinit_error=True)
if not parallel_config.worker_use_ray:
# Initialize cluster locally.
port = random.randint(10000, 20000)
port = get_open_port()
# We need to setup the distributed init method to make sure
# the distributed megatron code (e.g., get world size) works correctly.
distributed_init_method = f"tcp://localhost:{port}"
all_stage_devices = [[(0, None, 0)]]
return distributed_init_method, all_stage_devices
return distributed_init_method, None
# Assume we have a uniform cluster that each node has the same number of
# GPUs for now.
valid_node_resources = []
num_devices_per_node = None
for node in ray.nodes():
if (not node["Alive"]) or node["Resources"]["GPU"] <= 0:
continue
if num_devices_per_node is None:
num_devices_per_node = node["Resources"]["GPU"]
else:
assert num_devices_per_node == node["Resources"]["GPU"], (
"The number of GPUs per node is not uniform.")
for key in node["Resources"]:
if key.startswith("node:"):
valid_node_resources.append(key)
# Verify the parallel config.
num_nodes = len(valid_node_resources)
if parallel_config.world_size > num_nodes * num_devices_per_node:
raise ValueError(
"The number of required GPUs exceeds the total number of "
"available GPUs.")
if parallel_config.tensor_parallel_size >= num_devices_per_node:
if parallel_config.tensor_parallel_size % num_devices_per_node != 0:
current_placement_group = ray.util.get_current_placement_group()
if current_placement_group:
# We are in a placement group
bundles = current_placement_group.bundle_specs
# Verify that we can use the placement group.
gpu_bundles = 0
for bundle in bundles:
bundle_gpus = bundle.get("GPU", 0)
if bundle_gpus > 1:
raise ValueError(
"Placement group bundle cannot have more than 1 GPU.")
if bundle_gpus:
gpu_bundles += 1
if parallel_config.world_size > gpu_bundles:
raise ValueError(
"The number of tensor parallelism is not divisible by the "
"number of GPUs per node.")
"The number of required GPUs exceeds the total number of "
"available GPUs in the placement group.")
else:
if num_devices_per_node % parallel_config.tensor_parallel_size != 0:
num_gpus_in_cluster = ray.cluster_resources().get("GPU", 0)
if parallel_config.world_size > num_gpus_in_cluster:
raise ValueError(
"The number of GPUs per node is not divisible by the number "
"of tensor parallelism.")
"The number of required GPUs exceeds the total number of "
"available GPUs in the cluster.")
# Create a new placement group
current_placement_group = ray.util.placement_group([{
"GPU": 1
}] * parallel_config.world_size)
# Wait until PG is ready - this will block until all
# requested resources are available, and will timeout
# if they cannot be provisioned.
ray.get(current_placement_group.ready(), timeout=1800)
# Assign GPUs to pipeline stages.
rank = 0
current_node_id = 0
current_device_id = 0
distributed_init_method = None
all_stage_devices = []
for _ in range(parallel_config.pipeline_parallel_size):
stage_devices = []
for _ in range(parallel_config.tensor_parallel_size):
node_resource = valid_node_resources[current_node_id]
stage_devices.append((rank, node_resource, current_device_id))
if distributed_init_method is None:
ip = node_resource.split("node:")[-1]
port = random.randint(10000, 20000)
distributed_init_method = f"tcp://{ip}:{port}"
rank += 1
current_device_id += 1
if current_device_id >= num_devices_per_node:
current_node_id += 1
current_device_id = 0
all_stage_devices.append(stage_devices)
return distributed_init_method, all_stage_devices
return None, current_placement_group

View File

@ -3,7 +3,7 @@ import json
from typing import AsyncGenerator
from fastapi import BackgroundTasks, FastAPI, Request
from fastapi.responses import Response, StreamingResponse
from fastapi.responses import JSONResponse, Response, StreamingResponse
import uvicorn
from vllm.engine.arg_utils import AsyncEngineArgs
@ -64,7 +64,7 @@ async def generate(request: Request) -> Response:
prompt = final_output.prompt
text_outputs = [prompt + output.text for output in final_output.outputs]
ret = {"text": text_outputs}
return Response(content=json.dumps(ret))
return JSONResponse(ret)
if __name__ == "__main__":

View File

@ -28,6 +28,8 @@ class LLM:
tokenizer: The name or path of a HuggingFace Transformers tokenizer.
tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
if available, and "slow" will always use the slow tokenizer.
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
downloading the model and tokenizer.
tensor_parallel_size: The number of GPUs to use for distributed
execution with tensor parallelism.
dtype: The data type for the model weights and activations. Currently,
@ -43,6 +45,7 @@ class LLM:
model: str,
tokenizer: Optional[str] = None,
tokenizer_mode: str = "auto",
trust_remote_code: bool = False,
tensor_parallel_size: int = 1,
dtype: str = "auto",
seed: int = 0,
@ -54,6 +57,7 @@ class LLM:
model=model,
tokenizer=tokenizer,
tokenizer_mode=tokenizer_mode,
trust_remote_code=trust_remote_code,
tensor_parallel_size=tensor_parallel_size,
dtype=dtype,
seed=seed,
@ -151,4 +155,8 @@ class LLM:
pbar.update(1)
if use_tqdm:
pbar.close()
# Sort the outputs by request ID.
# This is necessary because some requests may be finished earlier than
# its previous requests.
outputs = sorted(outputs, key=lambda x: int(x.request_id))
return outputs

View File

@ -7,15 +7,13 @@ from http import HTTPStatus
import json
import time
from typing import AsyncGenerator, Dict, List, Optional
from packaging import version
import fastapi
from fastapi import BackgroundTasks, Request
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
from fastchat.conversation import Conversation, SeparatorStyle
from fastchat.model.model_adapter import get_conversation_template
import uvicorn
from vllm.engine.arg_utils import AsyncEngineArgs
@ -33,6 +31,14 @@ from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils import random_uuid
try:
import fastchat
from fastchat.conversation import Conversation, SeparatorStyle
from fastchat.model.model_adapter import get_conversation_template
_fastchat_available = True
except ImportError:
_fastchat_available = False
TIMEOUT_KEEP_ALIVE = 5 # seconds
logger = init_logger(__name__)
@ -63,10 +69,21 @@ async def check_model(request) -> Optional[JSONResponse]:
async def get_gen_prompt(request) -> str:
if not _fastchat_available:
raise ModuleNotFoundError(
"fastchat is not installed. Please install fastchat to use "
"the chat completion and conversation APIs: `$ pip install fschat`"
)
if version.parse(fastchat.__version__) < version.parse("0.2.23"):
raise ImportError(
f"fastchat version is low. Current version: {fastchat.__version__} "
"Please upgrade fastchat to use: `$ pip install -U fschat`")
conv = get_conversation_template(request.model)
conv = Conversation(
name=conv.name,
system=conv.system,
system_template=conv.system_template,
system_message=conv.system_message,
roles=conv.roles,
messages=list(conv.messages), # prevent in-place modification
offset=conv.offset,
@ -83,7 +100,7 @@ async def get_gen_prompt(request) -> str:
for message in request.messages:
msg_role = message["role"]
if msg_role == "system":
conv.system = message["content"]
conv.system_message = message["content"]
elif msg_role == "user":
conv.append_message(conv.roles[0], message["content"])
elif msg_role == "assistant":
@ -98,25 +115,14 @@ async def get_gen_prompt(request) -> str:
return prompt
async def check_length(request, prompt, model_config):
if hasattr(model_config.hf_config, "max_sequence_length"):
context_len = model_config.hf_config.max_sequence_length
elif hasattr(model_config.hf_config, "seq_length"):
context_len = model_config.hf_config.seq_length
elif hasattr(model_config.hf_config, "max_position_embeddings"):
context_len = model_config.hf_config.max_position_embeddings
elif hasattr(model_config.hf_config, "seq_length"):
context_len = model_config.hf_config.seq_length
else:
context_len = 2048
async def check_length(request, prompt):
input_ids = tokenizer(prompt).input_ids
token_num = len(input_ids)
if token_num + request.max_tokens > context_len:
if token_num + request.max_tokens > max_model_len:
return create_error_response(
HTTPStatus.BAD_REQUEST,
f"This model's maximum context length is {context_len} tokens. "
f"This model's maximum context length is {max_model_len} tokens. "
f"However, you requested {request.max_tokens + token_num} tokens "
f"({token_num} in the messages, "
f"{request.max_tokens} in the completion). "
@ -185,7 +191,7 @@ async def create_chat_completion(raw_request: Request):
"logit_bias is not currently supported")
prompt = await get_gen_prompt(request)
error_check_ret = await check_length(request, prompt, engine_model_config)
error_check_ret = await check_length(request, prompt)
if error_check_ret is not None:
return error_check_ret
@ -269,7 +275,7 @@ async def create_chat_completion(raw_request: Request):
finish_reason=output.finish_reason,
)
yield f"data: {response_json}\n\n"
yield "data: [DONE]\n\n"
yield "data: [DONE]\n\n"
# Streaming response
if request.stream:
@ -465,7 +471,7 @@ async def create_completion(raw_request: Request):
finish_reason=output.finish_reason,
)
yield f"data: {response_json}\n\n"
yield "data: [DONE]\n\n"
yield "data: [DONE]\n\n"
# Streaming response
if stream:
@ -582,10 +588,12 @@ if __name__ == "__main__":
engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(engine_args)
engine_model_config = asyncio.run(engine.get_model_config())
max_model_len = engine_model_config.get_max_model_len()
# A separate tokenizer to map token IDs to strings.
tokenizer = get_tokenizer(engine_args.tokenizer,
tokenizer_mode=engine_args.tokenizer_mode)
tokenizer_mode=engine_args.tokenizer_mode,
trust_remote_code=engine_args.trust_remote_code)
uvicorn.run(app,
host=args.host,

View File

@ -12,7 +12,7 @@ from vllm import cache_ops
from vllm import pos_encoding_ops
from vllm.model_executor.input_metadata import InputMetadata
_SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128]
_SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256]
class PagedAttention(nn.Module):
@ -20,12 +20,20 @@ class PagedAttention(nn.Module):
"""GPT-style multi-head PagedAttention.
This class takes flattened 1D query, key, and value tensors as input. The
input 1D tensors can be split into three parts: the prompt tokens, the
generation tokens, and the paddings.
input 1D tensors can either contain prompt tokens or generation tokens, in
addition to paddings.
|<------------------------------------- num_valid_tokens ------------------------------------->|
|<--------------- num_prompt_tokens -------------->|<------- num_generation_tokens (M) ------->|
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|<--generation_0-->|...|<--generation_M-1-->|<--padding-->|
If the input tensors contain prompt tokens, the layout is as follows:
|<---------------------- num_valid_tokens ---------------------->|
|<--------------- num_prompt_tokens -------------->|
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|<--padding-->|
Otherwise, the layout is as follows:
|<------------------ num_valid_tokens ------------------->|
|<------- num_generation_tokens (M) ------->|
|<--generation_0-->|...|<--generation_M-1-->|<--padding-->|
The prompts might have different lengths, while the generation tokens always
have length 1. The paddings are appended to make the input length a multiple
@ -44,12 +52,23 @@ class PagedAttention(nn.Module):
5. Output a flattened 1D tensor.
"""
def __init__(self, num_heads: int, head_size: int, scale: float) -> None:
def __init__(self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: Optional[int] = None) -> None:
super().__init__()
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.attn_op = xops.fmha.cutlass.FwOp()
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.head_mapping = torch.repeat_interleave(
torch.arange(self.num_kv_heads, dtype=torch.int32, device="cuda"),
self.num_queries_per_kv)
if self.head_size not in _SUPPORTED_HEAD_SIZES:
raise ValueError(f"head_size ({self.head_size}) is not supported. "
@ -76,10 +95,18 @@ class PagedAttention(nn.Module):
Args:
output: shape = [num_prompt_tokens, num_heads, head_size]
query: shape = [num_prompt_tokens, num_heads, head_size]
key: shape = [num_prompt_tokens, num_heads, head_size]
value: shape = [num_prompt_tokens, num_heads, head_size]
key: shape = [num_prompt_tokens, num_kv_heads, head_size]
value: shape = [num_prompt_tokens, num_kv_heads, head_size]
input_metadata: metadata for paged attention.
"""
if self.num_kv_heads != self.num_heads:
# Project the key and value tensors to the desired number of heads.
key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=1)
value = torch.repeat_interleave(value,
self.num_queries_per_kv,
dim=1)
# TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize.
out = xops.memory_efficient_attention_forward(
query.unsqueeze(0),
@ -107,9 +134,10 @@ class PagedAttention(nn.Module):
Args:
output: shape = [num_generation_tokens, num_heads, head_size]
query: shape = [num_generation_tokens, num_heads, head_size]
key_cache: shape = [num_blocks, num_heads, head_size/x,
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
block_size, x]
value_cache: shape = [num_blocks, num_heads, head_size, block_size]
value_cache: shape = [num_blocks, num_kv_heads, head_size,
block_size]
input_metadata: metadata for paged attention.
"""
block_size = value_cache.shape[3]
@ -118,6 +146,7 @@ class PagedAttention(nn.Module):
query,
key_cache,
value_cache,
self.head_mapping,
self.scale,
input_metadata.block_tables,
input_metadata.context_lens,
@ -143,11 +172,12 @@ class PagedAttention(nn.Module):
Args:
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_heads * head_size]
value: shape = [num_tokens, num_heads * head_size]
key_cache: shape = [num_blocks, num_heads, head_size/x,
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
block_size, x]
value_cache: shape = [num_blocks, num_heads, head_size, block_size]
value_cache: shape = [num_blocks, num_kv_heads, head_size,
block_size]
input_metadata: metadata for paged attention.
cache_event: event to wait for the cache operations to finish.
@ -157,8 +187,8 @@ class PagedAttention(nn.Module):
# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_heads, self.head_size)
value = value.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
# Pre-allocate the output tensor.
output = torch.empty_like(query)
@ -166,6 +196,8 @@ class PagedAttention(nn.Module):
# Compute the attention op for prompts.
num_prompt_tokens = input_metadata.num_prompt_tokens
if num_prompt_tokens > 0:
# Prompt run.
assert input_metadata.num_generation_tokens == 0
self.set_attn_bias(input_metadata)
self.multi_query_kv_attention(
output[:num_prompt_tokens],
@ -195,6 +227,8 @@ class PagedAttention(nn.Module):
)
if input_metadata.num_generation_tokens > 0:
# Decoding run.
assert input_metadata.num_prompt_tokens == 0
assert key_cache is not None and value_cache is not None, (
"key_cache and value_cache must be provided when "
"generating tokens.")
@ -220,8 +254,9 @@ class PagedAttentionWithRoPE(PagedAttention):
rotary_dim: int,
max_position: int = 8192,
base: int = 10000,
num_kv_heads: Optional[int] = None,
) -> None:
super().__init__(num_heads, head_size, scale)
super().__init__(num_heads, head_size, scale, num_kv_heads)
# Create the cos and sin cache.
inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim))
@ -255,11 +290,12 @@ class PagedAttentionWithRoPE(PagedAttention):
Args:
positions: shape = [num_tokens]
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_heads * head_size]
value: shape = [num_tokens, num_heads * head_size]
key_cache: shape = [num_blocks, num_heads, head_size/x,
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
block_size, x]
value_cache: shape = [num_blocks, num_heads, head_size, block_size]
value_cache: shape = [num_blocks, num_kv_heads, head_size,
block_size]
input_metadata: metadata for paged attention.
cache_event: event to wait for the cache operations to finish.
@ -290,14 +326,13 @@ class PagedAttentionWithRoPE(PagedAttention):
class PagedAttentionWithALiBi(PagedAttention):
"""PagedAttention with ALiBi attention bias."""
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
slopes: List[float],
) -> None:
super().__init__(num_heads, head_size, scale)
def __init__(self,
num_heads: int,
head_size: int,
scale: float,
slopes: List[float],
num_kv_heads: Optional[int] = None) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads)
assert len(slopes) == num_heads
slopes = torch.tensor(slopes, dtype=torch.float32)
@ -310,6 +345,11 @@ class PagedAttentionWithALiBi(PagedAttention):
# Generates ALiBi mask for each prompt.
for prompt_len in input_metadata.prompt_lens:
bias = torch.arange(prompt_len)
# Note(zhuohan): HF uses
# `bias = bias[None, :].repeat(prompt_len, 1)`
# here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi
# paper.
bias = bias[None, :] - bias[:, None]
bias = bias.to(self.alibi_slopes.device)
@ -339,10 +379,17 @@ class PagedAttentionWithALiBi(PagedAttention):
Args:
output: shape = [num_prompt_tokens, num_heads, head_size]
query: shape = [num_prompt_tokens, num_heads, head_size]
key: shape = [num_prompt_tokens, num_heads, head_size]
value: shape = [num_prompt_tokens, num_heads, head_size]
key: shape = [num_prompt_tokens, num_kv_heads, head_size]
value: shape = [num_prompt_tokens, num_kv_heads, head_size]
input_metadata: metadata for paged attention.
"""
if self.num_kv_heads != self.num_heads:
# Project the key and value tensors to the desired number of heads.
key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=1)
value = torch.repeat_interleave(value,
self.num_queries_per_kv,
dim=1)
# FIXME(woosuk): Because xformers does not support dynamic sequence
# lengths with custom attention bias, we process each prompt one by
# one. This is inefficient, especially when we have many short prompts.
@ -376,9 +423,10 @@ class PagedAttentionWithALiBi(PagedAttention):
Args:
output: shape = [num_generation_tokens, num_heads, head_size]
query: shape = [num_generation_tokens, num_heads, head_size]
key_cache: shape = [num_blocks, num_heads, head_size/x,
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
block_size, x]
value_cache: shape = [num_blocks, num_heads, head_size, block_size]
value_cache: shape = [num_blocks, num_kv_heads, head_size,
block_size]
input_metadata: metadata for paged attention.
"""
block_size = value_cache.shape[3]
@ -387,6 +435,7 @@ class PagedAttentionWithALiBi(PagedAttention):
query,
key_cache,
value_cache,
self.head_mapping,
self.scale,
input_metadata.block_tables,
input_metadata.context_lens,

View File

@ -38,12 +38,15 @@ class Sampler(nn.Module):
embedding: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
embedding_bias: Optional[torch.Tensor] = None,
) -> Dict[int, SequenceOutputs]:
# Get the hidden states that we use for sampling.
hidden_states = _prune_hidden_states(hidden_states, input_metadata)
# Get the logits for the next tokens.
logits = torch.matmul(hidden_states, embedding.t())
if embedding_bias is not None:
logits += embedding_bias
logits = gather_from_tensor_model_parallel_region(logits)
# Remove paddings in vocab (if any).
logits = logits[:, :self.vocab_size]

View File

@ -11,14 +11,19 @@ from vllm.model_executor.weight_utils import initialize_dummy_weights
# TODO(woosuk): Lazy-load the model classes.
_MODEL_REGISTRY = {
"BaiChuanForCausalLM": BaiChuanForCausalLM, # baichuan-7b
"BaichuanForCausalLM": BaichuanForCausalLM, # baichuan-13b
"BloomForCausalLM": BloomForCausalLM,
"FalconForCausalLM": FalconForCausalLM,
"GPT2LMHeadModel": GPT2LMHeadModel,
"GPTBigCodeForCausalLM": GPTBigCodeForCausalLM,
"GPTJForCausalLM": GPTJForCausalLM,
"GPTNeoXForCausalLM": GPTNeoXForCausalLM,
"LlamaForCausalLM": LlamaForCausalLM,
"LLaMAForCausalLM": LlamaForCausalLM, # For decapoda-research/llama-*
"MPTForCausalLM": MPTForCausalLM,
"OPTForCausalLM": OPTForCausalLM,
"RWForCausalLM": FalconForCausalLM,
}

View File

@ -1,15 +1,23 @@
from vllm.model_executor.models.baichuan import (BaiChuanForCausalLM,
BaichuanForCausalLM)
from vllm.model_executor.models.bloom import BloomForCausalLM
from vllm.model_executor.models.falcon import FalconForCausalLM
from vllm.model_executor.models.gpt2 import GPT2LMHeadModel
from vllm.model_executor.models.gpt_bigcode import GPTBigCodeForCausalLM
from vllm.model_executor.models.gpt_j import GPTJForCausalLM
from vllm.model_executor.models.gpt_neox import GPTNeoXForCausalLM
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.model_executor.models.mpt import MPTForCausalLM
from vllm.model_executor.models.opt import OPTForCausalLM
__all__ = [
"BaiChuanForCausalLM",
"BaichuanForCausalLM",
"BloomForCausalLM",
"FalconForCausalLM",
"GPT2LMHeadModel",
"GPTBigCodeForCausalLM",
"GPTJForCausalLM",
"GPTNeoXForCausalLM",
"LlamaForCausalLM",
"MPTForCausalLM",

View File

@ -0,0 +1,377 @@
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only BaiChuan model compatible with HuggingFace weights.
The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
"""
import math
from typing import Dict, List, Optional, Tuple
import torch
from torch import nn
from vllm.sequence import SequenceOutputs
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE, PagedAttentionWithALiBi
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
from vllm.transformers_utils.configs.baichuan import BaiChuanConfig
KVCache = Tuple[torch.Tensor, torch.Tensor]
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
base = torch.tensor(
2**(-(2**-(math.log2(closest_power_of_2) - 3))),
dtype=torch.float32,
)
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
slopes = torch.pow(base, powers)
if closest_power_of_2 != total_num_heads:
extra_base = torch.tensor(
2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
dtype=torch.float32,
)
num_remaining_heads = min(closest_power_of_2,
total_num_heads - closest_power_of_2)
extra_powers = torch.arange(start=1,
end=1 + 2 * num_remaining_heads,
step=2,
dtype=torch.int32)
slopes = torch.cat(
[slopes, torch.pow(extra_base, extra_powers)], dim=0)
return slopes
class BaiChuanMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
):
super().__init__()
self.gate_up_proj = ColumnParallelLinear(hidden_size,
2 * intermediate_size,
bias=False,
gather_output=False,
perform_initialization=False)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
input_is_parallel=True,
perform_initialization=False)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class BaiChuanAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
hidden_size: int,
num_heads: int,
position_embedding: str,
):
super().__init__()
self.hidden_size = hidden_size
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
)
self.total_num_heads = num_heads
assert self.total_num_heads % tensor_model_parallel_world_size == 0
self.num_heads = (self.total_num_heads //
tensor_model_parallel_world_size)
self.head_dim = hidden_size // self.total_num_heads
self.postion_embedding = position_embedding
# pylint: disable=invalid-name
self.W_pack = ColumnParallelLinear(
hidden_size,
3 * hidden_size,
bias=False,
gather_output=False,
perform_initialization=False,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
input_is_parallel=True,
perform_initialization=False,
)
# Create the alibi slopes and slice them.
if self.postion_embedding == "ALIBI":
tp_rank = get_tensor_model_parallel_rank()
head_start = tp_rank * self.num_heads
head_end = (tp_rank + 1) * self.num_heads
alibi_slopes = _get_alibi_slopes(self.total_num_heads)
alibi_slopes = alibi_slopes[head_start:head_end].tolist()
scaling = self.head_dim**-0.5
self.attn = PagedAttentionWithALiBi(self.num_heads, self.head_dim,
scaling, alibi_slopes)
else:
self.scaling = self.head_dim**-0.5
self.attn = PagedAttentionWithRoPE(self.num_heads,
self.head_dim,
self.scaling,
rotary_dim=self.head_dim)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
qkv, _ = self.W_pack(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
k_cache, v_cache = kv_cache
if self.postion_embedding == "ALIBI":
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
cache_event)
else:
attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
input_metadata, cache_event)
output, _ = self.o_proj(attn_output)
return output
class BaiChuanDecoderLayer(nn.Module):
def __init__(self, config: BaiChuanConfig, position_embedding: str):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = BaiChuanAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
position_embedding=position_embedding,
)
self.mlp = BaiChuanMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
# Self Attention
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class BaiChuanModel(nn.Module):
def __init__(self, config: BaiChuanConfig, position_embedding: str):
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
perform_initialization=False)
self.layers = nn.ModuleList([
BaiChuanDecoderLayer(config, position_embedding)
for _ in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
for i in range(len(self.layers)):
if cache_events is None:
cache_event = None
else:
cache_event = cache_events[i]
layer = self.layers[i]
hidden_states = layer(
positions,
hidden_states,
kv_caches[i],
input_metadata,
cache_event,
)
hidden_states = self.norm(hidden_states)
return hidden_states
class BaiChuanBaseForCausalLM(nn.Module):
def __init__(self, config, position_embedding: str):
super().__init__()
self.config = config
self.model = BaiChuanModel(config, position_embedding)
self.lm_head = ColumnParallelLinear(config.hidden_size,
config.vocab_size,
bias=False,
gather_output=False,
perform_initialization=False)
self.sampler = Sampler(config.vocab_size)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]:
hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
input_metadata)
return next_tokens
_column_parallel_weights = [
"embed_tokens.weight",
"lm_head.weight",
]
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
use_np_cache: bool = False):
tp_world_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, use_np_cache):
if "rotary_emb.inv_freq" in name:
continue
if "embed_tokens" in name or "lm_head" in name:
# Consider padding in the vocab size.
param = state_dict[name]
padded_vocab_size = param.shape[0] * tp_world_size
num_extra_rows = padded_vocab_size - self.config.vocab_size
extra_rows = torch.empty(num_extra_rows,
loaded_weight.shape[1])
extra_rows = extra_rows.to(loaded_weight)
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
if "W_pack" in name:
total_num_heads = self.config.num_attention_heads
hidden_size = self.config.hidden_size
head_size = hidden_size // total_num_heads
num_heads = total_num_heads // tp_world_size
head_start = tp_rank * num_heads
head_end = (tp_rank + 1) * num_heads
loaded_weight = loaded_weight.view(3, total_num_heads,
head_size, hidden_size)
loaded_weight = loaded_weight[:, head_start:head_end, :, :]
loaded_weight = loaded_weight.reshape(-1, hidden_size)
is_gate_up_weight = False
for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
if weight_name not in name:
continue
param = state_dict[name.replace(weight_name, "gate_up_proj")]
shard_size = param.shape[0] // 2
loaded_weight = loaded_weight[shard_size * tp_rank:shard_size *
(tp_rank + 1)]
param_slice = param.data[shard_size * stride_id:shard_size *
(stride_id + 1)]
assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight)
is_gate_up_weight = True
break
if is_gate_up_weight:
continue
param = state_dict[name]
load_tensor_parallel_weights(
param,
loaded_weight,
name,
self._column_parallel_weights,
self._row_parallel_weights,
tp_rank,
)
class BaichuanForCausalLM(BaiChuanBaseForCausalLM): # baichuan 13b
def __init__(self, config):
super().__init__(config, "ALIBI")
class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): # baichuan 7b
def __init__(self, config):
super().__init__(config, "ROPE")

View File

@ -284,10 +284,17 @@ class BloomForCausalLM(nn.Module):
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, use_np_cache):
if not name.startswith("transformer."):
name = "transformer." + name
if name == "lm_head.weight":
# Since hidden_states are parallelized, we need to
# load lm_head.weight in parallel.
self._column_parallel_weights.append(name)
# If lm_head is provided, use it instead.
param = self.lm_head_weight
else:
if not name.startswith("transformer."):
name = "transformer." + name
param = state_dict[name]
param = state_dict[name]
if "query_key_value" in name:
# NOTE(woosuk): BLOOM's fused QKV has the shape of
# [num_heads * 3 * head_size, hidden_size], while the

View File

@ -0,0 +1,496 @@
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/a5cc30d72ae2dc19af534e4b35c986cc28db1275/src/transformers/models/falcon/modeling_falcon.py
# Copyright 2023 The vLLM team.
# Copyright 2023 the Falcon authors and HuggingFace Inc. team. All rights
# reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Falcon model."""
import math
from typing import Dict, List, Optional, Tuple, Union
import torch
from torch import nn
from torch.nn import LayerNorm
from transformers import FalconConfig as HF_FalconConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.attention import (PagedAttention,
PagedAttentionWithALiBi,
PagedAttentionWithRoPE)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear,
reduce_from_tensor_model_parallel_region)
from vllm.sequence import SequenceOutputs
from vllm.transformers_utils.configs import RWConfig
KVCache = Tuple[torch.Tensor, torch.Tensor]
FalconConfig = Union[HF_FalconConfig, RWConfig]
# NOTE(Hesslow): Unfortunately we did not fuse matmul and bias during
# training, this means that there's one additional quantization to bfloat16
# between the operations. In order not to degrade the quality of our HF-port,
# we keep these characteristics in the final model.
class FalconLinear(nn.Linear):
def forward(self, x: torch.Tensor) -> torch.Tensor:
hidden_states = x @ self.weight.T
if self.bias is None:
return hidden_states
return hidden_states + self.bias
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
base = torch.tensor(2**(-(2**-(math.log2(closest_power_of_2) - 3))),
dtype=torch.float32)
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
slopes = torch.pow(base, powers)
if closest_power_of_2 != total_num_heads:
extra_base = torch.tensor(
2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
dtype=torch.float32)
num_remaining_heads = min(closest_power_of_2,
total_num_heads - closest_power_of_2)
extra_powers = torch.arange(1,
1 + 2 * num_remaining_heads,
2,
dtype=torch.int32)
slopes = torch.cat(
[slopes, torch.pow(extra_base, extra_powers)], dim=0)
return slopes
class FalconAttention(nn.Module):
def __init__(self, config: FalconConfig):
super().__init__()
self.hidden_size = config.hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = config.num_attention_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.head_dim = self.hidden_size // self.total_num_heads
assert self.head_dim * self.total_num_heads == self.hidden_size
self.new_decoder_architecture = config.new_decoder_architecture
self.multi_query = config.multi_query
if self.new_decoder_architecture:
self.total_num_kv_heads = config.num_kv_heads
assert self.total_num_heads % tp_size == 0
self.num_kv_heads = self.total_num_kv_heads // tp_size
self.query_key_value = ColumnParallelLinear(
self.hidden_size,
(self.total_num_heads + 2 * self.total_num_kv_heads) *
self.head_dim,
bias=config.bias,
gather_output=False,
perform_initialization=False,
skip_bias_add=True,
)
elif self.multi_query:
self.total_num_kv_heads = 1
self.num_kv_heads = 1
self.query = ColumnParallelLinear(
self.hidden_size,
self.total_num_heads * self.head_dim,
bias=config.bias,
gather_output=False,
perform_initialization=False,
skip_bias_add=True,
)
self.key_value = FalconLinear(self.hidden_size,
2 * self.head_dim,
bias=config.bias)
else:
self.total_num_kv_heads = self.total_num_heads
self.num_kv_heads = self.num_heads
self.query_key_value = ColumnParallelLinear(
self.hidden_size,
(self.total_num_heads + 2 * self.total_num_kv_heads) *
self.head_dim,
bias=config.bias,
gather_output=False,
perform_initialization=False,
skip_bias_add=True,
)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
# Layer-wise attention scaling
self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
self.reduce_row_parallel_results = not (config.new_decoder_architecture
or config.parallel_attn)
self.dense = RowParallelLinear(
self.hidden_size,
self.hidden_size,
bias=config.bias,
input_is_parallel=True,
perform_initialization=False,
skip_bias_add=True,
reduce_results=self.reduce_row_parallel_results)
self.use_rotary = config.rotary
self.use_alibi = config.alibi
assert not (self.use_rotary and self.use_alibi), (
"Rotary and alibi are mutually exclusive.")
if self.use_rotary:
# TODO(zhuohan): Pass in correct `max_position``
self.attn = PagedAttentionWithRoPE(self.num_heads,
self.head_dim,
self.inv_norm_factor,
rotary_dim=self.head_dim,
num_kv_heads=self.num_kv_heads)
elif self.use_alibi:
tp_rank = get_tensor_model_parallel_rank()
head_start = tp_rank * self.num_heads
head_end = (tp_rank + 1) * self.num_heads
alibi_slopes = (_get_alibi_slopes(self.total_num_heads) *
self.inv_norm_factor)
alibi_slopes = alibi_slopes[head_start:head_end].tolist()
self.attn = PagedAttentionWithALiBi(self.num_heads,
self.head_dim,
self.inv_norm_factor,
alibi_slopes,
num_kv_heads=self.num_kv_heads)
else:
self.attn = PagedAttention(self.num_heads,
self.head_dim,
scale=self.inv_norm_factor,
num_kv_heads=self.num_kv_heads)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
if not self.new_decoder_architecture and self.multi_query:
q, bias = self.query(hidden_states)
if bias is not None:
q += bias
kv = self.key_value(hidden_states)
k, v = kv.split([self.kv_size, self.kv_size], dim=-1)
else:
qkv, bias = self.query_key_value(hidden_states)
if bias is not None:
qkv += bias
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
dim=-1)
k_cache, v_cache = kv_cache
if self.use_rotary:
attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
input_metadata, cache_event)
else:
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
cache_event)
attn_output, bias = self.dense(attn_output)
return attn_output, bias
class FalconMLP(nn.Module):
def __init__(self, config: FalconConfig):
super().__init__()
hidden_size = config.hidden_size
self.dense_h_to_4h = ColumnParallelLinear(hidden_size,
4 * hidden_size,
bias=config.bias,
gather_output=False,
perform_initialization=False,
skip_bias_add=True)
self.act = nn.GELU()
self.reduce_row_parallel_results = not (config.new_decoder_architecture
or config.parallel_attn)
self.dense_4h_to_h = RowParallelLinear(
4 * hidden_size,
hidden_size,
bias=config.bias,
input_is_parallel=True,
perform_initialization=False,
skip_bias_add=True,
reduce_results=self.reduce_row_parallel_results)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# NOTE(zhuohan): Following huggingface, we do not fuse bias add here.
x, bias = self.dense_h_to_4h(x)
if bias is not None:
x += bias
x = self.act(x)
x, bias = self.dense_4h_to_h(x)
return x, bias
class FalconDecoderLayer(nn.Module):
def __init__(self, config: FalconConfig):
super().__init__()
hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.self_attention = FalconAttention(config)
self.mlp = FalconMLP(config)
self.config = config
if config.new_decoder_architecture:
# The layer norm before self-attention
self.ln_attn = LayerNorm(hidden_size,
eps=config.layer_norm_epsilon)
# The layer norm before the MLP
self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
else:
self.input_layernorm = LayerNorm(hidden_size,
eps=config.layer_norm_epsilon)
if not config.parallel_attn:
self.post_attention_layernorm = LayerNorm(
hidden_size, eps=config.layer_norm_epsilon)
self.reduce_row_parallel_results = not (config.new_decoder_architecture
or config.parallel_attn)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
):
residual = hidden_states
if self.config.new_decoder_architecture:
attention_layernorm_out = self.ln_attn(hidden_states)
mlp_layernorm_out = self.ln_mlp(hidden_states)
else:
attention_layernorm_out = self.input_layernorm(hidden_states)
# Self attention.
attention_output, attention_bias = self.self_attention(
positions=positions,
hidden_states=attention_layernorm_out,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event,
)
if self.reduce_row_parallel_results and attention_bias is not None:
attention_output += attention_bias
if not self.config.new_decoder_architecture:
if self.config.parallel_attn:
mlp_layernorm_out = attention_layernorm_out
else:
residual += attention_output
mlp_layernorm_out = self.post_attention_layernorm(residual)
# MLP.
mlp_output, mlp_bias = self.mlp(mlp_layernorm_out)
if self.reduce_row_parallel_results and mlp_bias is not None:
mlp_output += mlp_bias
if not self.reduce_row_parallel_results:
# When MLP and Attention layers are parallel, we can use
# only one all-reduce operator to reduce the results from
# both MLP and Attention layers.
mlp_output += attention_output
mlp_output = reduce_from_tensor_model_parallel_region(mlp_output)
if attention_bias is not None:
mlp_output += attention_bias
if mlp_bias is not None:
mlp_output += mlp_bias
output = mlp_output + residual
return output
class FalconModel(nn.Module):
def __init__(self, config: FalconConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.use_alibi = config.alibi
# Embedding + LN Embedding
self.word_embeddings = VocabParallelEmbedding(
config.vocab_size, self.embed_dim, perform_initialization=False)
# Transformer blocks
self.h = nn.ModuleList([
FalconDecoderLayer(config) for _ in range(config.num_hidden_layers)
])
# Final Layer Norm
self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
def forward(
self,
input_ids: torch.LongTensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.word_embeddings(input_ids)
for i in range(len(self.h)):
if cache_events is None:
cache_event = None
else:
cache_event = cache_events[i]
layer = self.h[i]
hidden_states = layer(
positions,
hidden_states,
kv_caches[i],
input_metadata,
cache_event,
)
hidden_states = self.ln_f(hidden_states)
return hidden_states
class FalconForCausalLM(nn.Module):
def __init__(self, config: FalconConfig):
super().__init__()
self.config = config
self.transformer = FalconModel(config)
self.lm_head = ColumnParallelLinear(config.hidden_size,
config.vocab_size,
bias=False,
gather_output=False,
perform_initialization=False)
self.sampler = Sampler(config.vocab_size)
def forward(
self,
input_ids: torch.LongTensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]:
hidden_states = self.transformer(
input_ids,
positions,
kv_caches,
input_metadata,
cache_events,
)
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
input_metadata)
return next_tokens
_column_parallel_weights = [
"word_embeddings.weight", "lm_head.weight", "dense_h_to_4h.weight",
"dense_h_to_4h.bias"
]
_row_parallel_weights = ["dense.weight", "dense_4h_to_h.weight"]
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
use_np_cache: bool = False):
tp_size = (get_tensor_model_parallel_world_size())
tp_rank = get_tensor_model_parallel_rank()
hidden_size = self.config.hidden_size
total_num_heads = self.config.num_attention_heads
num_heads = total_num_heads // tp_size
head_size = hidden_size // total_num_heads
head_start = tp_rank * num_heads
head_end = (tp_rank + 1) * num_heads
if self.config.new_decoder_architecture:
total_num_kv_heads = self.config.num_kv_heads
num_kv_heads = total_num_kv_heads // tp_size
separated_q_kv = False
kv_head_start = tp_rank * num_kv_heads
kv_head_end = (tp_rank + 1) * num_kv_heads
elif self.config.multi_query:
total_num_kv_heads = 1
num_kv_heads = 1
separated_q_kv = True
kv_head_start = 0
kv_head_end = 1
else:
total_num_kv_heads = total_num_heads
num_kv_heads = total_num_kv_heads // tp_size
separated_q_kv = False
kv_head_start = tp_rank * num_kv_heads
kv_head_end = (tp_rank + 1) * num_kv_heads
num_query_heads_per_kv_head = total_num_heads // total_num_kv_heads
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, use_np_cache):
if "query_key_value" in name:
loaded_weight_size = loaded_weight.size()
loaded_weight = loaded_weight.view(
total_num_kv_heads, num_query_heads_per_kv_head + 2,
head_size, *loaded_weight_size[1:])
wq = loaded_weight[:, :-2].reshape(-1, *loaded_weight_size[1:])
wk = loaded_weight[:, [-2]].reshape(-1,
*loaded_weight_size[1:])
wv = loaded_weight[:, [-1]].reshape(-1,
*loaded_weight_size[1:])
wq = wq[head_size * head_start:head_size * head_end]
wk = wk[head_size * kv_head_start:head_size * kv_head_end]
wv = wv[head_size * kv_head_start:head_size * kv_head_end]
if separated_q_kv:
loaded_weight_q = wq
loaded_weight_kv = torch.cat([wk, wv], dim=0)
q_weight_name = name.replace("query_key_value", "query")
kv_weight_name = name.replace("query_key_value",
"key_value")
load_tensor_parallel_weights(state_dict[q_weight_name],
loaded_weight_q,
q_weight_name,
self._column_parallel_weights,
self._row_parallel_weights,
tp_rank)
load_tensor_parallel_weights(state_dict[kv_weight_name],
loaded_weight_kv,
kv_weight_name,
self._column_parallel_weights,
self._row_parallel_weights,
tp_rank)
continue
else:
loaded_weight = torch.cat([wq, wk, wv], dim=0)
param = state_dict[name]
load_tensor_parallel_weights(param, loaded_weight, name,
self._column_parallel_weights,
self._row_parallel_weights, tp_rank)

View File

@ -26,7 +26,6 @@ from typing import Dict, List, Optional, Tuple
import torch
from torch import nn
import numpy as np
from transformers import GPTBigCodeConfig
from vllm.model_executor.input_metadata import InputMetadata
@ -57,11 +56,28 @@ class GPTBigCodeAttention(nn.Module):
self.head_dim = self.hidden_size // total_num_heads
self.scale = self.head_dim**-0.5
self.c_attn = ColumnParallelLinear(self.hidden_size,
3 * self.hidden_size,
bias=True,
gather_output=False,
perform_initialization=False)
self.multi_query = config.multi_query
if self.multi_query:
self.num_kv_heads = 1
self.kv_dim = self.head_dim
self.c_attn_q = ColumnParallelLinear(self.hidden_size,
self.hidden_size,
bias=True,
gather_output=False,
perform_initialization=False)
self.c_attn_kv = nn.Linear(self.hidden_size,
2 * self.kv_dim,
bias=True)
else:
self.num_kv_heads = self.num_heads
self.kv_dim = self.num_kv_heads * self.head_dim
self.c_attn = ColumnParallelLinear(self.hidden_size,
self.hidden_size +
2 * self.kv_dim,
bias=True,
gather_output=False,
perform_initialization=False)
self.c_proj = RowParallelLinear(self.hidden_size,
self.hidden_size,
bias=True,
@ -69,7 +85,8 @@ class GPTBigCodeAttention(nn.Module):
perform_initialization=False)
self.attn = PagedAttention(self.num_heads,
self.head_dim,
scale=self.scale)
scale=self.scale,
num_kv_heads=self.num_kv_heads)
def forward(
self,
@ -78,8 +95,14 @@ class GPTBigCodeAttention(nn.Module):
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
qkv, _ = self.c_attn(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
if self.multi_query:
q, _ = self.c_attn_q(hidden_states)
kv = self.c_attn_kv(hidden_states)
k, v = kv.split([self.kv_dim, self.kv_dim], dim=-1)
else:
qkv, _ = self.c_attn(hidden_states)
q, k, v = qkv.split([self.hidden_size, self.kv_dim, self.kv_dim],
dim=-1)
key_cache, value_cache = kv_cache
attn_output = self.attn(q, k, v, key_cache, value_cache,
input_metadata, cache_event)
@ -248,11 +271,58 @@ class GPTBigCodeForCausalLM(nn.Module):
# NOTE: "c_attn.bias" should not be skipped.
continue
param = state_dict[name]
if not name.startswith("transformer."):
name = "transformer." + name
# For the fused QKV linear layer, manually shard the weights.
if "c_attn" in name:
# GPT-2's fused QKV has the shape of
# [3 * num_heads * head_size, hidden_size].
# When tensor parallelism is used, we shard the weights along
# the head dimension.
total_num_heads = self.config.num_attention_heads
total_num_kv_heads = (1 if self.config.multi_query else
total_num_heads)
hidden_size = self.config.hidden_size
head_size = hidden_size // total_num_heads
total_kv_size = head_size * total_num_kv_heads
num_heads = total_num_heads // tensor_model_parallel_world_size
head_start = tensor_model_parallel_rank * num_heads
head_end = (tensor_model_parallel_rank + 1) * num_heads
wq, wk, wv = torch.split(
loaded_weight, [hidden_size, total_kv_size, total_kv_size],
dim=0)
wq = wq[head_size * head_start:head_size * head_end]
if not self.config.multi_query:
# Split the heads when using normal multi-head attention
wk = wk[head_size * head_start:head_size * head_end]
wv = wv[head_size * head_start:head_size * head_end]
loaded_weight = torch.cat([wq, wk, wv], dim=0)
else:
# For multi-query attention, we split the query
# but replicate the key and value.
loaded_weight_q = wq
loaded_weight_kv = torch.cat([wk, wv], dim=0)
q_weight_name = name.replace("c_attn", "c_attn_q")
kv_weight_name = name.replace("c_attn", "c_attn_kv")
load_tensor_parallel_weights(state_dict[q_weight_name],
loaded_weight_q,
q_weight_name,
self._column_parallel_weights,
self._row_parallel_weights,
tensor_model_parallel_rank)
load_tensor_parallel_weights(state_dict[kv_weight_name],
loaded_weight_kv,
kv_weight_name,
self._column_parallel_weights,
self._row_parallel_weights,
tensor_model_parallel_rank)
continue
param = state_dict[name]
if name == "transformer.wte.weight":
# Consider padding in the vocab size.
padded_vocab_size = param.shape[
@ -263,68 +333,6 @@ class GPTBigCodeForCausalLM(nn.Module):
extra_rows = extra_rows.to(loaded_weight)
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
def _expand_mqa_mha(qkv_array, n_head, head_dim):
"""manipulates along axis=0 from MQA to MHA
inputs: qkv_array.shape=((n_heads + 2) * head_dim, hidden_dim)
with n_heads for q, then 1 for k, 1 for 1 v, times head dim
return: qkv_array.shape=(3 * n_heads * head_dim, hidden_dim)
TODO: this function is no longer needed once vllm supports MQA.
"""
qkv_array = qkv_array.numpy()
dims_q = n_head * head_dim
# pylint: disable=unbalanced-tuple-unpacking
q, k, v = np.split(qkv_array, (dims_q, dims_q + head_dim),
axis=0)
# q is fine, but k & v have not replicated shape along the first
# axis as long as MQA is not nativly supported, increase memory
# and replicated (head_dim, hidden_dim) to
# (n_heads * head_dim, hidden_dim)
if k.ndim == 2 and v.ndim == 2:
replication = (n_head, 1) # weights
else:
replication = n_head # biases
# replicate n_head times for q, v
k, v = np.tile(k, replication), np.tile(v, replication)
# concat q, k, v along the first axis
# (n_heads * head_dim, hidden_dim)
# to (3 * n_heads * head_dim, hidden_dim)
qkv_array = np.concatenate((q, k, v), axis=0)
return torch.from_numpy(qkv_array)
# For the fused QKV linear layer, manually shard the weights.
if "c_attn" in name:
# GPT-2's fused QKV has the shape of
# [3 * num_heads * head_size, hidden_size].
# When tensor parallelism is used, we shard the weights along
# the head dimension.
total_num_heads = self.config.num_attention_heads
hidden_size = self.config.hidden_size
head_size = hidden_size // total_num_heads
num_heads = total_num_heads // tensor_model_parallel_world_size
head_start = tensor_model_parallel_rank * num_heads
head_end = (tensor_model_parallel_rank + 1) * num_heads
if name.endswith(".weight"):
loaded_weight = _expand_mqa_mha(loaded_weight,
n_head=total_num_heads,
head_dim=head_size)
loaded_weight = loaded_weight.view(3, total_num_heads,
head_size, hidden_size)
loaded_weight = loaded_weight[:, head_start:head_end, :, :]
loaded_weight = loaded_weight.reshape(-1, hidden_size)
elif name.endswith(".bias"):
loaded_weight = _expand_mqa_mha(loaded_weight,
n_head=total_num_heads,
head_dim=head_size)
loaded_weight = loaded_weight.view(3, total_num_heads,
head_size)
loaded_weight = loaded_weight[:, head_start:head_end, :]
loaded_weight = loaded_weight.reshape(-1)
else:
raise ValueError(f"Unexpected parameter name {name}")
load_tensor_parallel_weights(param, loaded_weight, name,
self._column_parallel_weights,
self._row_parallel_weights,

View File

@ -0,0 +1,251 @@
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gptj/modeling_gptj.py
# Copyright 2023 The vLLM team.
# Copyright 2021 The EleutherAI and HuggingFace Teams. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GPT-J model compatible with HuggingFace weights.
The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
"""
from typing import Dict, List, Optional, Tuple
import torch
from torch import nn
from transformers import GPTJConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
from vllm.sequence import SequenceOutputs
KVCache = Tuple[torch.Tensor, torch.Tensor]
class GPTJAttention(nn.Module):
def __init__(self, config: GPTJConfig):
super().__init__()
self.total_num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.total_num_heads
self.qkv_proj = ColumnParallelLinear(config.hidden_size,
3 * config.hidden_size,
bias=False,
gather_output=False,
perform_initialization=False)
self.out_proj = RowParallelLinear(config.hidden_size,
config.hidden_size,
bias=False,
input_is_parallel=True,
perform_initialization=False)
tp_world_size = get_tensor_model_parallel_world_size()
assert self.total_num_heads % tp_world_size == 0
self.num_heads = self.total_num_heads // tp_world_size
scaling = self.head_size**-0.5
assert getattr(config, "rotary", True)
assert config.rotary_dim % 2 == 0
self.attn = PagedAttentionWithRoPE(self.num_heads, self.head_size,
scaling, config.rotary_dim)
self.warmup = False
def forward(
self,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
k_cache, v_cache = kv_cache
attn_output = self.attn(position_ids, q, k, v, k_cache, v_cache,
input_metadata, cache_event)
attn_output, _ = self.out_proj(attn_output)
return attn_output
class GPTJMLP(nn.Module):
def __init__(self, intermediate_size: int, config: GPTJConfig):
super().__init__()
hidden_size = config.n_embd
self.fc_in = ColumnParallelLinear(hidden_size,
intermediate_size,
gather_output=False,
perform_initialization=False)
self.fc_out = RowParallelLinear(intermediate_size,
hidden_size,
input_is_parallel=True,
perform_initialization=False)
self.act = get_act_fn(config.activation_function)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.fc_in(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states, _ = self.fc_out(hidden_states)
return hidden_states
class GPTJBlock(nn.Module):
def __init__(self, config: GPTJConfig):
super().__init__()
if config.n_inner is None:
inner_dim = 4 * config.n_embd
else:
inner_dim = config.n_inner
self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.attn = GPTJAttention(config)
self.mlp = GPTJMLP(inner_dim, config)
def forward(
self,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
attn_output = self.attn(
position_ids=position_ids,
hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event,
)
mlp_output = self.mlp(hidden_states)
hidden_states = attn_output + mlp_output + residual
return hidden_states
class GPTJModel(nn.Module):
def __init__(self, config: GPTJConfig):
super().__init__()
self.config = config
self.embed_dim = config.n_embd
self.wte = VocabParallelEmbedding(config.vocab_size,
self.embed_dim,
perform_initialization=False)
self.h = nn.ModuleList(
[GPTJBlock(config) for _ in range(config.n_layer)])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.wte(input_ids)
for i in range(len(self.h)):
if cache_events is None:
cache_event = None
else:
cache_event = cache_events[i]
layer = self.h[i]
hidden_states = layer(
position_ids,
hidden_states,
kv_caches[i],
input_metadata,
cache_event,
)
hidden_states = self.ln_f(hidden_states)
return hidden_states
class GPTJForCausalLM(nn.Module):
def __init__(self, config: GPTJConfig):
super().__init__()
self.config = config
assert not config.tie_word_embeddings
self.transformer = GPTJModel(config)
self.lm_head = ColumnParallelLinear(config.n_embd,
config.vocab_size,
gather_output=False,
perform_initialization=False)
self.sampler = Sampler(config.vocab_size)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]:
hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
input_metadata, self.lm_head.bias)
return next_tokens
_column_parallel_weights = [
"wte.weight", "fc_in.weight", "fc_in.bias", "lm_head.weight",
"lm_head.bias"
]
_row_parallel_weights = ["out_proj.weight", "fc_out.weight"]
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
use_np_cache: bool = False):
tp_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, use_np_cache):
if "attn.bias" in name or "attn.masked_bias" in name:
continue
is_attention_weight = False
for stride_id, att_weight_name in enumerate(
["q_proj", "k_proj", "v_proj"]):
if att_weight_name not in name:
continue
param = state_dict[name.replace(att_weight_name, "qkv_proj")]
shard_size = param.shape[1]
loaded_weight = loaded_weight[shard_size * tp_rank:shard_size *
(tp_rank + 1)]
param_slice = param.data[shard_size * stride_id:shard_size *
(stride_id + 1)]
assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight)
is_attention_weight = True
break
if is_attention_weight:
continue
param = state_dict[name]
load_tensor_parallel_weights(param, loaded_weight, name,
self._column_parallel_weights,
self._row_parallel_weights, tp_rank)

View File

@ -84,21 +84,26 @@ class LlamaAttention(nn.Module):
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
):
super().__init__()
self.hidden_size = hidden_size
tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tensor_model_parallel_world_size == 0
self.num_heads = (self.total_num_heads //
tensor_model_parallel_world_size)
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
assert self.total_num_kv_heads % tp_size == 0
self.num_kv_heads = self.total_num_kv_heads // tp_size
self.head_dim = hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.qkv_proj = ColumnParallelLinear(
hidden_size,
3 * self.total_num_heads * self.head_dim,
(self.total_num_heads + 2 * self.total_num_kv_heads) *
self.head_dim,
bias=False,
gather_output=False,
perform_initialization=False,
@ -113,7 +118,8 @@ class LlamaAttention(nn.Module):
self.attn = PagedAttentionWithRoPE(self.num_heads,
self.head_dim,
self.scaling,
rotary_dim=self.head_dim)
rotary_dim=self.head_dim,
num_kv_heads=self.num_kv_heads)
def forward(
self,
@ -124,7 +130,7 @@ class LlamaAttention(nn.Module):
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
k_cache, v_cache = kv_cache
attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
input_metadata, cache_event)
@ -140,6 +146,7 @@ class LlamaDecoderLayer(nn.Module):
self.self_attn = LlamaAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
)
self.mlp = LlamaMLP(
hidden_size=self.hidden_size,
@ -187,10 +194,9 @@ class LlamaModel(nn.Module):
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
vocab_size = ((config.vocab_size + 63) // 64) * 64
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
perform_initialization=False)
vocab_size, config.hidden_size, perform_initialization=False)
self.layers = nn.ModuleList([
LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)
])
@ -228,8 +234,9 @@ class LlamaForCausalLM(nn.Module):
super().__init__()
self.config = config
self.model = LlamaModel(config)
vocab_size = ((config.vocab_size + 63) // 64) * 64
self.lm_head = ColumnParallelLinear(config.hidden_size,
config.vocab_size,
vocab_size,
bias=False,
gather_output=False,
perform_initialization=False)
@ -259,7 +266,19 @@ class LlamaForCausalLM(nn.Module):
model_name_or_path: str,
cache_dir: Optional[str] = None,
use_np_cache: bool = False):
tp_size = get_tensor_model_parallel_world_size()
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
q_proj_shard_size = (self.config.hidden_size // tp_size)
kv_proj_shard_size = (self.config.hidden_size //
self.config.num_attention_heads *
self.config.num_key_value_heads // tp_size)
attention_weight_specs = [
# (weight_name, shard_size, offset)
("q_proj", q_proj_shard_size, 0),
("k_proj", kv_proj_shard_size, q_proj_shard_size),
("v_proj", kv_proj_shard_size,
q_proj_shard_size + kv_proj_shard_size),
]
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
@ -267,19 +286,28 @@ class LlamaForCausalLM(nn.Module):
if "rotary_emb.inv_freq" in name:
continue
if "embed_tokens" in name or "lm_head" in name:
param = state_dict[name]
# Consider padding in the vocab size.
padded_vocab_size = (param.shape[0] * tp_size)
num_extra_rows = padded_vocab_size - self.config.vocab_size
extra_rows = torch.empty(num_extra_rows,
loaded_weight.shape[1])
extra_rows = extra_rows.to(loaded_weight)
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
is_attention_weight = False
for stride_id, att_weight_name in enumerate(
["q_proj", "k_proj", "v_proj"]):
if att_weight_name not in name:
for weight_name, shard_size, offset in attention_weight_specs:
if weight_name not in name:
continue
param = state_dict[name.replace(att_weight_name, "qkv_proj")]
shard_size = param.shape[0] // 3
param = state_dict[name.replace(weight_name, "qkv_proj")]
loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank:shard_size *
(tensor_model_parallel_rank + 1)]
param_slice = param.data[shard_size * stride_id:shard_size *
(stride_id + 1)]
param_slice = param.data[offset:offset + shard_size]
assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight)
is_attention_weight = True
break

View File

@ -1,3 +1,4 @@
# coding=utf-8
# Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main
import math
from typing import Dict, List, Optional, Tuple

View File

@ -1,9 +1,6 @@
import vllm.model_executor.parallel_utils.parallel_state
import vllm.model_executor.parallel_utils.tensor_parallel
# Alias parallel_state as mpu, its legacy name
mpu = parallel_state
__all__ = [
"parallel_state",
"tensor_parallel",

View File

@ -44,7 +44,6 @@ _PIPELINE_GLOBAL_RANKS = None
# rank when broadcasting weights from src to all other data parallel ranks
_DATA_PARALLEL_GLOBAL_RANKS = None
_ALL_REDUCE_LAUNCHER: Optional['GraphAllReduce'] = None
def initialize_model_parallel(
tensor_model_parallel_size: int = 1,
@ -196,20 +195,6 @@ def initialize_model_parallel(
if rank in ranks:
_POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks
def initialize_all_reduce_launcher(
max_num_tokens: int,
hidden_size: int,
dtype: torch.dtype,
disable_graph: bool = False,
) -> None:
global _ALL_REDUCE_LAUNCHER
_ALL_REDUCE_LAUNCHER = GraphAllReduce(
max_num_tokens=max_num_tokens,
hidden_size=hidden_size,
dtype=dtype,
disable_graph=disable_graph,
)
def model_parallel_is_initialized():
"""Check if model and data parallel groups are initialized."""
if _TENSOR_MODEL_PARALLEL_GROUP is None or \
@ -458,6 +443,7 @@ def get_pipeline_model_parallel_last_rank():
last_rank_local = get_pipeline_model_parallel_world_size() - 1
return _PIPELINE_GLOBAL_RANKS[last_rank_local]
def get_pipeline_model_parallel_next_rank():
"""Return the global rank that follows the caller in the pipeline"""
assert _PIPELINE_GLOBAL_RANKS is not None, \
@ -485,10 +471,6 @@ def get_data_parallel_rank():
"""Return my rank for the data parallel group."""
return torch.distributed.get_rank(group=get_data_parallel_group())
def get_all_reduce_launcher() -> 'GraphAllReduce':
assert _ALL_REDUCE_LAUNCHER is not None, 'all reduce launcher is not initialized'
return _ALL_REDUCE_LAUNCHER
def destroy_model_parallel():
"""Set the groups to none."""
global _MODEL_PARALLEL_GROUP
@ -515,56 +497,3 @@ def destroy_model_parallel():
_MPU_TENSOR_MODEL_PARALLEL_RANK = None
global _MPU_PIPELINE_MODEL_PARALLEL_RANK
_MPU_PIPELINE_MODEL_PARALLEL_RANK = None
class GraphAllReduce:
def __init__(
self,
max_num_tokens: int,
hidden_size: int,
dtype: torch.dtype,
disable_graph: bool = False,
) -> None:
self.max_num_tokens = max_num_tokens
self.hidden_size = hidden_size
self.disable_graph = disable_graph
tp_world_size = get_tensor_model_parallel_world_size()
if tp_world_size == 1:
return
self.group = get_tensor_model_parallel_group()
self.buffer = torch.empty(
size=(max_num_tokens, hidden_size),
dtype=dtype,
device='cuda',
)
# Build graphs for different number of tokens.
if not self.disable_graph:
self.graphs = {}
for num_tokens in range(8, max_num_tokens + 1, 8):
self.graphs[num_tokens] = self._build_graph(num_tokens)
def _build_graph(self, num_tokens: int) -> torch.cuda.CUDAGraph:
# Warm up.
torch.distributed.all_reduce(self.buffer[:num_tokens], group=self.group)
torch.cuda.synchronize()
# Build graph.
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
torch.distributed.all_reduce(
self.buffer[:num_tokens], group=self.group)
torch.cuda.synchronize()
return graph
def launch(self, x: torch.Tensor) -> torch.Tensor:
# NOTE: x must be a slice of self.buffer.
num_tokens = x.shape[0]
if self.disable_graph:
torch.distributed.all_reduce(x, group=self.group)
else:
self.graphs[num_tokens].replay()
return x

View File

@ -12,6 +12,7 @@ from .mappings import (
copy_to_tensor_model_parallel_region,
gather_from_tensor_model_parallel_region,
gather_from_sequence_parallel_region,
reduce_from_tensor_model_parallel_region,
scatter_to_tensor_model_parallel_region,
scatter_to_sequence_parallel_region,
)
@ -38,7 +39,7 @@ __all__ = [
"copy_to_tensor_model_parallel_region",
"gather_from_tensor_model_parallel_region",
"gather_from_sequence_parallel_region",
# "reduce_from_tensor_model_parallel_region",
"reduce_from_tensor_model_parallel_region",
"scatter_to_tensor_model_parallel_region",
"scatter_to_sequence_parallel_region",
# random.py

View File

@ -14,7 +14,6 @@ from torch.nn.parameter import Parameter
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
get_all_reduce_launcher,
)
from .mappings import (
copy_to_tensor_model_parallel_region,
@ -248,8 +247,8 @@ class ColumnParallelLinear(torch.nn.Module):
self.output_size = output_size
self.gather_output = gather_output
# Divide the weight matrix along the last dimension.
world_size = get_tensor_model_parallel_world_size()
self.output_size_per_partition = divide(output_size, world_size)
self.world_size = get_tensor_model_parallel_world_size()
self.output_size_per_partition = divide(output_size, self.world_size)
self.skip_bias_add = skip_bias_add
if params_dtype is None:
@ -350,6 +349,7 @@ class RowParallelLinear(torch.nn.Module):
params_dtype:
use_cpu_initialization:
perform_initialization:
reduce_results:
"""
def __init__(self, input_size, output_size, *,
@ -360,6 +360,7 @@ class RowParallelLinear(torch.nn.Module):
params_dtype=None,
use_cpu_initialization=False,
perform_initialization=True,
reduce_results=True,
):
super(RowParallelLinear, self).__init__()
@ -367,14 +368,19 @@ class RowParallelLinear(torch.nn.Module):
self.input_size = input_size
self.output_size = output_size
self.input_is_parallel = input_is_parallel
self.reduce_results = reduce_results
if params_dtype is None:
params_dtype = torch.get_default_dtype()
# Divide the weight matrix along the last dimension.
world_size = get_tensor_model_parallel_world_size()
self.input_size_per_partition = divide(input_size, world_size)
self.world_size = get_tensor_model_parallel_world_size()
self.input_size_per_partition = divide(input_size, self.world_size)
self.skip_bias_add = skip_bias_add
if not reduce_results and (bias and not skip_bias_add):
raise ValueError("When not reduce the results, adding bias to the "
"results can lead to incorrect results")
# Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result
# we allocate the transpose.
@ -427,17 +433,12 @@ class RowParallelLinear(torch.nn.Module):
input_parallel = input_
else:
input_parallel = scatter_to_tensor_model_parallel_region(input_)
if get_tensor_model_parallel_world_size() == 1:
# Matrix multiply.
output_ = F.linear(input_parallel, self.weight)
# Matrix multiply.
output_parallel = F.linear(input_parallel, self.weight)
if self.reduce_results and self.world_size > 1:
output_ = reduce_from_tensor_model_parallel_region(output_parallel)
else:
# Matrix multiply.
all_reduce_launcher = get_all_reduce_launcher()
num_tokens = input_parallel.shape[0]
output_buffer = all_reduce_launcher.buffer[:num_tokens]
torch.matmul(input_parallel, self.weight_t, out=output_buffer)
# All-reduce across all the partitions.
output_ = all_reduce_launcher.launch(output_buffer)
output_ = output_parallel
if not self.skip_bias_add:
output = output_ + self.bias if self.bias is not None else output_

View File

@ -39,7 +39,10 @@ def hf_model_weights_iterator(
else:
hf_folder = model_name_or_path
hf_bin_files = glob.glob(os.path.join(hf_folder, "*.bin"))
hf_bin_files = [
x for x in glob.glob(os.path.join(hf_folder, "*.bin"))
if not x.endswith("training_args.bin")
]
if use_np_cache:
# Convert the model weights from torch tensors to numpy arrays for

View File

@ -127,7 +127,8 @@ class Sequence:
self.logical_token_blocks.append(block)
def _append_tokens_to_blocks(self, token_ids: List[int]) -> None:
while token_ids:
cursor = 0
while cursor < len(token_ids):
if not self.logical_token_blocks:
self._append_logical_block()
@ -137,8 +138,9 @@ class Sequence:
last_block = self.logical_token_blocks[-1]
num_empty_slots = last_block.get_num_empty_slots()
last_block.append_tokens(token_ids[:num_empty_slots])
token_ids = token_ids[num_empty_slots:]
last_block.append_tokens(token_ids[cursor:cursor +
num_empty_slots])
cursor += num_empty_slots
def append_token_id(
self,

View File

@ -4,11 +4,27 @@ from vllm.transformers_utils.configs import * # pylint: disable=wildcard-import
_CONFIG_REGISTRY = {
"mpt": MPTConfig,
"baichuan": BaiChuanConfig,
"RefinedWeb": RWConfig, # For tiiuae/falcon-40b(-instruct)
"RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct)
}
def get_config(model: str) -> PretrainedConfig:
config = AutoConfig.from_pretrained(model, trust_remote_code=True)
def get_config(model: str, trust_remote_code: bool) -> PretrainedConfig:
try:
config = AutoConfig.from_pretrained(
model, trust_remote_code=trust_remote_code)
except ValueError as e:
if (not trust_remote_code and
"requires you to execute the configuration file" in str(e)):
err_msg = (
"Failed to load the model config. If the model is a custom "
"model not yet available in the HuggingFace transformers "
"library, consider setting `trust_remote_code=True` in LLM "
"or using the `--trust-remote-code` flag in the CLI.")
raise RuntimeError(err_msg) from e
else:
raise e
if config.model_type in _CONFIG_REGISTRY:
config_class = _CONFIG_REGISTRY[config.model_type]
config = config_class.from_pretrained(model)

View File

@ -1,5 +1,12 @@
from vllm.transformers_utils.configs.mpt import MPTConfig
from vllm.transformers_utils.configs.baichuan import BaiChuanConfig
# RWConfig is for the original tiiuae/falcon-40b(-instruct) and
# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
# `FalconConfig` class from the official HuggingFace transformers library.
from vllm.transformers_utils.configs.falcon import RWConfig
__all__ = [
"MPTConfig",
"BaiChuanConfig",
"RWConfig",
]

View File

@ -0,0 +1,62 @@
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from transformers.configuration_utils import PretrainedConfig
class BaiChuanConfig(PretrainedConfig):
model_type = "baichuan"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=64000,
hidden_size=4096,
intermediate_size=11008,
num_hidden_layers=32,
num_attention_heads=32,
hidden_act="silu",
max_position_embeddings=4096,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
tie_word_embeddings=False,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)

View File

@ -0,0 +1,87 @@
# Adapted from
# https://huggingface.co/tiiuae/falcon-7b/blob/main/configuration_RW.py
# Copyright 2023 The vLLM team.
# Copyright 2022 the Big Science Workshop and HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Falcon configuration"""
from transformers.configuration_utils import PretrainedConfig
class RWConfig(PretrainedConfig):
model_type = "falcon"
keys_to_ignore_at_inference = ["past_key_values"]
attribute_map = {
"num_hidden_layers": "n_layer",
"num_attention_heads": "n_head",
"num_kv_heads": "n_head_kv",
}
def __init__(
self,
vocab_size=250880,
hidden_size=64,
n_layer=2,
n_head=8,
layer_norm_epsilon=1e-5,
initializer_range=0.02,
use_cache=True,
bos_token_id=1,
eos_token_id=2,
hidden_dropout=0.0,
attention_dropout=0.0,
multi_query=True,
n_head_kv=None,
alibi=False,
bias=False,
parallel_attn=False,
new_decoder_architecture=False,
**kwargs,
) -> None:
self.vocab_size = vocab_size
# Backward compatibility with n_embed kwarg
n_embed = kwargs.pop("n_embed", None)
self.hidden_size = hidden_size if n_embed is None else n_embed
self.n_layer = n_layer
self.n_head = n_head
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_range = initializer_range
self.use_cache = use_cache
self.hidden_dropout = hidden_dropout
self.attention_dropout = attention_dropout
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.multi_query = multi_query
self.n_head_kv = 1 if n_head_kv is None else n_head_kv
self.alibi = alibi
self.bias = bias
self.parallel_attn = parallel_attn
self.new_decoder_architecture = new_decoder_architecture
if self.hidden_size == 8192:
# Hack for falcon-40b
self.new_decoder_architecture = True
super().__init__(bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
**kwargs)
@property
def head_dim(self):
return self.hidden_size // self.n_head
@property
def rotary(self):
return not self.alibi

View File

@ -15,6 +15,7 @@ def get_tokenizer(
tokenizer_name: str,
*args,
tokenizer_mode: str = "auto",
trust_remote_code: bool = False,
**kwargs,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
"""Gets a tokenizer for the given model name via Huggingface."""
@ -31,8 +32,11 @@ def get_tokenizer(
f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original "
"tokenizer.")
try:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, *args,
**kwargs)
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name,
*args,
trust_remote_code=trust_remote_code,
**kwargs)
except TypeError as e:
# The LLaMA tokenizer causes a protobuf error in some environments.
err_msg = (
@ -40,6 +44,20 @@ def get_tokenizer(
f"model, use '{_FAST_LLAMA_TOKENIZER}' instead of the original "
"tokenizer.")
raise RuntimeError(err_msg) from e
except ValueError as e:
# If the error pertains to the tokenizer class not existing or not
# currently being imported, suggest using the --trust-remote-code flag.
if (not trust_remote_code and
("does not exist or is not currently imported." in str(e)
or "requires you to execute the tokenizer file" in str(e))):
err_msg = (
"Failed to load the tokenizer. If the tokenizer is a custom "
"tokenizer not yet available in the HuggingFace transformers "
"library, consider setting `trust_remote_code=True` in LLM "
"or using the `--trust-remote-code` flag in the CLI.")
raise RuntimeError(err_msg) from e
else:
raise e
if not isinstance(tokenizer, PreTrainedTokenizerFast):
logger.warning(
@ -62,6 +80,8 @@ def detokenize_incrementally(
new_token: The new token as a string.
output_text: The new output text as a string.
"""
if skip_special_tokens and (new_token_id in tokenizer.all_special_ids):
return None, prev_output_tokens
new_token = tokenizer.convert_ids_to_tokens(
new_token_id, skip_special_tokens=skip_special_tokens)
output_tokens = prev_output_tokens + [new_token]
@ -81,7 +101,7 @@ def detokenize_incrementally(
sub_texts = []
current_sub_text = []
for token in output_tokens:
if skip_special_tokens and token in tokenizer.all_special_ids:
if skip_special_tokens and token in tokenizer.all_special_tokens:
continue
if token in tokenizer.added_tokens_encoder:
if current_sub_text:

View File

@ -1,13 +1,15 @@
"""A GPU worker class."""
from typing import Dict, List, Tuple
import os
from typing import Dict, List, Tuple, Optional
import torch
import torch.distributed
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig)
from vllm.model_executor import get_model, InputMetadata, set_random_seed
from vllm.model_executor.parallel_utils.parallel_state import (
initialize_model_parallel, initialize_all_reduce_launcher)
initialize_model_parallel)
from vllm.sampling_params import SamplingParams
from vllm.sequence import SequenceData, SequenceGroupMetadata, SequenceOutputs
from vllm.worker.cache_engine import CacheEngine
@ -27,8 +29,8 @@ class Worker:
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
rank: int,
distributed_init_method: str,
rank: Optional[int] = None,
distributed_init_method: Optional[str] = None,
) -> None:
self.model_config = model_config
self.parallel_config = parallel_config
@ -36,19 +38,6 @@ class Worker:
self.rank = rank
self.distributed_init_method = distributed_init_method
# Initialize the distributed environment.
_init_distributed_environment(parallel_config, rank,
distributed_init_method)
# Initialize the model.
set_random_seed(self.model_config.seed)
self.model = get_model(model_config)
initialize_all_reduce_launcher(
self.scheduler_config.max_num_batched_tokens,
self.model_config.get_hidden_size(),
self.model_config.dtype,
)
# Uninitialized cache engine. Will be initialized by
# self.init_cache_engine().
self.cache_config = None
@ -57,6 +46,26 @@ class Worker:
self.cache_events = None
self.gpu_cache = None
def init_model(self):
# This env var set by Ray causes exceptions with graph building.
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
# Env vars will be set by Ray.
self.rank = self.rank if self.rank is not None else int(
os.getenv("RANK", "-1"))
local_rank = int(os.getenv("LOCAL_RANK", "0"))
self.device = torch.device(f"cuda:{local_rank}")
if self.rank < 0:
raise ValueError("Invalid or unspecified rank.")
torch.cuda.set_device(self.device)
# Initialize the distributed environment.
_init_distributed_environment(self.parallel_config, self.rank,
self.distributed_init_method)
# Initialize the model.
set_random_seed(self.model_config.seed)
self.model = get_model(self.model_config)
@torch.inference_mode()
def profile_num_available_blocks(
self,
@ -294,15 +303,28 @@ class Worker:
def _init_distributed_environment(
parallel_config: ParallelConfig,
rank: int,
distributed_init_method: str,
distributed_init_method: Optional[str] = None,
) -> None:
"""Initialize the distributed environment."""
torch.distributed.init_process_group(
backend="nccl",
world_size=parallel_config.world_size,
rank=rank,
init_method=distributed_init_method,
)
if torch.distributed.is_initialized():
torch_world_size = torch.distributed.get_world_size()
if torch_world_size != parallel_config.world_size:
raise RuntimeError(
"torch.distributed is already initialized but the torch world "
"size does not match parallel_config.world_size "
f"({torch_world_size} vs. {parallel_config.world_size}).")
elif not distributed_init_method:
raise ValueError(
"distributed_init_method must be set if torch.distributed "
"is not already initialized")
else:
torch.distributed.init_process_group(
backend="nccl",
world_size=parallel_config.world_size,
rank=rank,
init_method=distributed_init_method,
)
# A small all_reduce for warmup.
torch.distributed.all_reduce(torch.zeros(1).cuda())
initialize_model_parallel(parallel_config.tensor_parallel_size,