mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
Compare commits
27 Commits
Author | SHA1 | Date | |
---|---|---|---|
852ef5b4f5 | |||
db09d4ad83 | |||
c957c741d9 | |||
c07ece5ca4 | |||
7a9c20c715 | |||
005ba458b5 | |||
320a622ec4 | |||
c9927c1a6a | |||
fbd80ad409 | |||
22379d5513 | |||
1696725879 | |||
002800f081 | |||
e15932bb60 | |||
ce741ba3e4 | |||
bf87484efa | |||
8ce9c50d40 | |||
32b6816e55 | |||
c128d69856 | |||
55b28b1eee | |||
e11222333f | |||
28873a2799 | |||
0080d8329d | |||
0d93f15694 | |||
becd7a56f1 | |||
75471386de | |||
d2b2eed67c | |||
4b6f069b6f |
@ -17,6 +17,7 @@ Easy, fast, and cheap LLM serving for everyone
|
||||
---
|
||||
|
||||
*Latest News* 🔥
|
||||
- [2023/08] We would like to express our sincere gratitude to [Andreessen Horowitz](https://a16z.com/2023/08/30/supporting-the-open-source-ai-community/) (a16z) for providing a generous grant to support the open-source development and research of vLLM.
|
||||
- [2023/07] Added support for LLaMA-2! You can run and serve 7B/13B/70B LLaMA-2s on vLLM with a single command!
|
||||
- [2023/06] Serving vLLM On any Cloud with SkyPilot. Check out a 1-click [example](https://github.com/skypilot-org/skypilot/blob/master/llm/vllm) to start the vLLM demo, and the [blog post](https://blog.skypilot.co/serving-llm-24x-faster-on-the-cloud-with-vllm-and-skypilot/) for the story behind vLLM development on the clouds.
|
||||
- [2023/06] We officially released vLLM! FastChat-vLLM integration has powered [LMSYS Vicuna and Chatbot Arena](https://chat.lmsys.org) since mid-April. Check out our [blog post](https://vllm.ai).
|
||||
|
@ -1,6 +1,8 @@
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include "dispatch_utils.h"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
template<typename T>
|
||||
@ -34,9 +36,7 @@ void silu_and_mul(
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(d, 1024));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
at::ScalarType::Half,
|
||||
at::ScalarType::BFloat16,
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(),
|
||||
"silu_and_mul_kernel",
|
||||
[&] {
|
||||
@ -71,9 +71,7 @@ __global__ void activation_kernel(
|
||||
dim3 grid(num_tokens); \
|
||||
dim3 block(std::min(d, 1024)); \
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2( \
|
||||
at::ScalarType::Half, \
|
||||
at::ScalarType::BFloat16, \
|
||||
VLLM_DISPATCH_FLOATING_TYPES( \
|
||||
input.scalar_type(), \
|
||||
"activation_kernel", \
|
||||
[&] { \
|
||||
|
@ -178,7 +178,7 @@ __global__ void single_query_cached_kv_attention_kernel(
|
||||
// This includes a reduction across the threads in the same thread group.
|
||||
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs);
|
||||
// Add the ALiBi bias if slopes are given.
|
||||
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len) : 0;
|
||||
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0;
|
||||
|
||||
if (thread_group_offset == 0) {
|
||||
// Store the partial reductions to shared memory.
|
||||
@ -246,6 +246,8 @@ __global__ void single_query_cached_kv_attention_kernel(
|
||||
accs[i] = 0.f;
|
||||
}
|
||||
|
||||
scalar_t zero_value;
|
||||
zero(zero_value);
|
||||
for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) {
|
||||
const int physical_block_number = block_table[block_idx];
|
||||
const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
|
||||
@ -261,6 +263,16 @@ __global__ void single_query_cached_kv_attention_kernel(
|
||||
if (row_idx < HEAD_SIZE) {
|
||||
const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
|
||||
V_vec v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
|
||||
if (block_idx == num_blocks - 1) {
|
||||
// NOTE(woosuk): When v_vec contains the tokens that are out of the context,
|
||||
// we should explicitly zero out the values since they may contain NaNs.
|
||||
// See https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
|
||||
scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
|
||||
#pragma unroll
|
||||
for (int j = 0; j <= V_VEC_SIZE; j++) {
|
||||
v_vec_ptr[j] = token_idx + j < context_len ? v_vec_ptr[j] : zero_value;
|
||||
}
|
||||
}
|
||||
accs[i] += dot(logits_vec, v_vec);
|
||||
}
|
||||
}
|
||||
|
@ -420,4 +420,14 @@ inline __device__ void from_float(bf16_8_t& dst, Float8_ src) {
|
||||
#endif
|
||||
}
|
||||
|
||||
// Zero-out a variable.
|
||||
inline __device__ void zero(__nv_bfloat16& dst) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
assert(false);
|
||||
#else
|
||||
// Same as CUDART_ZERO_BF16 introduced in CUDA 12.2.
|
||||
dst = __ushort_as_bfloat16((unsigned short)0x0000U);
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
@ -390,11 +390,6 @@ inline __device__ float sum(uint4 v) {
|
||||
return sum(c);
|
||||
}
|
||||
|
||||
// Zero-out a vector.
|
||||
inline __device__ void zero(uint16_t& dst) {
|
||||
dst = uint16_t(0);
|
||||
}
|
||||
|
||||
// From float32 to float16.
|
||||
inline __device__ void from_float(uint16_t& dst, float src) {
|
||||
dst = float_to_half(src);
|
||||
@ -441,4 +436,9 @@ inline __device__ Float8_ to_float(uint4 u) {
|
||||
return tmp;
|
||||
}
|
||||
|
||||
// Zero-out a variable.
|
||||
inline __device__ void zero(uint16_t& dst) {
|
||||
dst = uint16_t(0);
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
@ -265,4 +265,9 @@ inline __device__ Float8_ to_float(Float8_ u) {
|
||||
return u;
|
||||
}
|
||||
|
||||
// Zero-out a variable.
|
||||
inline __device__ void zero(float& dst) {
|
||||
dst = 0.f;
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
@ -1,6 +1,8 @@
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include "dispatch_utils.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <map>
|
||||
@ -125,9 +127,7 @@ void copy_blocks(
|
||||
dim3 grid(num_layers, num_pairs);
|
||||
dim3 block(std::min(1024, numel_per_block));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
at::ScalarType::Half,
|
||||
at::ScalarType::BFloat16,
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
|
||||
vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
key_cache_ptrs_tensor.data_ptr<int64_t>(),
|
||||
@ -202,9 +202,7 @@ void reshape_and_cache(
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(num_heads * head_size, 512));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
at::ScalarType::Half,
|
||||
at::ScalarType::BFloat16,
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
key.scalar_type(),
|
||||
"reshape_and_cache_kernel",
|
||||
[&] {
|
||||
@ -364,9 +362,7 @@ void gather_cached_kv(
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(num_heads * head_size, 512));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
at::ScalarType::Half,
|
||||
at::ScalarType::BFloat16,
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
key.scalar_type(),
|
||||
"gather_cached_kv_kernel_optimized",
|
||||
[&] {
|
||||
|
14
csrc/dispatch_utils.h
Normal file
14
csrc/dispatch_utils.h
Normal file
@ -0,0 +1,14 @@
|
||||
/*
|
||||
* Adapted from
|
||||
* https://github.com/pytorch/pytorch/blob/v2.0.1/aten/src/ATen/Dispatch.h
|
||||
*/
|
||||
#include <torch/extension.h>
|
||||
|
||||
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
|
||||
|
||||
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH( \
|
||||
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
@ -1,6 +1,7 @@
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include "dispatch_utils.h"
|
||||
#include "reduction_utils.cuh"
|
||||
|
||||
namespace vllm {
|
||||
@ -46,9 +47,7 @@ void rms_norm(
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(hidden_size, 1024));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
at::ScalarType::Half,
|
||||
at::ScalarType::BFloat16,
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(),
|
||||
"rms_norm_kernel",
|
||||
[&] {
|
||||
|
@ -1,15 +1,16 @@
|
||||
#include <torch/extension.h>
|
||||
|
||||
void rotary_embedding_neox(
|
||||
void rotary_embedding(
|
||||
torch::Tensor& positions,
|
||||
torch::Tensor& query,
|
||||
torch::Tensor& key,
|
||||
int head_size,
|
||||
torch::Tensor& cos_sin_cache);
|
||||
torch::Tensor& cos_sin_cache,
|
||||
bool is_neox);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def(
|
||||
"rotary_embedding_neox",
|
||||
&rotary_embedding_neox,
|
||||
"Apply GPT-NeoX style rotary embedding to query and key");
|
||||
"rotary_embedding",
|
||||
&rotary_embedding,
|
||||
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
|
||||
}
|
||||
|
@ -1,10 +1,42 @@
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include "dispatch_utils.h"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
template<typename scalar_t>
|
||||
__global__ void rotary_embedding_neox_kernel(
|
||||
template<typename scalar_t, bool IS_NEOX>
|
||||
inline __device__ void apply_rotary_embedding(
|
||||
scalar_t* __restrict__ arr,
|
||||
const scalar_t* __restrict__ cos_ptr,
|
||||
const scalar_t* __restrict__ sin_ptr,
|
||||
int rot_offset,
|
||||
int embed_dim)
|
||||
{
|
||||
int x_index, y_index;
|
||||
scalar_t cos, sin;
|
||||
if (IS_NEOX) {
|
||||
// GPT-NeoX style rotary embedding.
|
||||
x_index = rot_offset;
|
||||
y_index = embed_dim + rot_offset;
|
||||
cos = __ldg(cos_ptr + x_index);
|
||||
sin = __ldg(sin_ptr + x_index);
|
||||
} else {
|
||||
// GPT-J style rotary embedding.
|
||||
x_index = 2 * rot_offset;
|
||||
y_index = 2 * rot_offset + 1;
|
||||
cos = __ldg(cos_ptr + x_index / 2);
|
||||
sin = __ldg(sin_ptr + x_index / 2);
|
||||
}
|
||||
|
||||
const scalar_t x = arr[x_index];
|
||||
const scalar_t y = arr[y_index];
|
||||
arr[x_index] = x * cos - y * sin;
|
||||
arr[y_index] = y * cos + x * sin;
|
||||
}
|
||||
|
||||
template<typename scalar_t, bool IS_NEOX>
|
||||
__global__ void rotary_embedding_kernel(
|
||||
const int64_t* __restrict__ positions, // [num_tokens]
|
||||
scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size]
|
||||
scalar_t* __restrict__ key, // [num_tokens, num_kv_heads, head_size]
|
||||
@ -21,58 +53,37 @@ __global__ void rotary_embedding_neox_kernel(
|
||||
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
|
||||
|
||||
const int embed_dim = rot_dim / 2;
|
||||
const scalar_t* cos_ptr = cache_ptr;
|
||||
const scalar_t* sin_ptr = cache_ptr + embed_dim;
|
||||
|
||||
const int nq = num_heads * embed_dim;
|
||||
for (int i = threadIdx.x; i < nq; i += blockDim.x) {
|
||||
const int head_idx = i / embed_dim;
|
||||
const int token_head = token_idx * query_stride + head_idx * head_size;
|
||||
|
||||
const int rot_offset = i % embed_dim;
|
||||
const int x_index = rot_offset;
|
||||
const int y_index = embed_dim + rot_offset;
|
||||
|
||||
const int out_x = token_idx * query_stride + head_idx * head_size + x_index;
|
||||
const int out_y = token_idx * query_stride + head_idx * head_size + y_index;
|
||||
|
||||
const scalar_t cos = __ldg(cache_ptr + x_index);
|
||||
const scalar_t sin = __ldg(cache_ptr + y_index);
|
||||
|
||||
const scalar_t q_x = query[token_head + x_index];
|
||||
const scalar_t q_y = query[token_head + y_index];
|
||||
query[out_x] = q_x * cos - q_y * sin;
|
||||
query[out_y] = q_y * cos + q_x * sin;
|
||||
apply_rotary_embedding<scalar_t, IS_NEOX>(query + token_head, cos_ptr,
|
||||
sin_ptr, rot_offset, embed_dim);
|
||||
}
|
||||
|
||||
const int nk = num_kv_heads * embed_dim;
|
||||
for (int i = threadIdx.x; i < nk; i += blockDim.x) {
|
||||
const int head_idx = i / embed_dim;
|
||||
const int token_head = token_idx * key_stride + head_idx * head_size;
|
||||
|
||||
const int rot_offset = i % embed_dim;
|
||||
const int x_index = rot_offset;
|
||||
const int y_index = embed_dim + rot_offset;
|
||||
|
||||
const int out_x = token_idx * key_stride + head_idx * head_size + x_index;
|
||||
const int out_y = token_idx * key_stride + head_idx * head_size + y_index;
|
||||
|
||||
const scalar_t cos = __ldg(cache_ptr + x_index);
|
||||
const scalar_t sin = __ldg(cache_ptr + y_index);
|
||||
|
||||
const scalar_t k_x = key[token_head + x_index];
|
||||
const scalar_t k_y = key[token_head + y_index];
|
||||
key[out_x] = k_x * cos - k_y * sin;
|
||||
key[out_y] = k_y * cos + k_x * sin;
|
||||
apply_rotary_embedding<scalar_t, IS_NEOX>(key + token_head, cos_ptr,
|
||||
sin_ptr, rot_offset, embed_dim);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
void rotary_embedding_neox(
|
||||
void rotary_embedding(
|
||||
torch::Tensor& positions, // [num_tokens]
|
||||
torch::Tensor& query, // [num_tokens, num_heads * head_size]
|
||||
torch::Tensor& key, // [num_tokens, num_kv_heads * head_size]
|
||||
int head_size,
|
||||
torch::Tensor& cos_sin_cache) // [max_position, rot_dim]
|
||||
{
|
||||
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
|
||||
bool is_neox) {
|
||||
int num_tokens = query.size(0);
|
||||
int rot_dim = cos_sin_cache.size(1);
|
||||
int num_heads = query.size(1) / head_size;
|
||||
@ -83,22 +94,34 @@ void rotary_embedding_neox(
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(num_heads * rot_dim / 2, 512));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
at::ScalarType::Half,
|
||||
at::ScalarType::BFloat16,
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
query.scalar_type(),
|
||||
"rotary_embedding_neox",
|
||||
"rotary_embedding",
|
||||
[&] {
|
||||
vllm::rotary_embedding_neox_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
positions.data_ptr<int64_t>(),
|
||||
query.data_ptr<scalar_t>(),
|
||||
key.data_ptr<scalar_t>(),
|
||||
cos_sin_cache.data_ptr<scalar_t>(),
|
||||
rot_dim,
|
||||
query_stride,
|
||||
key_stride,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_size);
|
||||
if (is_neox) {
|
||||
vllm::rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
|
||||
positions.data_ptr<int64_t>(),
|
||||
query.data_ptr<scalar_t>(),
|
||||
key.data_ptr<scalar_t>(),
|
||||
cos_sin_cache.data_ptr<scalar_t>(),
|
||||
rot_dim,
|
||||
query_stride,
|
||||
key_stride,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_size);
|
||||
} else {
|
||||
vllm::rotary_embedding_kernel<scalar_t, false><<<grid, block, 0, stream>>>(
|
||||
positions.data_ptr<int64_t>(),
|
||||
query.data_ptr<scalar_t>(),
|
||||
key.data_ptr<scalar_t>(),
|
||||
cos_sin_cache.data_ptr<scalar_t>(),
|
||||
rot_dim,
|
||||
query_stride,
|
||||
key_stride,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_size);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
@ -59,7 +59,7 @@ Next, you need to rewrite the :code:`forward` methods of your model by following
|
||||
+ kv_caches: List[KVCache],
|
||||
+ input_metadata: InputMetadata,
|
||||
+ cache_events: Optional[List[torch.cuda.Event]],
|
||||
+) -> Dict[int, SequenceOutputs]:
|
||||
+) -> SamplerOutput:
|
||||
|
||||
3. Update the code by considering that :code:`input_ids` and :code:`positions` are now flattened tensors.
|
||||
4. Replace the attention operation with either :code:`GPTPagedAttention` or :code:`GPTNeoXPagedAttention`, depending on the model's architecture.
|
||||
|
@ -15,7 +15,7 @@ Alongside each architecture, we include some popular models that use it.
|
||||
- Models
|
||||
- Example HuggingFace Models
|
||||
* - :code:`AquilaForCausalLM`
|
||||
- Aqualia
|
||||
- Aquila
|
||||
- :code:`BAAI/Aquila-7B`, :code:`BAAI/AquilaChat-7B`, etc.
|
||||
* - :code:`BaiChuanForCausalLM`
|
||||
- Baichuan
|
||||
@ -43,14 +43,14 @@ Alongside each architecture, we include some popular models that use it.
|
||||
- :code:`internlm/internlm-7b`, :code:`internlm/internlm-chat-7b`, etc.
|
||||
* - :code:`LlamaForCausalLM`
|
||||
- LLaMA, LLaMA-2, Vicuna, Alpaca, Koala, Guanaco
|
||||
- :code:`meta-llama/Llama-2-13b-hf`, :code:`openlm-research/open_llama_13b`, :code:`lmsys/vicuna-13b-v1.3`, :code:`young-geng/koala`, :code:`JosephusCheung/Guanaco`, etc.
|
||||
- :code:`meta-llama/Llama-2-13b-hf`, :code:`meta-llama/Llama-2-70b-hf`, :code:`openlm-research/open_llama_13b`, :code:`lmsys/vicuna-13b-v1.3`, :code:`young-geng/koala`, etc.
|
||||
* - :code:`MPTForCausalLM`
|
||||
- MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter
|
||||
- :code:`mosaicml/mpt-7b`, :code:`mosaicml/mpt-7b-storywriter`, :code:`mosaicml/mpt-30b`, etc.
|
||||
* - :code:`OPTForCausalLM`
|
||||
- OPT, OPT-IML
|
||||
- :code:`facebook/opt-66b`, :code:`facebook/opt-iml-max-30b`, etc.
|
||||
* - :code:`OPTForCausalLM`
|
||||
* - :code:`QWenLMHeadModel`
|
||||
- Qwen
|
||||
- :code:`Qwen/Qwen-7B`, :code:`Qwen/Qwen-7B-Chat`, etc.
|
||||
|
||||
|
@ -10,3 +10,4 @@ types-setuptools
|
||||
|
||||
# testing
|
||||
pytest
|
||||
pytest-forked
|
||||
|
@ -4,7 +4,7 @@ ray >= 2.5.1
|
||||
sentencepiece # Required for LLaMA tokenizer.
|
||||
numpy
|
||||
torch >= 2.0.0
|
||||
transformers >= 4.31.0 # Required for LLaMA-2.
|
||||
transformers >= 4.33.1 # Required for Code Llama.
|
||||
xformers >= 0.0.21
|
||||
fastapi
|
||||
uvicorn
|
||||
|
51
tests/async_engine/api_server_async_engine.py
Normal file
51
tests/async_engine/api_server_async_engine.py
Normal file
@ -0,0 +1,51 @@
|
||||
"""vllm.entrypoints.api_server with some extra logging for testing."""
|
||||
import argparse
|
||||
from typing import Any, Dict
|
||||
|
||||
import uvicorn
|
||||
from fastapi.responses import JSONResponse, Response
|
||||
|
||||
import vllm.entrypoints.api_server
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
|
||||
app = vllm.entrypoints.api_server.app
|
||||
|
||||
|
||||
class AsyncLLMEngineWithStats(AsyncLLMEngine):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._num_aborts = 0
|
||||
|
||||
async def abort(self, request_id: str) -> None:
|
||||
await super().abort(request_id)
|
||||
self._num_aborts += 1
|
||||
|
||||
def testing_stats(self) -> Dict[str, Any]:
|
||||
return {"num_aborted_requests": self._num_aborts}
|
||||
|
||||
|
||||
@app.get("/stats")
|
||||
def stats() -> Response:
|
||||
"""Get the statistics of the engine."""
|
||||
return JSONResponse(engine.testing_stats())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="localhost")
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
engine = AsyncLLMEngineWithStats.from_engine_args(engine_args,
|
||||
start_engine_loop=False)
|
||||
vllm.entrypoints.api_server.engine = engine
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
log_level="debug",
|
||||
timeout_keep_alive=vllm.entrypoints.api_server.TIMEOUT_KEEP_ALIVE)
|
86
tests/async_engine/test_api_server.py
Normal file
86
tests/async_engine/test_api_server.py
Normal file
@ -0,0 +1,86 @@
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from multiprocessing import Pool
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
|
||||
def _query_server(prompt: str) -> dict:
|
||||
response = requests.post("http://localhost:8000/generate",
|
||||
json={
|
||||
"prompt": prompt,
|
||||
"max_tokens": 100,
|
||||
"temperature": 0,
|
||||
"ignore_eos": True
|
||||
})
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api_server():
|
||||
script_path = Path(__file__).parent.joinpath(
|
||||
"api_server_async_engine.py").absolute()
|
||||
uvicorn_process = subprocess.Popen([
|
||||
sys.executable, "-u",
|
||||
str(script_path), "--model", "facebook/opt-125m"
|
||||
])
|
||||
yield
|
||||
uvicorn_process.terminate()
|
||||
|
||||
|
||||
def test_api_server(api_server):
|
||||
"""
|
||||
Run the API server and test it.
|
||||
|
||||
We run both the server and requests in separate processes.
|
||||
|
||||
We test that the server can handle incoming requests, including
|
||||
multiple requests at the same time, and that it can handle requests
|
||||
being cancelled without crashing.
|
||||
"""
|
||||
with Pool(32) as pool:
|
||||
# Wait until the server is ready
|
||||
prompts = ["Hello world"] * 1
|
||||
result = None
|
||||
while not result:
|
||||
try:
|
||||
for result in pool.map(_query_server, prompts):
|
||||
break
|
||||
except:
|
||||
time.sleep(1)
|
||||
|
||||
# Actual tests start here
|
||||
# Try with 1 prompt
|
||||
for result in pool.map(_query_server, prompts):
|
||||
assert result
|
||||
|
||||
num_aborted_requests = requests.get(
|
||||
"http://localhost:8000/stats").json()["num_aborted_requests"]
|
||||
assert num_aborted_requests == 0
|
||||
|
||||
# Try with 100 prompts
|
||||
prompts = ["Hello world"] * 100
|
||||
for result in pool.map(_query_server, prompts):
|
||||
assert result
|
||||
|
||||
# Cancel requests
|
||||
pool.map_async(_query_server, prompts)
|
||||
time.sleep(0.01)
|
||||
pool.terminate()
|
||||
pool.join()
|
||||
|
||||
# check cancellation stats
|
||||
num_aborted_requests = requests.get(
|
||||
"http://localhost:8000/stats").json()["num_aborted_requests"]
|
||||
assert num_aborted_requests > 0
|
||||
|
||||
# check that server still runs after cancellations
|
||||
with Pool(32) as pool:
|
||||
# Try with 100 prompts
|
||||
prompts = ["Hello world"] * 100
|
||||
for result in pool.map(_query_server, prompts):
|
||||
assert result
|
54
tests/async_engine/test_request_tracker.py
Normal file
54
tests/async_engine/test_request_tracker.py
Normal file
@ -0,0 +1,54 @@
|
||||
import pytest
|
||||
|
||||
from vllm.engine.async_llm_engine import RequestTracker
|
||||
from vllm.outputs import RequestOutput
|
||||
|
||||
|
||||
def test_request_tracker():
|
||||
tracker = RequestTracker()
|
||||
stream_1 = tracker.add_request("1")
|
||||
new, finished = tracker.get_new_and_finished_requests()
|
||||
assert len(new) == 1
|
||||
assert new[0]["request_id"] == "1"
|
||||
assert not finished
|
||||
assert not stream_1.finished
|
||||
|
||||
stream_2 = tracker.add_request("2")
|
||||
stream_3 = tracker.add_request("3")
|
||||
new, finished = tracker.get_new_and_finished_requests()
|
||||
assert len(new) == 2
|
||||
assert new[0]["request_id"] == "2"
|
||||
assert new[1]["request_id"] == "3"
|
||||
assert not finished
|
||||
assert not stream_2.finished
|
||||
assert not stream_3.finished
|
||||
|
||||
# request_ids must be unique
|
||||
with pytest.raises(KeyError):
|
||||
tracker.add_request("1")
|
||||
|
||||
tracker.abort_request("1")
|
||||
new, finished = tracker.get_new_and_finished_requests()
|
||||
assert len(finished) == 1
|
||||
assert "1" in finished
|
||||
assert not new
|
||||
assert stream_1.finished
|
||||
|
||||
stream_4 = tracker.add_request("4")
|
||||
tracker.abort_request("4")
|
||||
new, finished = tracker.get_new_and_finished_requests()
|
||||
assert len(finished) == 1
|
||||
assert "4" in finished
|
||||
assert not new
|
||||
assert stream_4.finished
|
||||
|
||||
stream_5 = tracker.add_request("5")
|
||||
tracker.process_request_output(
|
||||
RequestOutput("2", "output", [], [], finished=True))
|
||||
new, finished = tracker.get_new_and_finished_requests()
|
||||
assert len(finished) == 1
|
||||
assert "2" in finished
|
||||
assert len(new) == 1
|
||||
assert new[0]["request_id"] == "5"
|
||||
assert stream_2.finished
|
||||
assert not stream_5.finished
|
178
tests/conftest.py
Normal file
178
tests/conftest.py
Normal file
@ -0,0 +1,178 @@
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
_TEST_PROMPTS = [
|
||||
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs.",
|
||||
"Briefly describe the major milestones in the development of artificial intelligence from 1950 to 2020.",
|
||||
"Compare and contrast artificial intelligence with human intelligence in terms of processing information.",
|
||||
"Describe the basic components of a neural network and how it can be trained.",
|
||||
"Write a short story about a robot that dreams for the first time.",
|
||||
"Analyze the impact of the COVID-19 pandemic on global economic structures and future business models.",
|
||||
"Explain the cultural significance of the Mona Lisa painting, and how its perception might vary in Western versus Eastern societies.",
|
||||
"Translate the following English sentence into Japanese, French, and Swahili: 'The early bird catches the worm.'",
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def example_prompts() -> List[str]:
|
||||
return _TEST_PROMPTS
|
||||
|
||||
|
||||
_STR_DTYPE_TO_TORCH_DTYPE = {
|
||||
"half": torch.half,
|
||||
"bfloat16": torch.bfloat16,
|
||||
"float": torch.float,
|
||||
}
|
||||
|
||||
|
||||
class HfRunner:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
tokenizer_name: Optional[str] = None,
|
||||
dtype: str = "half",
|
||||
) -> None:
|
||||
assert dtype in _STR_DTYPE_TO_TORCH_DTYPE
|
||||
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch_dtype,
|
||||
trust_remote_code=True,
|
||||
).cuda()
|
||||
if tokenizer_name is None:
|
||||
tokenizer_name = model_name
|
||||
self.tokenizer = get_tokenizer(tokenizer_name, trust_remote_code=True)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
**kwargs,
|
||||
) -> List[Tuple[List[int], str]]:
|
||||
outputs: List[Tuple[List[int], str]] = []
|
||||
for prompt in prompts:
|
||||
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
|
||||
output_ids = self.model.generate(
|
||||
input_ids.cuda(),
|
||||
use_cache=True,
|
||||
**kwargs,
|
||||
)
|
||||
output_str = self.tokenizer.batch_decode(
|
||||
output_ids,
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=False,
|
||||
)
|
||||
output_ids = output_ids.cpu().tolist()
|
||||
outputs.append((output_ids, output_str))
|
||||
return outputs
|
||||
|
||||
def generate_greedy(
|
||||
self,
|
||||
prompts: List[str],
|
||||
max_tokens: int,
|
||||
) -> List[Tuple[List[int], str]]:
|
||||
outputs = self.generate(prompts,
|
||||
do_sample=False,
|
||||
max_new_tokens=max_tokens)
|
||||
for i in range(len(outputs)):
|
||||
output_ids, output_str = outputs[i]
|
||||
outputs[i] = (output_ids[0], output_str[0])
|
||||
return outputs
|
||||
|
||||
def generate_beam_search(
|
||||
self,
|
||||
prompts: List[str],
|
||||
beam_width: int,
|
||||
max_tokens: int,
|
||||
) -> List[Tuple[List[int], str]]:
|
||||
outputs = self.generate(prompts,
|
||||
do_sample=False,
|
||||
max_new_tokens=max_tokens,
|
||||
num_beams=beam_width,
|
||||
num_return_sequences=beam_width)
|
||||
for i in range(len(outputs)):
|
||||
output_ids, output_str = outputs[i]
|
||||
for j in range(len(output_ids)):
|
||||
output_ids[j] = [
|
||||
x for x in output_ids[j]
|
||||
if x != self.tokenizer.pad_token_id
|
||||
]
|
||||
outputs[i] = (output_ids, output_str)
|
||||
return outputs
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def hf_runner():
|
||||
return HfRunner
|
||||
|
||||
|
||||
class VllmRunner:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
tokenizer_name: Optional[str] = None,
|
||||
dtype: str = "half",
|
||||
) -> None:
|
||||
self.model = LLM(
|
||||
model=model_name,
|
||||
tokenizer=tokenizer_name,
|
||||
trust_remote_code=True,
|
||||
dtype=dtype,
|
||||
swap_space=0,
|
||||
)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
sampling_params: SamplingParams,
|
||||
) -> List[Tuple[List[int], str]]:
|
||||
req_outputs = self.model.generate(prompts,
|
||||
sampling_params=sampling_params)
|
||||
outputs = []
|
||||
for req_output in req_outputs:
|
||||
prompt_str = req_output.prompt
|
||||
prompt_ids = req_output.prompt_token_ids
|
||||
req_sample_output_ids = []
|
||||
req_sample_output_strs = []
|
||||
for sample in req_output.outputs:
|
||||
output_str = sample.text
|
||||
output_ids = sample.token_ids
|
||||
req_sample_output_ids.append(prompt_ids + output_ids)
|
||||
req_sample_output_strs.append(prompt_str + output_str)
|
||||
outputs.append((req_sample_output_ids, req_sample_output_strs))
|
||||
return outputs
|
||||
|
||||
def generate_greedy(
|
||||
self,
|
||||
prompts: List[str],
|
||||
max_tokens: int,
|
||||
) -> List[Tuple[List[int], str]]:
|
||||
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
|
||||
outputs = self.generate(prompts, greedy_params)
|
||||
return [(output_ids[0], output_str[0])
|
||||
for output_ids, output_str in outputs]
|
||||
|
||||
def generate_beam_search(
|
||||
self,
|
||||
prompts: List[str],
|
||||
beam_width: int,
|
||||
max_tokens: int,
|
||||
) -> List[Tuple[List[int], str]]:
|
||||
beam_search_params = SamplingParams(n=beam_width,
|
||||
use_beam_search=True,
|
||||
temperature=0.0,
|
||||
max_tokens=max_tokens)
|
||||
outputs = self.generate(prompts, beam_search_params)
|
||||
return outputs
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vllm_runner():
|
||||
return VllmRunner
|
43
tests/kernels/conftest.py
Normal file
43
tests/kernels/conftest.py
Normal file
@ -0,0 +1,43 @@
|
||||
from typing import List, Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
|
||||
def create_kv_caches(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_layers: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
scale = head_size**-0.5
|
||||
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
||||
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
|
||||
key_caches = []
|
||||
for _ in range(num_layers):
|
||||
key_cache = torch.empty(size=key_cache_shape,
|
||||
dtype=dtype,
|
||||
device='cuda')
|
||||
key_cache.uniform_(-scale, scale)
|
||||
key_caches.append(key_cache)
|
||||
|
||||
value_cache_shape = (num_blocks, num_heads, head_size, block_size)
|
||||
value_caches = []
|
||||
for _ in range(num_layers):
|
||||
value_cache = torch.empty(size=value_cache_shape,
|
||||
dtype=dtype,
|
||||
device='cuda')
|
||||
value_cache.uniform_(-scale, scale)
|
||||
value_caches.append(value_cache)
|
||||
return key_caches, value_caches
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def kv_cache_factory():
|
||||
return create_kv_caches
|
@ -1,20 +1,34 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers.activations import get_activation
|
||||
|
||||
from vllm import activation_ops
|
||||
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||
NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
|
||||
D = [512, 4096, 5120, 13824] # Arbitrary values for testing
|
||||
SEEDS = [0]
|
||||
|
||||
|
||||
def ref_silu_and_mul(x: torch.Tensor) -> torch.Tensor:
|
||||
x1, x2 = x.chunk(chunks=2, dim=1)
|
||||
return F.silu(x1) * x2
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("d", D)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@torch.inference_mode()
|
||||
def run_silu_and_mul(
|
||||
def test_silu_and_mul(
|
||||
num_tokens: int,
|
||||
d: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
) -> None:
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
x = torch.randn(num_tokens, 2 * d, dtype=dtype, device='cuda')
|
||||
out = torch.empty(num_tokens, d, dtype=dtype, device='cuda')
|
||||
activation_ops.silu_and_mul(out, x)
|
||||
@ -22,20 +36,19 @@ def run_silu_and_mul(
|
||||
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
|
||||
|
||||
|
||||
def test_silu_and_mul() -> None:
|
||||
for dtype in [torch.half, torch.bfloat16, torch.float]:
|
||||
for num_tokens in [7, 83, 2048]:
|
||||
for d in [512, 4096, 5120, 13824]:
|
||||
print(f'Testing dtype={dtype}, num_tokens={num_tokens}, d={d}')
|
||||
run_silu_and_mul(num_tokens, d, dtype)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("d", D)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@torch.inference_mode()
|
||||
def run_gelu_new(
|
||||
def test_gelu_new(
|
||||
num_tokens: int,
|
||||
d: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
) -> None:
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
x = torch.randn(num_tokens, d, dtype=dtype, device='cuda')
|
||||
out = torch.empty(num_tokens, d, dtype=dtype, device='cuda')
|
||||
activation_ops.gelu_new(out, x)
|
||||
@ -43,30 +56,20 @@ def run_gelu_new(
|
||||
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
|
||||
|
||||
|
||||
def test_gelu_new() -> None:
|
||||
for dtype in [torch.half, torch.bfloat16, torch.float]:
|
||||
for num_tokens in [7, 83, 2048]:
|
||||
for d in [512, 4096, 5120, 13824]:
|
||||
print(f'Testing dtype={dtype}, num_tokens={num_tokens}, d={d}')
|
||||
run_gelu_new(num_tokens, d, dtype)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def run_gelu_fast(
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("d", D)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
def test_gelu_fast(
|
||||
num_tokens: int,
|
||||
d: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
) -> None:
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
x = torch.randn(num_tokens, d, dtype=dtype, device='cuda')
|
||||
out = torch.empty(num_tokens, d, dtype=dtype, device='cuda')
|
||||
activation_ops.gelu_fast(out, x)
|
||||
ref_out = get_activation("gelu_fast")(x)
|
||||
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
|
||||
|
||||
|
||||
def test_gelu_fast() -> None:
|
||||
for dtype in [torch.half, torch.bfloat16, torch.float]:
|
||||
for num_tokens in [7, 83, 2048]:
|
||||
for d in [512, 4096, 5120, 13824]:
|
||||
print(f'Testing dtype={dtype}, num_tokens={num_tokens}, d={d}')
|
||||
run_gelu_fast(num_tokens, d, dtype)
|
||||
|
@ -1,14 +1,24 @@
|
||||
import random
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from xformers import ops as xops
|
||||
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
|
||||
|
||||
from vllm import attention_ops
|
||||
|
||||
MAX_SEQ_LEN = 4096
|
||||
TEST_SEED = 0
|
||||
MAX_SEQ_LEN = 8192
|
||||
NUM_BLOCKS = 128 # Arbitrary values for testing
|
||||
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||
NUM_GEN_SEQS = [7] # Arbitrary values for testing
|
||||
NUM_PREFILL_SEQS = [1, 3, 7] # Arbitrary values for testing
|
||||
NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing
|
||||
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
|
||||
BLOCK_SIZES = [8, 16, 32]
|
||||
USE_ALIBI = [False, True]
|
||||
SEEDS = [0]
|
||||
|
||||
|
||||
def ref_masked_attention(
|
||||
@ -18,29 +28,34 @@ def ref_masked_attention(
|
||||
scale: float,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
query = query * scale
|
||||
attn = torch.einsum('qhd,khd->hqk', query, key)
|
||||
attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
|
||||
if attn_mask is not None:
|
||||
attn = attn + attn_mask
|
||||
attn = torch.softmax(attn, dim=-1)
|
||||
out = torch.einsum('hqk,khd->qhd', attn, value)
|
||||
attn_weights = attn_weights + attn_mask.float()
|
||||
attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
|
||||
out = torch.einsum("hqk,khd->qhd", attn_weights, value)
|
||||
return out
|
||||
|
||||
|
||||
def ref_single_query_cached_kv_attention(
|
||||
output: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
num_queries_per_kv: int,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
context_lens: torch.Tensor,
|
||||
scale: float,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
) -> None:
|
||||
num_heads = value_cache.shape[1]
|
||||
num_query_heads = query.shape[1]
|
||||
num_kv_heads = value_cache.shape[1]
|
||||
head_size = value_cache.shape[2]
|
||||
block_size = value_cache.shape[3]
|
||||
num_seqs = query.shape[0]
|
||||
|
||||
num_input_tokens = query.shape[0]
|
||||
for i in range(num_input_tokens):
|
||||
block_tables = block_tables.cpu().tolist()
|
||||
context_lens = context_lens.cpu().tolist()
|
||||
for i in range(num_seqs):
|
||||
q = query[i].unsqueeze(0)
|
||||
block_table = block_tables[i]
|
||||
context_len = int(context_lens[i])
|
||||
@ -52,30 +67,138 @@ def ref_single_query_cached_kv_attention(
|
||||
block_offset = j % block_size
|
||||
|
||||
k = key_cache[block_number, :, :, block_offset, :]
|
||||
k = k.reshape(num_heads, head_size)
|
||||
k = k.reshape(num_kv_heads, head_size)
|
||||
keys.append(k)
|
||||
|
||||
v = value_cache[block_number, :, :, block_offset]
|
||||
values.append(v)
|
||||
keys = torch.stack(keys, dim=0)
|
||||
values = torch.stack(values, dim=0)
|
||||
if num_queries_per_kv > 1:
|
||||
# Handle MQA and GQA
|
||||
keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1)
|
||||
values = torch.repeat_interleave(values, num_queries_per_kv, dim=1)
|
||||
|
||||
scale = 1.0 / (head_size**0.5)
|
||||
out = ref_masked_attention(q, keys, values, scale)
|
||||
out = out.view(num_heads, head_size)
|
||||
alibi_bias = None
|
||||
if alibi_slopes is not None:
|
||||
# Create the ALiBi bias used in the paged attention kernel.
|
||||
position_ids = torch.arange(context_len, device="cuda").int()
|
||||
alibi_bias = (position_ids - context_len + 1).float()
|
||||
alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(
|
||||
1, 1, -1)
|
||||
|
||||
out = ref_masked_attention(q, keys, values, scale, alibi_bias)
|
||||
out = out.view(num_query_heads, head_size)
|
||||
output[i].copy_(out, non_blocking=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("use_alibi", USE_ALIBI)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@torch.inference_mode()
|
||||
def test_single_query_cached_kv_attention(
|
||||
kv_cache_factory,
|
||||
num_seqs: int,
|
||||
num_heads: Tuple[int, int],
|
||||
head_size: int,
|
||||
use_alibi: bool,
|
||||
block_size: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
) -> None:
|
||||
random.seed(seed)
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
scale = float(1.0 / (head_size**0.5))
|
||||
num_query_heads, num_kv_heads = num_heads
|
||||
query = torch.empty(num_seqs,
|
||||
num_query_heads,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device="cuda")
|
||||
query.uniform_(-scale, scale)
|
||||
|
||||
assert num_query_heads % num_kv_heads == 0
|
||||
num_queries_per_kv = num_query_heads // num_kv_heads
|
||||
head_mapping = torch.repeat_interleave(
|
||||
torch.arange(num_kv_heads, dtype=torch.int32, device="cuda"),
|
||||
num_queries_per_kv)
|
||||
alibi_slopes = None
|
||||
if use_alibi:
|
||||
alibi_slopes = torch.randn(num_query_heads,
|
||||
dtype=torch.float,
|
||||
device="cuda")
|
||||
|
||||
context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
|
||||
max_context_len = max(context_lens)
|
||||
context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda")
|
||||
|
||||
# Create the block tables.
|
||||
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
|
||||
block_tables = []
|
||||
for _ in range(num_seqs):
|
||||
block_table = [
|
||||
random.randint(0, NUM_BLOCKS - 1)
|
||||
for _ in range(max_num_blocks_per_seq)
|
||||
]
|
||||
block_tables.append(block_table)
|
||||
block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda")
|
||||
|
||||
# Create the KV caches.
|
||||
key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
|
||||
num_kv_heads, head_size, dtype,
|
||||
seed)
|
||||
key_cache, value_cache = key_caches[0], value_caches[0]
|
||||
|
||||
# Call the paged attention kernel.
|
||||
output = torch.empty_like(query)
|
||||
attention_ops.single_query_cached_kv_attention(
|
||||
output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
head_mapping,
|
||||
scale,
|
||||
block_tables,
|
||||
context_lens,
|
||||
block_size,
|
||||
max_context_len,
|
||||
alibi_slopes,
|
||||
)
|
||||
|
||||
# Run the reference implementation.
|
||||
ref_output = torch.empty_like(query)
|
||||
ref_single_query_cached_kv_attention(
|
||||
ref_output,
|
||||
query,
|
||||
num_queries_per_kv,
|
||||
key_cache,
|
||||
value_cache,
|
||||
block_tables,
|
||||
context_lens,
|
||||
scale,
|
||||
alibi_slopes,
|
||||
)
|
||||
|
||||
# NOTE(woosuk): Due to the kernel-level differences in the two
|
||||
# implementations, there is a small numerical difference in the two
|
||||
# outputs. Thus, we use a relaxed tolerance for the test.
|
||||
assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)
|
||||
|
||||
|
||||
def ref_multi_query_kv_attention(
|
||||
cu_seq_lens: List[int],
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
scale: float,
|
||||
dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
head_size = query.shape[-1]
|
||||
scale = 1.0 / (head_size**0.5)
|
||||
|
||||
num_seqs = len(cu_seq_lens) - 1
|
||||
ref_outputs = []
|
||||
for i in range(num_seqs):
|
||||
@ -87,7 +210,7 @@ def ref_multi_query_kv_attention(
|
||||
attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype),
|
||||
diagonal=1)
|
||||
attn_mask = attn_mask * torch.finfo(dtype).min
|
||||
attn_mask = attn_mask.to(dtype=dtype, device='cuda')
|
||||
attn_mask = attn_mask.to(dtype=dtype, device="cuda")
|
||||
|
||||
ref_output = ref_masked_attention(
|
||||
query[start_idx:end_idx],
|
||||
@ -101,172 +224,43 @@ def ref_multi_query_kv_attention(
|
||||
return ref_output
|
||||
|
||||
|
||||
def ref_multi_query_cached_kv_attention(
|
||||
cu_query_lens: List[int],
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
context_lens: torch.Tensor,
|
||||
dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
num_heads = value_cache.shape[1]
|
||||
head_size = value_cache.shape[2]
|
||||
block_size = value_cache.shape[3]
|
||||
scale = 1.0 / (head_size**0.5)
|
||||
|
||||
num_queries = len(cu_query_lens) - 1
|
||||
ref_outputs = []
|
||||
for i in range(num_queries):
|
||||
start_idx = cu_query_lens[i]
|
||||
end_idx = cu_query_lens[i + 1]
|
||||
query_len = end_idx - start_idx
|
||||
context_len = int(context_lens[i])
|
||||
block_table = block_tables[i]
|
||||
|
||||
# Create attention mask
|
||||
attn_mask = torch.triu(torch.ones(query_len, context_len),
|
||||
diagonal=context_len - query_len + 1) * -1e5
|
||||
attn_mask = attn_mask.to(dtype=dtype, device='cuda')
|
||||
|
||||
keys = []
|
||||
values = []
|
||||
for j in range(context_len):
|
||||
block_number = int(block_table[j // block_size])
|
||||
block_offset = j % block_size
|
||||
|
||||
k = key_cache[block_number, :, :, block_offset, :]
|
||||
k = k.reshape(num_heads, head_size)
|
||||
keys.append(k)
|
||||
|
||||
v = value_cache[block_number, :, :, block_offset]
|
||||
values.append(v)
|
||||
keys = torch.stack(keys, dim=0)
|
||||
values = torch.stack(values, dim=0)
|
||||
|
||||
ref_output = ref_masked_attention(
|
||||
query[start_idx:end_idx],
|
||||
keys,
|
||||
values,
|
||||
scale,
|
||||
attn_mask=attn_mask,
|
||||
)
|
||||
ref_outputs.append(ref_output)
|
||||
ref_output = torch.cat(ref_outputs, dim=0)
|
||||
return ref_output
|
||||
|
||||
|
||||
# TODO(woosuk): Add tests for USE_ALIBI=True.
|
||||
@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@torch.inference_mode()
|
||||
def run_single_query_cached_kv_attention(
|
||||
num_tokens: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
block_size: int,
|
||||
num_blocks: int,
|
||||
dtype: torch.dtype,
|
||||
num_kv_heads: int = None,
|
||||
) -> None:
|
||||
qkv = torch.empty(num_tokens,
|
||||
3,
|
||||
num_heads,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device='cuda')
|
||||
qkv.uniform_(-1e-3, 1e-3)
|
||||
query, _, _ = qkv.unbind(dim=1)
|
||||
|
||||
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
||||
key_block_shape = (num_heads, head_size // x, block_size, x)
|
||||
key_cache = torch.empty(size=(num_blocks, *key_block_shape),
|
||||
dtype=dtype,
|
||||
device='cuda')
|
||||
key_cache.uniform_(-1e-3, 1e-3)
|
||||
value_block_shape = (num_heads, head_size, block_size)
|
||||
value_cache = torch.empty(size=(num_blocks, *value_block_shape),
|
||||
dtype=dtype,
|
||||
device='cuda')
|
||||
value_cache.uniform_(-1e-3, 1e-3)
|
||||
|
||||
context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_tokens)]
|
||||
max_context_len = max(context_lens)
|
||||
context_lens = torch.tensor(context_lens, dtype=torch.int, device='cuda')
|
||||
|
||||
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
|
||||
block_tables = []
|
||||
for _ in range(num_tokens):
|
||||
block_table = [
|
||||
random.randint(0, num_blocks - 1)
|
||||
for _ in range(max_num_blocks_per_seq)
|
||||
]
|
||||
block_tables.append(block_table)
|
||||
block_tables = torch.tensor(block_tables, dtype=torch.int, device='cuda')
|
||||
head_mapping = torch.arange(num_heads, dtype=torch.int32, device="cuda")
|
||||
|
||||
scale = float(1.0 / (head_size**0.5))
|
||||
|
||||
num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
||||
assert num_heads % num_kv_heads == 0
|
||||
num_queries_per_kv = num_heads // num_kv_heads
|
||||
head_mapping = torch.repeat_interleave(
|
||||
torch.arange(num_kv_heads, dtype=torch.int32, device="cuda"),
|
||||
num_queries_per_kv)
|
||||
|
||||
output = torch.empty(num_tokens,
|
||||
num_heads,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device='cuda')
|
||||
attention_ops.single_query_cached_kv_attention(
|
||||
output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
head_mapping,
|
||||
scale,
|
||||
block_tables,
|
||||
context_lens,
|
||||
block_size,
|
||||
max_context_len,
|
||||
None, # ALiBi slopes.
|
||||
)
|
||||
|
||||
ref_output = torch.empty_like(query)
|
||||
ref_single_query_cached_kv_attention(
|
||||
ref_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
block_tables,
|
||||
context_lens,
|
||||
)
|
||||
# NOTE(woosuk): Due to the difference in the data types the two
|
||||
# implementations use for attention softmax logits and accumulation,
|
||||
# there is a small difference in the final outputs.
|
||||
# We should use a relaxed tolerance for the test.
|
||||
assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def run_multi_query_kv_attention(
|
||||
def test_multi_query_kv_attention(
|
||||
num_seqs: int,
|
||||
num_heads: int,
|
||||
num_heads: Tuple[int, int],
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
) -> None:
|
||||
random.seed(seed)
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
seq_lens = random.sample(range(1, MAX_SEQ_LEN), num_seqs)
|
||||
num_tokens = sum(seq_lens)
|
||||
|
||||
scale = float(1.0 / (head_size**0.5))
|
||||
num_query_heads, num_kv_heads = num_heads
|
||||
qkv = torch.empty(num_tokens,
|
||||
3,
|
||||
num_heads,
|
||||
num_query_heads + 2 * num_kv_heads,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device='cuda')
|
||||
qkv.uniform_(-1e-3, 1e-3)
|
||||
query, key, value = qkv.unbind(dim=1)
|
||||
device="cuda")
|
||||
qkv.uniform_(-scale, scale)
|
||||
query, key, value = qkv.split(
|
||||
[num_query_heads, num_kv_heads, num_kv_heads], dim=1)
|
||||
|
||||
attn_op = xops.fmha.cutlass.FwOp()
|
||||
num_queries_per_kv = num_query_heads // num_kv_heads
|
||||
if num_queries_per_kv > 1:
|
||||
# Handle MQA and GQA
|
||||
key = torch.repeat_interleave(key, num_queries_per_kv, dim=1)
|
||||
value = torch.repeat_interleave(value, num_queries_per_kv, dim=1)
|
||||
attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)
|
||||
output = xops.memory_efficient_attention_forward(
|
||||
query.unsqueeze(0),
|
||||
@ -275,7 +269,6 @@ def run_multi_query_kv_attention(
|
||||
attn_bias=attn_bias,
|
||||
p=0.0,
|
||||
scale=scale,
|
||||
op=attn_op,
|
||||
)
|
||||
output = output.squeeze(0)
|
||||
|
||||
@ -287,40 +280,7 @@ def run_multi_query_kv_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
scale,
|
||||
dtype,
|
||||
)
|
||||
assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)
|
||||
|
||||
|
||||
def test_single_query_cached_kv_attention() -> None:
|
||||
torch.random.manual_seed(TEST_SEED)
|
||||
torch.cuda.manual_seed(TEST_SEED)
|
||||
for dtype in [torch.half, torch.bfloat16, torch.float]:
|
||||
for block_size in [8, 16, 32]:
|
||||
for head_size in [64, 80, 96, 112, 128, 256]:
|
||||
print(f'Testing single_query_cached_kv_attention with '
|
||||
f'dtype={dtype}, block_size={block_size}, '
|
||||
f'head_size={head_size}')
|
||||
run_single_query_cached_kv_attention(
|
||||
num_tokens=37,
|
||||
num_heads=3,
|
||||
head_size=head_size,
|
||||
block_size=block_size,
|
||||
num_blocks=1024,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
|
||||
def test_multi_query_kv_attention() -> None:
|
||||
torch.random.manual_seed(TEST_SEED)
|
||||
torch.cuda.manual_seed(TEST_SEED)
|
||||
for dtype in [torch.half, torch.bfloat16, torch.float]:
|
||||
for head_size in [64, 80, 96, 112, 128, 256]:
|
||||
print(f'Testing multi_query_kv_attention with dtype={dtype}, '
|
||||
f'head_size={head_size}')
|
||||
run_multi_query_kv_attention(
|
||||
num_seqs=5,
|
||||
num_heads=3,
|
||||
head_size=head_size,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
@ -1,12 +1,32 @@
|
||||
import random
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import cache_ops
|
||||
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||
NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
|
||||
NUM_LAYERS = [5] # Arbitrary values for testing
|
||||
NUM_HEADS = [8] # Arbitrary values for testing
|
||||
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
|
||||
BLOCK_SIZES = [8, 16, 32]
|
||||
NUM_BLOCKS = [1024] # Arbitrary values for testing
|
||||
NUM_MAPPINGS = [32, 256] # Arbitrary values for testing
|
||||
SEEDS = [0]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
|
||||
@pytest.mark.parametrize("num_layers", NUM_LAYERS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@torch.inference_mode()
|
||||
def run_copy_blocks(
|
||||
def test_copy_blocks(
|
||||
kv_cache_factory,
|
||||
num_mappings: int,
|
||||
num_layers: int,
|
||||
num_heads: int,
|
||||
@ -14,48 +34,43 @@ def run_copy_blocks(
|
||||
block_size: int,
|
||||
num_blocks: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
) -> None:
|
||||
# Generate random block mappings.
|
||||
random.seed(seed)
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
# Generate random block mappings where each source block is mapped to two
|
||||
# destination blocks.
|
||||
assert 2 * num_mappings <= num_blocks
|
||||
src_blocks = random.sample(range(num_blocks), num_mappings)
|
||||
remainig_blocks = list(set(range(num_blocks)) - set(src_blocks))
|
||||
dst_blocks = random.sample(remainig_blocks, num_mappings)
|
||||
block_mapping = {src: [dst] for src, dst in zip(src_blocks, dst_blocks)}
|
||||
dst_blocks = random.sample(remainig_blocks, 2 * num_mappings)
|
||||
block_mapping = {}
|
||||
for i in range(num_mappings):
|
||||
src = src_blocks[i]
|
||||
dst1 = dst_blocks[2 * i]
|
||||
dst2 = dst_blocks[2 * i + 1]
|
||||
block_mapping[src] = [dst1, dst2]
|
||||
|
||||
# Create the KV cache.
|
||||
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
||||
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
|
||||
key_caches = []
|
||||
for _ in range(num_layers):
|
||||
key_cache = torch.randn(size=key_cache_shape,
|
||||
dtype=dtype,
|
||||
device='cuda')
|
||||
key_caches.append(key_cache)
|
||||
cloned_key_caches = []
|
||||
for key_cache in key_caches:
|
||||
cloned_key_caches.append(key_cache.clone())
|
||||
# Create the KV caches.
|
||||
key_caches, value_caches = kv_cache_factory(num_blocks, block_size,
|
||||
num_layers, num_heads,
|
||||
head_size, dtype, seed)
|
||||
|
||||
value_cache_shape = (num_blocks, num_heads, head_size, block_size)
|
||||
value_caches = []
|
||||
for _ in range(num_layers):
|
||||
value_cache = torch.randn(size=value_cache_shape,
|
||||
dtype=dtype,
|
||||
device='cuda')
|
||||
value_caches.append(value_cache)
|
||||
cloned_value_caches = []
|
||||
for value_cache in value_caches:
|
||||
cloned_value_caches.append(value_cache.clone())
|
||||
# Clone the KV caches.
|
||||
cloned_key_caches = [key_cache.clone() for key_cache in key_caches]
|
||||
cloned_value_caches = [value_cache.clone() for value_cache in value_caches]
|
||||
|
||||
# Call the copy blocks kernel.
|
||||
cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
|
||||
|
||||
# Reference implementation.
|
||||
# Run the reference implementation.
|
||||
for src, dsts in block_mapping.items():
|
||||
for dst in dsts:
|
||||
for key_cache, cloned_key_cache in zip(key_caches,
|
||||
cloned_key_caches):
|
||||
for cloned_key_cache in cloned_key_caches:
|
||||
cloned_key_cache[dst] = cloned_key_cache[src]
|
||||
for value_cache, cloned_value_cache in zip(value_caches,
|
||||
cloned_value_caches):
|
||||
for cloned_value_cache in cloned_value_caches:
|
||||
cloned_value_cache[dst] = cloned_value_cache[src]
|
||||
|
||||
# Compare the results.
|
||||
@ -66,15 +81,29 @@ def run_copy_blocks(
|
||||
assert torch.allclose(value_cache, cloned_value_cache)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@torch.inference_mode()
|
||||
def run_reshape_and_cache(
|
||||
def test_reshape_and_cache(
|
||||
kv_cache_factory,
|
||||
num_tokens: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
block_size: int,
|
||||
num_blocks: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
) -> None:
|
||||
random.seed(seed)
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
# Create a random slot mapping.
|
||||
num_slots = block_size * num_blocks
|
||||
slot_mapping = random.sample(range(num_slots), num_tokens)
|
||||
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device='cuda')
|
||||
@ -87,110 +116,31 @@ def run_reshape_and_cache(
|
||||
device='cuda')
|
||||
_, key, value = qkv.unbind(dim=1)
|
||||
|
||||
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
||||
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
|
||||
key_cache = torch.randn(size=key_cache_shape, dtype=dtype, device='cuda')
|
||||
cloned_key_cache = key_cache.clone()
|
||||
# Create the KV caches.
|
||||
key_caches, value_caches = kv_cache_factory(num_blocks, block_size, 1,
|
||||
num_heads, head_size, dtype,
|
||||
seed)
|
||||
key_cache, value_cache = key_caches[0], value_caches[0]
|
||||
|
||||
value_cache_shape = (num_blocks, num_heads, head_size, block_size)
|
||||
value_cache = torch.randn(size=value_cache_shape,
|
||||
dtype=dtype,
|
||||
device='cuda')
|
||||
# Clone the KV caches.
|
||||
cloned_key_cache = key_cache.clone()
|
||||
cloned_value_cache = value_cache.clone()
|
||||
|
||||
# Call the reshape_and_cache kernel.
|
||||
cache_ops.reshape_and_cache(key, value, key_cache, value_cache,
|
||||
slot_mapping)
|
||||
|
||||
# Run the reference implementation.
|
||||
reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
|
||||
block_indicies = torch.div(slot_mapping, block_size, rounding_mode='floor')
|
||||
block_indicies = block_indicies.cpu().tolist()
|
||||
block_offsets = slot_mapping % block_size
|
||||
block_offsets = block_offsets.cpu().tolist()
|
||||
for i in range(num_tokens):
|
||||
reshaped_key = key.reshape(num_tokens, num_heads, head_size // x, x)
|
||||
block_idx = torch.div(slot_mapping[i],
|
||||
block_size,
|
||||
rounding_mode='floor')
|
||||
block_offset = slot_mapping[i] % block_size
|
||||
block_idx = block_indicies[i]
|
||||
block_offset = block_offsets[i]
|
||||
cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i]
|
||||
cloned_value_cache[block_idx, :, :, block_offset] = value[i]
|
||||
|
||||
assert torch.allclose(key_cache, cloned_key_cache)
|
||||
assert torch.allclose(value_cache, cloned_value_cache)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def run_gather_cached_kv(
|
||||
num_tokens: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
block_size: int,
|
||||
num_blocks: int,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
num_slots = block_size * num_blocks
|
||||
slot_mapping = random.sample(range(num_slots), num_tokens)
|
||||
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device='cuda')
|
||||
|
||||
qkv = torch.randn(num_tokens,
|
||||
3,
|
||||
num_heads,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device='cuda')
|
||||
_, key, value = qkv.unbind(dim=1)
|
||||
|
||||
qkv_clone = qkv.clone()
|
||||
_, cloned_key, cloned_value = qkv_clone.unbind(dim=1)
|
||||
|
||||
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
||||
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
|
||||
key_cache = torch.randn(size=key_cache_shape, dtype=dtype, device='cuda')
|
||||
|
||||
value_cache_shape = (num_blocks, num_heads, head_size, block_size)
|
||||
value_cache = torch.randn(size=value_cache_shape,
|
||||
dtype=dtype,
|
||||
device='cuda')
|
||||
|
||||
cache_ops.gather_cached_kv(key, value, key_cache, value_cache,
|
||||
slot_mapping)
|
||||
|
||||
# Reference implementation.
|
||||
for i in range(num_tokens):
|
||||
reshaped_key = cloned_key.reshape(num_tokens, num_heads,
|
||||
head_size // x, x)
|
||||
block_idx = torch.div(slot_mapping[i],
|
||||
block_size,
|
||||
rounding_mode='floor')
|
||||
block_offset = slot_mapping[i] % block_size
|
||||
reshaped_key[i] = key_cache[block_idx, :, :, block_offset, :]
|
||||
cloned_value[i] = value_cache[block_idx, :, :, block_offset]
|
||||
|
||||
assert torch.allclose(key, cloned_key)
|
||||
assert torch.allclose(value, cloned_value)
|
||||
|
||||
|
||||
def test_copy_blocks() -> None:
|
||||
for dtype in [torch.half, torch.bfloat16, torch.float]:
|
||||
run_copy_blocks(num_mappings=23,
|
||||
num_layers=7,
|
||||
num_heads=17,
|
||||
head_size=16,
|
||||
block_size=8,
|
||||
num_blocks=1024,
|
||||
dtype=dtype)
|
||||
|
||||
|
||||
def test_reshape_and_cache() -> None:
|
||||
for dtype in [torch.half, torch.bfloat16, torch.float]:
|
||||
run_reshape_and_cache(num_tokens=3,
|
||||
num_heads=2,
|
||||
head_size=16,
|
||||
block_size=8,
|
||||
num_blocks=2,
|
||||
dtype=dtype)
|
||||
|
||||
|
||||
def test_gather_cached_kv() -> None:
|
||||
for dtype in [torch.half, torch.bfloat16, torch.float]:
|
||||
run_gather_cached_kv(num_tokens=3,
|
||||
num_heads=2,
|
||||
head_size=16,
|
||||
block_size=8,
|
||||
num_blocks=2,
|
||||
dtype=dtype)
|
||||
|
@ -1,35 +1,50 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm import layernorm_ops
|
||||
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||
HIDDEN_SIZES = [67, 768, 2048, 5120, 8192] # Arbitrary values for testing
|
||||
NUM_TOKENS = [7, 83, 4096] # Arbitrary values for testing
|
||||
SEEDS = [0]
|
||||
|
||||
|
||||
class RefRMSNorm(nn.Module):
|
||||
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
super().__init__()
|
||||
weight = torch.empty(hidden_size)
|
||||
weight.uniform_(-1e-3, 1e-3)
|
||||
weight.normal_(mean=1.0, std=0.1)
|
||||
self.weight = nn.Parameter(weight)
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
variance = hidden_states.to(torch.float32).pow(2).mean(-1,
|
||||
keepdim=True)
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance +
|
||||
self.variance_epsilon)
|
||||
if self.weight.dtype in [torch.half, torch.float16, torch.bfloat16]:
|
||||
hidden_states = hidden_states.to(self.weight.dtype)
|
||||
return self.weight * hidden_states
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@torch.inference_mode()
|
||||
def run_rms_norm(
|
||||
def test_rms_norm(
|
||||
num_tokens: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
) -> None:
|
||||
x = torch.randn(num_tokens, hidden_size, dtype=dtype, device='cuda')
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
scale = float(hidden_size**-0.5)
|
||||
x = torch.empty(num_tokens, hidden_size, dtype=dtype, device="cuda")
|
||||
x.uniform_(-scale, scale)
|
||||
ref = RefRMSNorm(hidden_size).to(dtype).cuda()
|
||||
|
||||
out = torch.empty_like(x)
|
||||
@ -40,17 +55,4 @@ def run_rms_norm(
|
||||
ref.variance_epsilon,
|
||||
)
|
||||
ref_out = ref(x)
|
||||
assert torch.allclose(out, ref_out, atol=1e-3, rtol=1e-5)
|
||||
|
||||
|
||||
def test_rms_norm() -> None:
|
||||
for dtype in [torch.half, torch.bfloat16, torch.float]:
|
||||
for num_tokens in [7, 128, 2048]:
|
||||
for hidden_size in [13, 64, 1024, 5120]:
|
||||
print(f'Testing RMS kernel with dtype={dtype}, num_tokens='
|
||||
f'{num_tokens}, hidden_size={hidden_size}')
|
||||
run_rms_norm(
|
||||
num_tokens=num_tokens,
|
||||
hidden_size=hidden_size,
|
||||
dtype=dtype,
|
||||
)
|
||||
assert torch.allclose(out, ref_out, atol=1e-2, rtol=1e-5)
|
||||
|
@ -1,47 +1,70 @@
|
||||
from typing import Tuple
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm import pos_encoding_ops
|
||||
|
||||
IS_NEOX_STYLE = [True, False]
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
|
||||
ROTARY_DIMS = [None, 32] # None means rotary dim == head size
|
||||
NUM_HEADS = [7, 12, 40, 52] # Arbitrary values for testing
|
||||
NUM_TOKENS = [11, 83, 2048] # Arbitrary values for testing
|
||||
SEEDS = [0]
|
||||
|
||||
def rotate_half(x: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
def rotate_neox(x: torch.Tensor) -> torch.Tensor:
|
||||
x1 = x[..., :x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2:]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(
|
||||
def rotate_gptj(x: torch.Tensor) -> torch.Tensor:
|
||||
x1 = x[..., ::2]
|
||||
x2 = x[..., 1::2]
|
||||
x = torch.stack((-x2, x1), dim=-1)
|
||||
return x.flatten(-2)
|
||||
|
||||
|
||||
def apply_rope(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
is_neox_style: bool,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
rotate_fn = rotate_neox if is_neox_style else rotate_gptj
|
||||
q_embed = (q * cos) + (rotate_fn(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_fn(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
class RefRotaryEmbeddingNeox(nn.Module):
|
||||
"""Reference implementation of the GPT-NeoX style rotary embedding."""
|
||||
class RefRotaryEmbedding(nn.Module):
|
||||
"""Reference implementation of rotary embedding."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
max_position_embeddings: int = 2048,
|
||||
is_neox_style: bool,
|
||||
max_position_embeddings: int = 8192,
|
||||
base: int = 10000,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.rotary_dim = dim
|
||||
self.is_neox_style = is_neox_style
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
# Create cos and sin embeddings.
|
||||
inv_freq = 1.0 / (base**(torch.arange(0, dim, 2) / dim))
|
||||
t = torch.arange(max_position_embeddings).float()
|
||||
freqs = torch.einsum("i,j->ij", t, inv_freq.float())
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
if is_neox_style:
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
else:
|
||||
emb = torch.repeat_interleave(freqs, 2, -1)
|
||||
cos = emb.cos().to(dtype=inv_freq.dtype)
|
||||
sin = emb.sin().to(dtype=inv_freq.dtype)
|
||||
self.register_buffer("cos_cached", cos, persistent=False)
|
||||
@ -53,7 +76,6 @@ class RefRotaryEmbeddingNeox(nn.Module):
|
||||
query: torch.Tensor, # [num_tokens, num_heads, head_size]
|
||||
key: torch.Tensor, # [num_tokens, num_heads, head_size]
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
query_rot = query[..., :self.rotary_dim]
|
||||
query_pass = query[..., self.rotary_dim:]
|
||||
key_rot = key[..., :self.rotary_dim]
|
||||
@ -63,7 +85,9 @@ class RefRotaryEmbeddingNeox(nn.Module):
|
||||
key_rot = key_rot.transpose(0, 1)
|
||||
cos = F.embedding(positions, self.cos_cached)
|
||||
sin = F.embedding(positions, self.sin_cached)
|
||||
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
|
||||
|
||||
query_rot, key_rot = apply_rope(query_rot, key_rot, cos, sin,
|
||||
self.is_neox_style)
|
||||
query_rot = query_rot.transpose(0, 1).contiguous()
|
||||
key_rot = key_rot.transpose(0, 1).contiguous()
|
||||
|
||||
@ -74,30 +98,44 @@ class RefRotaryEmbeddingNeox(nn.Module):
|
||||
return query, key
|
||||
|
||||
|
||||
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@torch.inference_mode()
|
||||
def run_rotary_embedding_neox(
|
||||
def test_rotary_embedding(
|
||||
is_neox_style: bool,
|
||||
num_tokens: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
max_position: int,
|
||||
rotary_dim: int,
|
||||
rotary_dim: Optional[int],
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
max_position: int = 8192,
|
||||
base: int = 10000,
|
||||
) -> None:
|
||||
positions = torch.randint(0, max_position, (num_tokens, ), device='cuda')
|
||||
if rotary_dim is None:
|
||||
rotary_dim = head_size
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
positions = torch.randint(0, max_position, (num_tokens, ), device="cuda")
|
||||
query = torch.randn(num_tokens,
|
||||
num_heads * head_size,
|
||||
dtype=dtype,
|
||||
device='cuda')
|
||||
device="cuda")
|
||||
key = torch.randn(num_tokens,
|
||||
num_heads * head_size,
|
||||
dtype=dtype,
|
||||
device='cuda')
|
||||
device="cuda")
|
||||
|
||||
# Create the rotary embedding.
|
||||
inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim))
|
||||
t = torch.arange(max_position).float()
|
||||
freqs = torch.einsum('i,j -> ij', t, inv_freq.float())
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq.float())
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
cos_sin_cache = torch.cat((cos, sin), dim=-1)
|
||||
@ -106,20 +144,22 @@ def run_rotary_embedding_neox(
|
||||
# Run the kernel. The kernel is in-place, so we need to clone the inputs.
|
||||
out_query = query.clone()
|
||||
out_key = key.clone()
|
||||
pos_encoding_ops.rotary_embedding_neox(
|
||||
pos_encoding_ops.rotary_embedding(
|
||||
positions,
|
||||
out_query,
|
||||
out_key,
|
||||
head_size,
|
||||
cos_sin_cache,
|
||||
is_neox_style,
|
||||
)
|
||||
|
||||
# Run the reference implementation.
|
||||
ref_rotary_embedding = RefRotaryEmbeddingNeox(
|
||||
ref_rotary_embedding = RefRotaryEmbedding(
|
||||
dim=rotary_dim,
|
||||
is_neox_style=is_neox_style,
|
||||
max_position_embeddings=max_position,
|
||||
base=base,
|
||||
).to(dtype=dtype, device='cuda')
|
||||
).to(dtype=dtype, device="cuda")
|
||||
ref_query, ref_key = ref_rotary_embedding(
|
||||
positions,
|
||||
query.view(num_tokens, num_heads, head_size),
|
||||
@ -129,19 +169,5 @@ def run_rotary_embedding_neox(
|
||||
ref_key = ref_key.view(num_tokens, num_heads * head_size)
|
||||
|
||||
# Compare the results.
|
||||
assert torch.allclose(out_query, ref_query, atol=1e-3, rtol=1e-5)
|
||||
assert torch.allclose(out_key, ref_key, atol=1e-3, rtol=1e-5)
|
||||
|
||||
|
||||
def test_rotary_embedding_neox() -> None:
|
||||
for dtype in [torch.half, torch.bfloat16, torch.float]:
|
||||
for head_size in [32, 64, 80, 96, 128, 160, 192, 256]:
|
||||
print(f'Running tests for head_size={head_size} and dtype={dtype}')
|
||||
run_rotary_embedding_neox(
|
||||
num_tokens=2145,
|
||||
num_heads=5,
|
||||
head_size=head_size,
|
||||
max_position=8192,
|
||||
rotary_dim=head_size,
|
||||
dtype=dtype,
|
||||
)
|
||||
assert torch.allclose(out_query, ref_query, atol=1e-5, rtol=1e-5)
|
||||
assert torch.allclose(out_key, ref_key, atol=1e-5, rtol=1e-5)
|
||||
|
45
tests/models/test_models.py
Normal file
45
tests/models/test_models.py
Normal file
@ -0,0 +1,45 @@
|
||||
"""Compare the outputs of HF and vLLM when using greedy sampling.
|
||||
|
||||
Run `pytest tests/models/test_models.py --forked`.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
MODELS = [
|
||||
"facebook/opt-125m",
|
||||
"gpt2",
|
||||
"bigcode/tiny_starcoder_py",
|
||||
"EleutherAI/gpt-j-6b",
|
||||
"EleutherAI/pythia-70m",
|
||||
"bigscience/bloom-560m",
|
||||
"mosaicml/mpt-7b",
|
||||
"tiiuae/falcon-7b",
|
||||
"meta-llama/Llama-2-7b-hf",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
def test_models(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
) -> None:
|
||||
hf_model = hf_runner(model, dtype=dtype)
|
||||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||
del hf_model
|
||||
|
||||
vllm_model = vllm_runner(model, dtype=dtype)
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
del vllm_model
|
||||
|
||||
for i in range(len(example_prompts)):
|
||||
hf_output_ids, hf_output_str = hf_outputs[i]
|
||||
vllm_output_ids, vllm_output_str = vllm_outputs[i]
|
||||
assert hf_output_str == vllm_output_str, (
|
||||
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
|
||||
assert hf_output_ids == vllm_output_ids, (
|
||||
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
|
46
tests/samplers/test_beam_search.py
Normal file
46
tests/samplers/test_beam_search.py
Normal file
@ -0,0 +1,46 @@
|
||||
"""Compare the outputs of HF and vLLM when using beam search.
|
||||
|
||||
Run `pytest tests/samplers/test_beam_search.py --forked`.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
# FIXME(zhuohan): The test can not pass if we:
|
||||
# 1. Increase max_tokens to 256.
|
||||
# 2. Increase beam_width to 8.
|
||||
# 3. Use the model "huggyllama/llama-7b".
|
||||
MAX_TOKENS = [128]
|
||||
BEAM_WIDTHS = [4]
|
||||
MODELS = ["facebook/opt-125m"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", MAX_TOKENS)
|
||||
@pytest.mark.parametrize("beam_width", BEAM_WIDTHS)
|
||||
def test_beam_search_single_input(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
beam_width: int,
|
||||
) -> None:
|
||||
hf_model = hf_runner(model, dtype=dtype)
|
||||
hf_outputs = hf_model.generate_beam_search(example_prompts, beam_width,
|
||||
max_tokens)
|
||||
del hf_model
|
||||
|
||||
vllm_model = vllm_runner(model, dtype=dtype)
|
||||
vllm_outputs = vllm_model.generate_beam_search(example_prompts, beam_width,
|
||||
max_tokens)
|
||||
del vllm_model
|
||||
|
||||
for i in range(len(example_prompts)):
|
||||
hf_output_ids, _ = hf_outputs[i]
|
||||
vllm_output_ids, _ = vllm_outputs[i]
|
||||
assert len(hf_output_ids) == len(vllm_output_ids)
|
||||
for j in range(len(hf_output_ids)):
|
||||
assert hf_output_ids[j] == vllm_output_ids[j], (
|
||||
f"Test{i} output{j}:\nHF: {hf_output_ids}\n"
|
||||
f"vLLM: {vllm_output_ids}")
|
@ -8,7 +8,7 @@ from vllm.entrypoints.llm import LLM
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
__version__ = "0.1.4"
|
||||
__version__ = "0.1.5"
|
||||
|
||||
__all__ = [
|
||||
"LLM",
|
||||
|
@ -24,9 +24,16 @@ class ModelConfig:
|
||||
downloading the model and tokenizer.
|
||||
download_dir: Directory to download and load the weights, default to the
|
||||
default cache directory of huggingface.
|
||||
use_np_weights: Save a numpy copy of model weights for faster loading.
|
||||
This can increase the disk usage by up to 2x.
|
||||
use_dummy_weights: Use dummy values for model weights (for profiling).
|
||||
load_format: The format of the model weights to load:
|
||||
"auto" will try to load the weights in the safetensors format and
|
||||
fall back to the pytorch bin format if safetensors format is
|
||||
not available.
|
||||
"pt" will load the weights in the pytorch bin format.
|
||||
"safetensors" will load the weights in the safetensors format.
|
||||
"npcache" will load the weights in pytorch format and store
|
||||
a numpy cache to speed up the loading.
|
||||
"dummy" will initialize the weights with random values, which is
|
||||
mainly for profiling.
|
||||
dtype: Data type for model weights and activations. The "auto" option
|
||||
will use FP16 precision for FP32 and FP16 models, and BF16 precision
|
||||
for BF16 models.
|
||||
@ -40,8 +47,7 @@ class ModelConfig:
|
||||
tokenizer_mode: str,
|
||||
trust_remote_code: bool,
|
||||
download_dir: Optional[str],
|
||||
use_np_weights: bool,
|
||||
use_dummy_weights: bool,
|
||||
load_format: str,
|
||||
dtype: str,
|
||||
seed: int,
|
||||
) -> None:
|
||||
@ -50,14 +56,24 @@ class ModelConfig:
|
||||
self.tokenizer_mode = tokenizer_mode
|
||||
self.trust_remote_code = trust_remote_code
|
||||
self.download_dir = download_dir
|
||||
self.use_np_weights = use_np_weights
|
||||
self.use_dummy_weights = use_dummy_weights
|
||||
self.load_format = load_format
|
||||
self.seed = seed
|
||||
|
||||
self.hf_config = get_config(model, trust_remote_code)
|
||||
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
|
||||
self._verify_load_format()
|
||||
self._verify_tokenizer_mode()
|
||||
|
||||
def _verify_load_format(self) -> None:
|
||||
load_format = self.load_format.lower()
|
||||
if load_format not in [
|
||||
"auto", "pt", "safetensors", "npcache", "dummy"
|
||||
]:
|
||||
raise ValueError(
|
||||
f"Unknown load format: {self.load_format}. Must be one of "
|
||||
"'auto', 'pt', 'safetensors', 'npcache', or 'dummy'.")
|
||||
self.load_format = load_format
|
||||
|
||||
def _verify_tokenizer_mode(self) -> None:
|
||||
tokenizer_mode = self.tokenizer_mode.lower()
|
||||
if tokenizer_mode not in ["auto", "slow"]:
|
||||
|
@ -172,9 +172,7 @@ class BlockSpaceManager:
|
||||
def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]:
|
||||
# CPU block -> GPU block.
|
||||
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
|
||||
for seq in seq_group.get_seqs():
|
||||
if seq.is_finished():
|
||||
continue
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
|
||||
new_block_table: BlockTable = []
|
||||
block_table = self.block_tables[seq.seq_id]
|
||||
|
||||
@ -203,9 +201,7 @@ class BlockSpaceManager:
|
||||
def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]:
|
||||
# GPU block -> CPU block.
|
||||
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
|
||||
for seq in seq_group.get_seqs():
|
||||
if seq.is_finished():
|
||||
continue
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||
new_block_table: BlockTable = []
|
||||
block_table = self.block_tables[seq.seq_id]
|
||||
|
||||
|
@ -1,14 +1,13 @@
|
||||
import enum
|
||||
import time
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
from vllm.config import CacheConfig, SchedulerConfig
|
||||
from vllm.core.block_manager import BlockSpaceManager
|
||||
from vllm.core.policy import PolicyFactory
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
|
||||
SequenceGroupMetadata, SequenceOutputs,
|
||||
SequenceStatus)
|
||||
SequenceGroupMetadata, SequenceStatus)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -64,6 +63,9 @@ class Scheduler:
|
||||
self.scheduler_config = scheduler_config
|
||||
self.cache_config = cache_config
|
||||
|
||||
self.prompt_limit = min(self.scheduler_config.max_model_len,
|
||||
self.scheduler_config.max_num_batched_tokens)
|
||||
|
||||
# Instantiate the scheduling policy.
|
||||
self.policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||
# Create the block space manager.
|
||||
@ -73,6 +75,7 @@ class Scheduler:
|
||||
num_cpu_blocks=self.cache_config.num_cpu_blocks,
|
||||
)
|
||||
|
||||
# TODO(zhuohan): Use deque instead of list for better performance.
|
||||
# Sequence groups in the WAITING state.
|
||||
self.waiting: List[SequenceGroup] = []
|
||||
# Sequence groups in the RUNNING state.
|
||||
@ -84,17 +87,26 @@ class Scheduler:
|
||||
# Add sequence groups to the waiting queue.
|
||||
self.waiting.append(seq_group)
|
||||
|
||||
def abort_seq_group(self, request_id: str) -> None:
|
||||
def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None:
|
||||
if isinstance(request_id, str):
|
||||
request_id = (request_id, )
|
||||
request_ids = set(request_id)
|
||||
for state_queue in [self.waiting, self.running, self.swapped]:
|
||||
for seq_group in state_queue:
|
||||
if seq_group.request_id == request_id:
|
||||
# We need to reverse the list as we are removing elements
|
||||
# from it as we iterate over it. If we don't do it,
|
||||
# indices will get messed up and we will skip over elements.
|
||||
for seq_group in reversed(state_queue):
|
||||
if seq_group.request_id in request_ids:
|
||||
# Remove the sequence group from the state queue.
|
||||
state_queue.remove(seq_group)
|
||||
for seq in seq_group.seqs:
|
||||
for seq in seq_group.get_seqs():
|
||||
if seq.is_finished():
|
||||
continue
|
||||
self.free_seq(seq, SequenceStatus.FINISHED_ABORTED)
|
||||
return
|
||||
seq.status = SequenceStatus.FINISHED_ABORTED
|
||||
self.free_seq(seq)
|
||||
request_ids.remove(seq_group.request_id)
|
||||
if not request_ids:
|
||||
return
|
||||
|
||||
def has_unfinished_seqs(self) -> bool:
|
||||
return self.waiting or self.running or self.swapped
|
||||
@ -115,6 +127,10 @@ class Scheduler:
|
||||
if not self.swapped:
|
||||
ignored_seq_groups: List[SequenceGroup] = []
|
||||
scheduled: List[SequenceGroup] = []
|
||||
# The total number of sequences on the fly, including the
|
||||
# requests in the generation phase.
|
||||
num_curr_seqs = sum(seq_group.get_max_num_running_seqs()
|
||||
for seq_group in self.running)
|
||||
num_batched_tokens = 0
|
||||
# Optimization: We do not sort the waiting queue since the preempted
|
||||
# sequence groups are added to the front and the new sequence groups
|
||||
@ -122,19 +138,19 @@ class Scheduler:
|
||||
while self.waiting:
|
||||
seq_group = self.waiting[0]
|
||||
|
||||
assert seq_group.num_seqs() == 1, (
|
||||
"Waiting sequence group should have only one prompt "
|
||||
"sequence.")
|
||||
num_prompt_tokens = seq_group.get_seqs()[0].get_len()
|
||||
prompt_limit = min(
|
||||
self.scheduler_config.max_model_len,
|
||||
self.scheduler_config.max_num_batched_tokens)
|
||||
if num_prompt_tokens > prompt_limit:
|
||||
if num_prompt_tokens > self.prompt_limit:
|
||||
logger.warning(
|
||||
f"Input prompt ({num_prompt_tokens} tokens) is too long"
|
||||
f" and exceeds limit of {prompt_limit}")
|
||||
f" and exceeds limit of {self.prompt_limit}")
|
||||
for seq in seq_group.get_seqs():
|
||||
seq.status = SequenceStatus.FINISHED_IGNORED
|
||||
ignored_seq_groups.append(seq_group)
|
||||
self.waiting.pop(0)
|
||||
break
|
||||
continue
|
||||
|
||||
# If the sequence group cannot be allocated, stop.
|
||||
if not self.block_manager.can_allocate(seq_group):
|
||||
@ -147,11 +163,7 @@ class Scheduler:
|
||||
|
||||
# The total number of sequences in the RUNNING state should not
|
||||
# exceed the maximum number of sequences.
|
||||
num_new_seqs = seq_group.num_seqs(
|
||||
status=SequenceStatus.WAITING)
|
||||
num_curr_seqs = sum(
|
||||
seq_group.num_seqs(status=SequenceStatus.RUNNING)
|
||||
for seq_group in self.running)
|
||||
num_new_seqs = seq_group.get_max_num_running_seqs()
|
||||
if (num_curr_seqs + num_new_seqs >
|
||||
self.scheduler_config.max_num_seqs):
|
||||
break
|
||||
@ -160,6 +172,7 @@ class Scheduler:
|
||||
self._allocate(seq_group)
|
||||
self.running.append(seq_group)
|
||||
num_batched_tokens += num_prompt_tokens
|
||||
num_curr_seqs += num_new_seqs
|
||||
scheduled.append(seq_group)
|
||||
|
||||
if scheduled:
|
||||
@ -205,30 +218,32 @@ class Scheduler:
|
||||
|
||||
# Swap in the sequence groups in the SWAPPED state if possible.
|
||||
self.swapped = self.policy.sort_by_priority(now, self.swapped)
|
||||
while self.swapped and not blocks_to_swap_out:
|
||||
seq_group = self.swapped[0]
|
||||
# If the sequence group has been preempted in this step, stop.
|
||||
if seq_group in preempted:
|
||||
break
|
||||
# If the sequence group cannot be swapped in, stop.
|
||||
if not self.block_manager.can_swap_in(seq_group):
|
||||
break
|
||||
if not preempted:
|
||||
num_curr_seqs = sum(seq_group.get_max_num_running_seqs()
|
||||
for seq_group in self.running)
|
||||
|
||||
# The total number of sequences in the RUNNING state should not
|
||||
# exceed the maximum number of sequences.
|
||||
num_new_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED)
|
||||
num_curr_seqs = sum(
|
||||
seq_group.num_seqs(status=SequenceStatus.RUNNING)
|
||||
for seq_group in self.running)
|
||||
if (num_curr_seqs + num_new_seqs >
|
||||
self.scheduler_config.max_num_seqs):
|
||||
break
|
||||
while self.swapped:
|
||||
seq_group = self.swapped[0]
|
||||
# If the sequence group cannot be swapped in, stop.
|
||||
if not self.block_manager.can_swap_in(seq_group):
|
||||
break
|
||||
|
||||
seq_group = self.swapped.pop(0)
|
||||
self._swap_in(seq_group, blocks_to_swap_in)
|
||||
self._append_slot(seq_group, blocks_to_copy)
|
||||
self.running.append(seq_group)
|
||||
# The total number of sequences in the RUNNING state should not
|
||||
# exceed the maximum number of sequences.
|
||||
num_new_seqs = seq_group.get_max_num_running_seqs()
|
||||
if (num_curr_seqs + num_new_seqs >
|
||||
self.scheduler_config.max_num_seqs):
|
||||
break
|
||||
|
||||
seq_group = self.swapped.pop(0)
|
||||
self._swap_in(seq_group, blocks_to_swap_in)
|
||||
self._append_slot(seq_group, blocks_to_copy)
|
||||
num_curr_seqs += num_new_seqs
|
||||
self.running.append(seq_group)
|
||||
|
||||
# Each sequence in the generation phase only takes one token slot.
|
||||
# Therefore, the number of batched tokens is equal to the number of
|
||||
# sequences in the RUNNING state.
|
||||
num_batched_tokens = sum(
|
||||
seq_group.num_seqs(status=SequenceStatus.RUNNING)
|
||||
for seq_group in self.running)
|
||||
@ -270,40 +285,10 @@ class Scheduler:
|
||||
seq_group_metadata_list.append(seq_group_metadata)
|
||||
return seq_group_metadata_list, scheduler_outputs
|
||||
|
||||
def update(
|
||||
self,
|
||||
seq_outputs: Dict[int, SequenceOutputs],
|
||||
) -> List[SequenceGroup]:
|
||||
scheduled: List[SequenceGroup] = []
|
||||
for seq_group in self.running:
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||
if seq.seq_id in seq_outputs:
|
||||
scheduled.append(seq_group)
|
||||
break
|
||||
def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None:
|
||||
self.block_manager.fork(parent_seq, child_seq)
|
||||
|
||||
# Update the scheduled sequences and free blocks.
|
||||
for seq_group in scheduled:
|
||||
# Process beam search results before processing the new tokens.
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||
output = seq_outputs[seq.seq_id]
|
||||
if seq.seq_id != output.parent_seq_id:
|
||||
# The sequence is a fork of the parent sequence (beam
|
||||
# search). Free the current sequence.
|
||||
self.block_manager.free(seq)
|
||||
# Fork the parent sequence.
|
||||
parent_seq = seq_group.find(output.parent_seq_id)
|
||||
parent_seq.fork(seq)
|
||||
self.block_manager.fork(parent_seq, seq)
|
||||
|
||||
# Process the new tokens.
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||
# Append a new token to the sequence.
|
||||
output = seq_outputs[seq.seq_id]
|
||||
seq.append_token_id(output.output_token, output.logprobs)
|
||||
return scheduled
|
||||
|
||||
def free_seq(self, seq: Sequence, finish_status: SequenceStatus) -> None:
|
||||
seq.status = finish_status
|
||||
def free_seq(self, seq: Sequence) -> None:
|
||||
self.block_manager.free(seq)
|
||||
|
||||
def free_finished_seq_groups(self) -> None:
|
||||
@ -340,8 +325,8 @@ class Scheduler:
|
||||
# If preemption mode is not specified, we determine the mode as follows:
|
||||
# We use recomputation by default since it incurs lower overhead than
|
||||
# swapping. However, when the sequence group has multiple sequences
|
||||
# (e.g., beam search), recomputation is not supported. In such a case,
|
||||
# we use swapping instead.
|
||||
# (e.g., beam search), recomputation is not currently supported. In
|
||||
# such a case, we use swapping instead.
|
||||
# FIXME(woosuk): This makes our scheduling policy a bit bizarre.
|
||||
# As swapped sequences are prioritized over waiting sequences,
|
||||
# sequence groups with multiple sequences are implicitly prioritized
|
||||
@ -349,8 +334,7 @@ class Scheduler:
|
||||
# TODO(woosuk): Support recomputation for sequence groups with multiple
|
||||
# sequences. This may require a more sophisticated CUDA kernel.
|
||||
if preemption_mode is None:
|
||||
seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
|
||||
if len(seqs) == 1:
|
||||
if seq_group.get_max_num_running_seqs() == 1:
|
||||
preemption_mode = PreemptionMode.RECOMPUTE
|
||||
else:
|
||||
preemption_mode = PreemptionMode.SWAP
|
||||
|
@ -15,8 +15,7 @@ class EngineArgs:
|
||||
tokenizer_mode: str = 'auto'
|
||||
trust_remote_code: bool = False
|
||||
download_dir: Optional[str] = None
|
||||
use_np_weights: bool = False
|
||||
use_dummy_weights: bool = False
|
||||
load_format: str = 'auto'
|
||||
dtype: str = 'auto'
|
||||
seed: int = 0
|
||||
worker_use_ray: bool = False
|
||||
@ -65,14 +64,21 @@ class EngineArgs:
|
||||
help='directory to download and load the weights, '
|
||||
'default to the default cache dir of '
|
||||
'huggingface')
|
||||
parser.add_argument('--use-np-weights',
|
||||
action='store_true',
|
||||
help='save a numpy copy of model weights for '
|
||||
'faster loading. This can increase the disk '
|
||||
'usage by up to 2x.')
|
||||
parser.add_argument('--use-dummy-weights',
|
||||
action='store_true',
|
||||
help='use dummy values for model weights')
|
||||
parser.add_argument(
|
||||
'--load-format',
|
||||
type=str,
|
||||
default=EngineArgs.load_format,
|
||||
choices=['auto', 'pt', 'safetensors', 'npcache', 'dummy'],
|
||||
help='The format of the model weights to load. '
|
||||
'"auto" will try to load the weights in the safetensors format '
|
||||
'and fall back to the pytorch bin format if safetensors format '
|
||||
'is not available. '
|
||||
'"pt" will load the weights in the pytorch bin format. '
|
||||
'"safetensors" will load the weights in the safetensors format. '
|
||||
'"npcache" will load the weights in pytorch format and store '
|
||||
'a numpy cache to speed up the loading. '
|
||||
'"dummy" will initialize the weights with random values, '
|
||||
'which is mainly for profiling.')
|
||||
# TODO(woosuk): Support FP32.
|
||||
parser.add_argument(
|
||||
'--dtype',
|
||||
@ -146,9 +152,8 @@ class EngineArgs:
|
||||
# Initialize the configs.
|
||||
model_config = ModelConfig(self.model, self.tokenizer,
|
||||
self.tokenizer_mode, self.trust_remote_code,
|
||||
self.download_dir, self.use_np_weights,
|
||||
self.use_dummy_weights, self.dtype,
|
||||
self.seed)
|
||||
self.download_dir, self.load_format,
|
||||
self.dtype, self.seed)
|
||||
cache_config = CacheConfig(self.block_size,
|
||||
self.gpu_memory_utilization,
|
||||
self.swap_space)
|
||||
|
@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Dict, List, Optional
|
||||
from functools import partial
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
@ -12,7 +13,202 @@ from vllm.sampling_params import SamplingParams
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
|
||||
|
||||
class AsyncEngineDeadError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
def _raise_exception_on_finish(task: asyncio.Task,
|
||||
request_tracker: "RequestTracker") -> None:
|
||||
msg = ("Task finished unexpectedly. This should never happen! "
|
||||
"Please open an issue on Github.")
|
||||
try:
|
||||
try:
|
||||
task.result()
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
except Exception as exc:
|
||||
raise AsyncEngineDeadError(
|
||||
msg + " See stack trace above for the actual cause.") from exc
|
||||
raise AsyncEngineDeadError(msg)
|
||||
except Exception as exc:
|
||||
request_tracker.propagate_exception(exc)
|
||||
raise exc
|
||||
|
||||
|
||||
class AsyncStream:
|
||||
"""A stream of RequestOutputs for a request that can be
|
||||
iterated over asynchronously."""
|
||||
|
||||
def __init__(self, request_id: str) -> None:
|
||||
self.request_id = request_id
|
||||
self._queue = asyncio.Queue()
|
||||
self._finished = False
|
||||
|
||||
def put(self, item: RequestOutput) -> None:
|
||||
if self._finished:
|
||||
return
|
||||
self._queue.put_nowait(item)
|
||||
|
||||
def finish(self) -> None:
|
||||
self._queue.put_nowait(StopIteration)
|
||||
self._finished = True
|
||||
|
||||
@property
|
||||
def finished(self) -> bool:
|
||||
return self._finished
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self) -> RequestOutput:
|
||||
result = await self._queue.get()
|
||||
if result is StopIteration:
|
||||
raise StopAsyncIteration
|
||||
elif isinstance(result, Exception):
|
||||
raise result
|
||||
return result
|
||||
|
||||
|
||||
class RequestTracker:
|
||||
"""Synchronous abstraction for tracking requests."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._request_streams: Dict[str, AsyncStream] = {}
|
||||
self._finished_requests: asyncio.Queue[str] = asyncio.Queue()
|
||||
self._new_requests: asyncio.Queue[Tuple[AsyncStream,
|
||||
dict]] = asyncio.Queue()
|
||||
|
||||
def __contains__(self, item):
|
||||
return item in self._request_streams
|
||||
|
||||
def propagate_exception(self, exc: Exception) -> None:
|
||||
"""Propagate an exception to all request streams."""
|
||||
for stream in self._request_streams.values():
|
||||
stream.put(exc)
|
||||
|
||||
def process_request_output(self,
|
||||
request_output: RequestOutput,
|
||||
*,
|
||||
verbose: bool = False) -> None:
|
||||
"""Process a request output from the engine."""
|
||||
request_id = request_output.request_id
|
||||
|
||||
self._request_streams[request_id].put(request_output)
|
||||
if request_output.finished:
|
||||
if verbose:
|
||||
logger.info(f"Finished request {request_id}.")
|
||||
self.abort_request(request_id)
|
||||
|
||||
def add_request(self, request_id: str,
|
||||
**engine_add_request_kwargs) -> AsyncStream:
|
||||
"""Add a request to be sent to the engine on the next background
|
||||
loop iteration."""
|
||||
if request_id in self._request_streams:
|
||||
raise KeyError(f"Request {request_id} already exists.")
|
||||
|
||||
stream = AsyncStream(request_id)
|
||||
self._new_requests.put_nowait((stream, {
|
||||
"request_id": request_id,
|
||||
**engine_add_request_kwargs
|
||||
}))
|
||||
return stream
|
||||
|
||||
def abort_request(self, request_id: str, *, verbose: bool = False) -> None:
|
||||
"""Abort a request during next background loop iteration."""
|
||||
if verbose:
|
||||
logger.info(f"Aborted request {request_id}.")
|
||||
|
||||
self._finished_requests.put_nowait(request_id)
|
||||
|
||||
if request_id not in self._request_streams or self._request_streams[
|
||||
request_id].finished:
|
||||
# The request has already finished or been aborted.
|
||||
return
|
||||
|
||||
self._request_streams[request_id].finish()
|
||||
|
||||
def get_new_and_finished_requests(self) -> Tuple[List[dict], Set[str]]:
|
||||
"""Get the new requests and finished requests to be
|
||||
sent to the engine."""
|
||||
new_requests: List[dict] = []
|
||||
finished_requests: Set[str] = set()
|
||||
|
||||
while not self._finished_requests.empty():
|
||||
request_id = self._finished_requests.get_nowait()
|
||||
finished_requests.add(request_id)
|
||||
self._request_streams.pop(request_id, None)
|
||||
|
||||
while not self._new_requests.empty():
|
||||
stream, new_request = self._new_requests.get_nowait()
|
||||
if stream.request_id in finished_requests:
|
||||
# The request has already been aborted.
|
||||
stream.finish()
|
||||
continue
|
||||
self._request_streams[stream.request_id] = stream
|
||||
new_requests.append(new_request)
|
||||
|
||||
return new_requests, finished_requests
|
||||
|
||||
|
||||
class _AsyncLLMEngine(LLMEngine):
|
||||
"""Extension of LLMEngine to add async methods."""
|
||||
|
||||
async def step_async(self) -> List[RequestOutput]:
|
||||
"""Performs one decoding iteration and returns newly generated results.
|
||||
The workers are ran asynchronously if possible.
|
||||
|
||||
This function performs one decoding iteration of the engine. It first
|
||||
schedules the sequences to be executed in the next iteration and the
|
||||
token blocks to be swapped in/out/copy. Then, it executes the model
|
||||
and updates the scheduler with the model outputs. Finally, it decodes
|
||||
the sequences and returns the newly generated results.
|
||||
"""
|
||||
(seq_group_metadata_list, scheduler_outputs,
|
||||
early_return) = self._schedule()
|
||||
if early_return is not None:
|
||||
return early_return
|
||||
|
||||
# Execute the model.
|
||||
output = await self._run_workers_async(
|
||||
"execute_model",
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
|
||||
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
|
||||
blocks_to_copy=scheduler_outputs.blocks_to_copy,
|
||||
)
|
||||
|
||||
return self._process_model_outputs(output, scheduler_outputs)
|
||||
|
||||
async def _run_workers_async(
|
||||
self,
|
||||
method: str,
|
||||
*args,
|
||||
get_all_outputs: bool = False,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""Runs the given method on all workers."""
|
||||
all_outputs = []
|
||||
for worker in self.workers:
|
||||
if self.parallel_config.worker_use_ray:
|
||||
executor = partial(worker.execute_method.remote, method)
|
||||
else:
|
||||
executor = getattr(worker, method)
|
||||
|
||||
output = executor(*args, **kwargs)
|
||||
all_outputs.append(output)
|
||||
|
||||
if self.parallel_config.worker_use_ray:
|
||||
all_outputs = await asyncio.gather(*all_outputs)
|
||||
|
||||
if get_all_outputs:
|
||||
return all_outputs
|
||||
|
||||
# Make sure all workers have the same results.
|
||||
output = all_outputs[0]
|
||||
for other_output in all_outputs[1:]:
|
||||
assert output == other_output
|
||||
return output
|
||||
|
||||
|
||||
class AsyncLLMEngine:
|
||||
@ -37,49 +233,117 @@ class AsyncLLMEngine:
|
||||
*args, *kwargs: Arguments for LLMEngine.
|
||||
"""
|
||||
|
||||
_engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine
|
||||
|
||||
def __init__(self,
|
||||
worker_use_ray: bool,
|
||||
engine_use_ray: bool,
|
||||
*args,
|
||||
log_requests: bool = True,
|
||||
start_engine_loop: bool = False,
|
||||
**kwargs) -> None:
|
||||
self.worker_use_ray = worker_use_ray
|
||||
self.engine_use_ray = engine_use_ray
|
||||
self.log_requests = log_requests
|
||||
if not self.engine_use_ray:
|
||||
engine_class = LLMEngine
|
||||
elif self.worker_use_ray:
|
||||
engine_class = ray.remote(num_cpus=0)(LLMEngine).remote
|
||||
else:
|
||||
engine_class = ray.remote(num_gpus=1)(LLMEngine).remote
|
||||
self.engine = engine_class(*args, **kwargs)
|
||||
# Request id -> request output.
|
||||
self.request_outputs: Dict[str, RequestOutput] = {}
|
||||
# Request id -> event to notify that there is new output.
|
||||
self.request_events: Dict[str, asyncio.Event] = {}
|
||||
self.is_engine_running = False
|
||||
self.kicking_request_id: Optional[str] = None
|
||||
self.engine = self._init_engine(*args, **kwargs)
|
||||
|
||||
async def engine_step(self, kicking_request_id: Optional[str] = None):
|
||||
self.request_tracker: RequestTracker = RequestTracker()
|
||||
self.background_loop = None
|
||||
if start_engine_loop:
|
||||
self.start_background_loop()
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
return (self.background_loop is not None
|
||||
and not self.background_loop.done())
|
||||
|
||||
def start_background_loop(self) -> None:
|
||||
"""Start the background loop."""
|
||||
if self.is_running:
|
||||
raise RuntimeError("Background loop is already running.")
|
||||
self.background_loop = asyncio.get_event_loop().create_task(
|
||||
self.run_engine_loop())
|
||||
self.background_loop.add_done_callback(
|
||||
partial(_raise_exception_on_finish,
|
||||
request_tracker=self.request_tracker))
|
||||
|
||||
def _init_engine(self, *args,
|
||||
**kwargs) -> Union[_AsyncLLMEngine, "ray.ObjectRef"]:
|
||||
if not self.engine_use_ray:
|
||||
engine_class = self._engine_class
|
||||
elif self.worker_use_ray:
|
||||
engine_class = ray.remote(num_cpus=0)(self._engine_class).remote
|
||||
else:
|
||||
engine_class = ray.remote(num_gpus=1)(self._engine_class).remote
|
||||
return engine_class(*args, **kwargs)
|
||||
|
||||
async def engine_step(self):
|
||||
"""Kick the engine to process the waiting requests."""
|
||||
self.is_engine_running = True
|
||||
self.kicking_request_id = kicking_request_id
|
||||
|
||||
new_requests, finished_requests = (
|
||||
self.request_tracker.get_new_and_finished_requests())
|
||||
|
||||
for new_request in new_requests:
|
||||
# Add the request into the vLLM engine's waiting queue.
|
||||
# TODO: Maybe add add_request_batch to reduce Ray overhead
|
||||
if self.engine_use_ray:
|
||||
await self.engine.add_request.remote(**new_request)
|
||||
else:
|
||||
self.engine.add_request(**new_request)
|
||||
|
||||
if finished_requests:
|
||||
await self._engine_abort(finished_requests)
|
||||
|
||||
if self.engine_use_ray:
|
||||
request_outputs = await self.engine.step.remote()
|
||||
else:
|
||||
# Yield to the event loop to allow other coroutines to run
|
||||
# while is_engine_running is True. This let the engine to add new
|
||||
# requests into the queue.
|
||||
await asyncio.sleep(0)
|
||||
request_outputs = self.engine.step()
|
||||
self.is_engine_running = False
|
||||
self.kicking_request_id = None
|
||||
request_outputs = await self.engine.step_async()
|
||||
|
||||
# Notify the waiting coroutines that there are new outputs ready.
|
||||
# Put the outputs into the corresponding streams.
|
||||
for request_output in request_outputs:
|
||||
request_id = request_output.request_id
|
||||
self.request_outputs[request_id] = request_output
|
||||
self.request_events[request_id].set()
|
||||
self.request_tracker.process_request_output(
|
||||
request_output, verbose=self.log_requests)
|
||||
|
||||
async def _engine_abort(self, request_ids: Iterable[str]):
|
||||
if self.engine_use_ray:
|
||||
await self.engine.abort_request.remote(request_ids)
|
||||
else:
|
||||
self.engine.abort_request(request_ids)
|
||||
|
||||
async def run_engine_loop(self):
|
||||
while True:
|
||||
await self.engine_step()
|
||||
await asyncio.sleep(0)
|
||||
|
||||
async def add_request(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: Optional[str],
|
||||
sampling_params: SamplingParams,
|
||||
prompt_token_ids: Optional[List[int]] = None,
|
||||
arrival_time: Optional[float] = None,
|
||||
) -> AsyncStream:
|
||||
if self.log_requests:
|
||||
logger.info(f"Received request {request_id}: "
|
||||
f"prompt: {prompt!r}, "
|
||||
f"sampling params: {sampling_params}, "
|
||||
f"prompt token ids: {prompt_token_ids}.")
|
||||
|
||||
if not self.is_running:
|
||||
raise AsyncEngineDeadError(
|
||||
"Background loop is not running. If it was running, "
|
||||
"inspect the output to find the stacktrace of the "
|
||||
"error that caused the background loop to stop "
|
||||
"(AsyncEngineDeadError).")
|
||||
|
||||
stream = self.request_tracker.add_request(
|
||||
request_id,
|
||||
prompt=prompt,
|
||||
sampling_params=sampling_params,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
arrival_time=arrival_time)
|
||||
|
||||
return stream
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
@ -108,76 +372,19 @@ class AsyncLLMEngine:
|
||||
# Preprocess the request.
|
||||
arrival_time = time.time()
|
||||
|
||||
# Create an event to notify us that there is new output from the
|
||||
# vLLM engine.
|
||||
request_event = asyncio.Event()
|
||||
self.request_events[request_id] = request_event
|
||||
try:
|
||||
stream = await self.add_request(request_id,
|
||||
prompt,
|
||||
sampling_params,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
arrival_time=arrival_time)
|
||||
|
||||
if self.log_requests:
|
||||
logger.info(f"Received request {request_id}: "
|
||||
f"prompt: {prompt!r}, "
|
||||
f"sampling params: {sampling_params}, "
|
||||
f"prompt token ids: {prompt_token_ids}.")
|
||||
|
||||
# Add the request into the vLLM engine's waiting queue.
|
||||
if self.engine_use_ray:
|
||||
await self.engine.add_request.remote(
|
||||
request_id,
|
||||
prompt,
|
||||
sampling_params,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
arrival_time=arrival_time)
|
||||
else:
|
||||
self.engine.add_request(request_id,
|
||||
prompt,
|
||||
sampling_params,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
arrival_time=arrival_time)
|
||||
|
||||
# The vLLM engine does not have a background loop that keeps
|
||||
# processing incoming requests. Therefore, we need to keep kicking
|
||||
# the engine to process the requests.
|
||||
while True:
|
||||
if request_id not in self.request_events:
|
||||
# The request has been aborted.
|
||||
return
|
||||
|
||||
# Kick the engine if the engine is not running.
|
||||
if not self.is_engine_running:
|
||||
try:
|
||||
await self.engine_step(request_id)
|
||||
except RuntimeError as e:
|
||||
await self.abort(request_id)
|
||||
raise e
|
||||
|
||||
# Wait for new output. The group_event will be set in engine_step
|
||||
# when there is new output available for the sequence group.
|
||||
# Added a timeout to prevent deadlock.
|
||||
try:
|
||||
await asyncio.wait_for(request_event.wait(),
|
||||
timeout=TIMEOUT_TO_PREVENT_DEADLOCK)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
# Reset the event to wait for the next output.
|
||||
request_event.clear()
|
||||
|
||||
# Decode and return new outputs.
|
||||
request_output = self.request_outputs[request_id]
|
||||
yield request_output
|
||||
|
||||
# Once finished, release the resources of the sequence group.
|
||||
if request_output.finished:
|
||||
if self.log_requests:
|
||||
logger.info(f"Finished request {request_id}.")
|
||||
|
||||
del self.request_outputs[request_id]
|
||||
del self.request_events[request_id]
|
||||
# Kick the engine if the engine is not running. This is to
|
||||
# prevent that there are still requests in engine's waiting
|
||||
# queue to be executed.
|
||||
if not self.is_engine_running:
|
||||
await self.engine_step()
|
||||
break
|
||||
async for request_output in stream:
|
||||
yield request_output
|
||||
except Exception as e:
|
||||
# If there is an exception, abort the request.
|
||||
self._abort(request_id)
|
||||
raise e
|
||||
|
||||
async def abort(self, request_id: str) -> None:
|
||||
"""Abort a request.
|
||||
@ -188,28 +395,26 @@ class AsyncLLMEngine:
|
||||
Args:
|
||||
request_id: The unique id of the request.
|
||||
"""
|
||||
if request_id not in self.request_events:
|
||||
# The request has already finished or been aborted.
|
||||
return
|
||||
if not self.is_running:
|
||||
raise AsyncEngineDeadError(
|
||||
"Background loop is not running. If it was running, "
|
||||
"inspect the output to find the stacktrace of the "
|
||||
"error that caused the background loop to stop "
|
||||
"(AsyncEngineDeadError).")
|
||||
|
||||
if self.log_requests:
|
||||
logger.info(f"Aborted request {request_id}.")
|
||||
return self._abort(request_id)
|
||||
|
||||
if self.engine_use_ray:
|
||||
await self.engine.abort_request.remote(request_id)
|
||||
else:
|
||||
self.engine.abort_request(request_id)
|
||||
def _abort(self, request_id: str) -> None:
|
||||
"""Abort a request.
|
||||
|
||||
if request_id in self.request_events:
|
||||
del self.request_events[request_id]
|
||||
if request_id in self.request_outputs:
|
||||
del self.request_outputs[request_id]
|
||||
Abort a submitted request. If the request is finished or not found,
|
||||
this method will be a no-op.
|
||||
|
||||
# To prevent deadlock when a request is aborted while the engine is
|
||||
# running.
|
||||
if self.kicking_request_id == request_id:
|
||||
self.is_engine_running = False
|
||||
self.kicking_request_id = None
|
||||
Args:
|
||||
request_id: The unique id of the request.
|
||||
"""
|
||||
self.request_tracker.abort_request(request_id,
|
||||
verbose=self.log_requests)
|
||||
|
||||
async def get_model_config(self) -> ModelConfig:
|
||||
"""Get the model configuration of the vLLM engine."""
|
||||
@ -220,7 +425,8 @@ class AsyncLLMEngine:
|
||||
|
||||
@classmethod
|
||||
def from_engine_args(cls,
|
||||
engine_args: AsyncEngineArgs) -> "AsyncLLMEngine":
|
||||
engine_args: AsyncEngineArgs,
|
||||
start_engine_loop: bool = False) -> "AsyncLLMEngine":
|
||||
"""Creates an async LLM engine from the engine arguments."""
|
||||
# Create the engine configs.
|
||||
engine_configs = engine_args.create_engine_configs()
|
||||
@ -235,5 +441,6 @@ class AsyncLLMEngine:
|
||||
distributed_init_method,
|
||||
placement_group,
|
||||
log_requests=not engine_args.disable_log_requests,
|
||||
log_stats=not engine_args.disable_log_stats)
|
||||
log_stats=not engine_args.disable_log_stats,
|
||||
start_engine_loop=start_engine_loop)
|
||||
return engine
|
||||
|
@ -1,17 +1,19 @@
|
||||
import time
|
||||
import copy
|
||||
import time
|
||||
from functools import partial
|
||||
from typing import Any, List, Optional, Tuple, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
|
||||
SchedulerConfig)
|
||||
from vllm.core.scheduler import Scheduler
|
||||
from vllm.core.scheduler import Scheduler, SchedulerOutputs
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.engine.ray_utils import initialize_cluster, ray, RayWorker
|
||||
from vllm.engine.ray_utils import RayWorker, initialize_cluster, ray
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
|
||||
from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup,
|
||||
SequenceGroupMetadata, SequenceOutputs,
|
||||
SequenceStatus)
|
||||
from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
|
||||
get_tokenizer)
|
||||
from vllm.utils import Counter
|
||||
@ -74,9 +76,8 @@ class LLMEngine:
|
||||
f"tokenizer_mode={model_config.tokenizer_mode}, "
|
||||
f"trust_remote_code={model_config.trust_remote_code}, "
|
||||
f"dtype={model_config.dtype}, "
|
||||
f"use_dummy_weights={model_config.use_dummy_weights}, "
|
||||
f"download_dir={model_config.download_dir!r}, "
|
||||
f"use_np_weights={model_config.use_np_weights}, "
|
||||
f"load_format={model_config.load_format}, "
|
||||
f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
|
||||
f"seed={model_config.seed})")
|
||||
# TODO(woosuk): Print more configs in debug mode.
|
||||
@ -135,7 +136,8 @@ class LLMEngine:
|
||||
get_all_outputs=True,
|
||||
)
|
||||
|
||||
def _init_workers_ray(self, placement_group: "PlacementGroup"):
|
||||
def _init_workers_ray(self, placement_group: "PlacementGroup",
|
||||
**ray_remote_kwargs):
|
||||
# Lazy import the Worker to avoid importing torch.cuda/xformers
|
||||
# before CUDA_VISIBLE_DEVICES is set in the Worker
|
||||
from vllm.worker.worker import Worker # pylint: disable=import-outside-toplevel
|
||||
@ -150,6 +152,7 @@ class LLMEngine:
|
||||
scheduling_strategy=PlacementGroupSchedulingStrategy(
|
||||
placement_group=placement_group,
|
||||
placement_group_capture_child_tasks=True),
|
||||
**ray_remote_kwargs,
|
||||
)(RayWorker).remote()
|
||||
self.workers.append(worker)
|
||||
|
||||
@ -255,24 +258,21 @@ class LLMEngine:
|
||||
|
||||
# Create the sequences.
|
||||
block_size = self.cache_config.block_size
|
||||
seqs: List[Sequence] = []
|
||||
for _ in range(sampling_params.best_of):
|
||||
seq_id = next(self.seq_counter)
|
||||
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size)
|
||||
seqs.append(seq)
|
||||
seq_id = next(self.seq_counter)
|
||||
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size)
|
||||
|
||||
# Create the sequence group.
|
||||
seq_group = SequenceGroup(request_id, seqs, sampling_params,
|
||||
seq_group = SequenceGroup(request_id, [seq], sampling_params,
|
||||
arrival_time)
|
||||
|
||||
# Add the sequence group to the scheduler.
|
||||
self.scheduler.add_seq_group(seq_group)
|
||||
|
||||
def abort_request(self, request_id: str) -> None:
|
||||
"""Aborts a request with the given ID.
|
||||
def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
|
||||
"""Aborts a request(s) with the given ID.
|
||||
|
||||
Args:
|
||||
request_id: The ID of the request to abort.
|
||||
request_id: The ID(s) of the request to abort.
|
||||
"""
|
||||
self.scheduler.abort_seq_group(request_id)
|
||||
|
||||
@ -288,6 +288,251 @@ class LLMEngine:
|
||||
"""Returns True if there are unfinished requests."""
|
||||
return self.scheduler.has_unfinished_seqs()
|
||||
|
||||
def _schedule(
|
||||
self
|
||||
) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs,
|
||||
Optional[List[RequestOutput]]]:
|
||||
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
|
||||
if scheduler_outputs.is_empty():
|
||||
return seq_group_metadata_list, scheduler_outputs, [
|
||||
RequestOutput.from_seq_group(seq_group)
|
||||
for seq_group in scheduler_outputs.ignored_seq_groups
|
||||
]
|
||||
return seq_group_metadata_list, scheduler_outputs, None
|
||||
|
||||
def _check_beam_search_early_stopping(
|
||||
self,
|
||||
early_stopping: Union[bool, str],
|
||||
sampling_params: SamplingParams,
|
||||
best_running_seq: Sequence,
|
||||
current_worst_seq: Sequence,
|
||||
) -> bool:
|
||||
assert sampling_params.use_beam_search
|
||||
length_penalty = sampling_params.length_penalty
|
||||
if early_stopping is True:
|
||||
return True
|
||||
|
||||
current_worst_score = (current_worst_seq.get_beam_search_score(
|
||||
length_penalty=length_penalty,
|
||||
eos_token_id=self.tokenizer.eos_token_id))
|
||||
if early_stopping is False:
|
||||
highest_attainable_score = (best_running_seq.get_beam_search_score(
|
||||
length_penalty=length_penalty,
|
||||
eos_token_id=self.tokenizer.eos_token_id))
|
||||
else:
|
||||
assert early_stopping == "never"
|
||||
if length_penalty > 0.0:
|
||||
# If length_penalty > 0.0, beam search will prefer longer
|
||||
# sequences. The highest attainable score calculation is
|
||||
# based on the longest possible sequence length in this case.
|
||||
max_possible_length = max(
|
||||
best_running_seq.get_prompt_len() +
|
||||
sampling_params.max_tokens,
|
||||
self.scheduler_config.max_model_len)
|
||||
highest_attainable_score = (
|
||||
best_running_seq.get_beam_search_score(
|
||||
length_penalty=length_penalty,
|
||||
eos_token_id=self.tokenizer.eos_token_id,
|
||||
seq_len=max_possible_length))
|
||||
else:
|
||||
# Otherwise, beam search will prefer shorter sequences. The
|
||||
# highest attainable score calculation is based on the current
|
||||
# sequence length.
|
||||
highest_attainable_score = (
|
||||
best_running_seq.get_beam_search_score(
|
||||
length_penalty=length_penalty,
|
||||
eos_token_id=self.tokenizer.eos_token_id))
|
||||
return current_worst_score >= highest_attainable_score
|
||||
|
||||
def _process_sequence_group_samples(
|
||||
self, seq_group: SequenceGroup,
|
||||
samples: List[SequenceOutputs]) -> None:
|
||||
parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
|
||||
existing_finished_seqs = seq_group.get_finished_seqs()
|
||||
parent_child_dict = {
|
||||
parent_seq.seq_id: []
|
||||
for parent_seq in parent_seqs
|
||||
}
|
||||
for sample in samples:
|
||||
parent_child_dict[sample.parent_seq_id].append(sample)
|
||||
# List of (child, parent)
|
||||
child_seqs: List[Tuple[Sequence, Sequence]] = []
|
||||
|
||||
# Process the child samples for each parent sequence
|
||||
for parent in parent_seqs:
|
||||
child_samples: List[SequenceOutputs] = parent_child_dict[
|
||||
parent.seq_id]
|
||||
if len(child_samples) == 0:
|
||||
# This parent sequence has no children samples. Remove
|
||||
# the parent sequence from the sequence group since it will
|
||||
# not be used in the future iterations.
|
||||
parent.status = SequenceStatus.FINISHED_ABORTED
|
||||
seq_group.remove(parent.seq_id)
|
||||
self.scheduler.free_seq(parent)
|
||||
continue
|
||||
# Fork the parent sequence if there are multiple child samples.
|
||||
for child_sample in child_samples[:-1]:
|
||||
new_child_seq_id = next(self.seq_counter)
|
||||
child = parent.fork(new_child_seq_id)
|
||||
child.append_token_id(child_sample.output_token,
|
||||
child_sample.logprobs)
|
||||
child_seqs.append((child, parent))
|
||||
# Continue the parent sequence for the last child sample.
|
||||
# We reuse the parent sequence here to reduce redundant memory
|
||||
# copies, especially when using non-beam search sampling methods.
|
||||
last_child_sample = child_samples[-1]
|
||||
parent.append_token_id(last_child_sample.output_token,
|
||||
last_child_sample.logprobs)
|
||||
child_seqs.append((parent, parent))
|
||||
|
||||
for seq, _ in child_seqs:
|
||||
self._decode_sequence(seq)
|
||||
self._check_stop(seq, seq_group.sampling_params)
|
||||
|
||||
# Non-beam search case
|
||||
if not seq_group.sampling_params.use_beam_search:
|
||||
# For newly created child sequences, add them to the sequence group
|
||||
# and fork them in block manager if they are not finished.
|
||||
for seq, parent in child_seqs:
|
||||
if seq is not parent:
|
||||
seq_group.add(seq)
|
||||
if not seq.is_finished():
|
||||
self.scheduler.fork_seq(parent, seq)
|
||||
|
||||
# Free the finished and selected parent sequences' memory in block
|
||||
# manager. Keep them in the sequence group as candidate output.
|
||||
# NOTE: we need to fork the new sequences before freeing the
|
||||
# old sequences.
|
||||
for seq, parent in child_seqs:
|
||||
if seq is parent and seq.is_finished():
|
||||
self.scheduler.free_seq(seq)
|
||||
return
|
||||
|
||||
# Beam search case
|
||||
# Select the child sequences to keep in the sequence group.
|
||||
selected_child_seqs = []
|
||||
unselected_child_seqs = []
|
||||
beam_width = seq_group.sampling_params.best_of
|
||||
length_penalty = seq_group.sampling_params.length_penalty
|
||||
|
||||
# Select the newly finished sequences with the highest scores
|
||||
# to replace existing finished sequences.
|
||||
# Tuple of (seq, parent, is_new)
|
||||
existing_finished_seqs = [(seq, None, False)
|
||||
for seq in existing_finished_seqs]
|
||||
new_finished_seqs = [(seq, parent, True) for seq, parent in child_seqs
|
||||
if seq.is_finished()]
|
||||
all_finished_seqs = existing_finished_seqs + new_finished_seqs
|
||||
# Sort the finished sequences by their scores.
|
||||
all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score(
|
||||
length_penalty=length_penalty,
|
||||
eos_token_id=self.tokenizer.eos_token_id),
|
||||
reverse=True)
|
||||
for seq, parent, is_new in all_finished_seqs[:beam_width]:
|
||||
if is_new:
|
||||
# A newly generated child sequence finishes and has a high
|
||||
# score, so we will add it into the sequence group.
|
||||
selected_child_seqs.append((seq, parent))
|
||||
for seq, parent, is_new in all_finished_seqs[beam_width:]:
|
||||
if is_new:
|
||||
# A newly generated child sequence finishes but has a low
|
||||
# score, so we will not add it into the sequence group.
|
||||
# Additionally, if this sequence is a continuation of a
|
||||
# parent sequence, we will need remove the parent sequence
|
||||
# from the sequence group.
|
||||
unselected_child_seqs.append((seq, parent))
|
||||
else:
|
||||
# An existing finished sequence has a low score, so we will
|
||||
# remove it from the sequence group.
|
||||
seq_group.remove(seq.seq_id)
|
||||
|
||||
# select the top beam_width sequences from the running
|
||||
# sequences for the next iteration to continue the beam
|
||||
# search.
|
||||
running_child_seqs = [(seq, parent) for seq, parent in child_seqs
|
||||
if not seq.is_finished()]
|
||||
# Sort the running sequences by their scores.
|
||||
running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score(
|
||||
length_penalty=length_penalty,
|
||||
eos_token_id=self.tokenizer.eos_token_id),
|
||||
reverse=True)
|
||||
|
||||
# Check if we can stop the beam search.
|
||||
if len(running_child_seqs) == 0:
|
||||
# No running sequences, stop the beam search.
|
||||
stop_beam_search = True
|
||||
elif len(all_finished_seqs) < beam_width:
|
||||
# Not enough finished sequences, continue the beam search.
|
||||
stop_beam_search = False
|
||||
else:
|
||||
# Check the early stopping criteria
|
||||
best_running_seq = running_child_seqs[0][0]
|
||||
current_worst_seq = all_finished_seqs[beam_width - 1][0]
|
||||
stop_beam_search = self._check_beam_search_early_stopping(
|
||||
seq_group.sampling_params.early_stopping,
|
||||
seq_group.sampling_params, best_running_seq, current_worst_seq)
|
||||
|
||||
if stop_beam_search:
|
||||
# Stop the beam search and remove all the running sequences from
|
||||
# the sequence group.
|
||||
unselected_child_seqs.extend(running_child_seqs)
|
||||
else:
|
||||
# Continue the beam search and select the top beam_width sequences
|
||||
# to continue the beam search.
|
||||
selected_child_seqs.extend(running_child_seqs[:beam_width])
|
||||
# The remaining running sequences will not be used in the next
|
||||
# iteration. Again, if these sequences are continuations of
|
||||
# parent sequences, we will need to remove the parent sequences
|
||||
# from the sequence group.
|
||||
unselected_child_seqs.extend(running_child_seqs[beam_width:])
|
||||
|
||||
# For newly created child sequences, add them to the sequence group
|
||||
# and fork them in block manager if they are not finished.
|
||||
for seq, parent in selected_child_seqs:
|
||||
if seq is not parent:
|
||||
seq_group.add(seq)
|
||||
if not seq.is_finished():
|
||||
self.scheduler.fork_seq(parent, seq)
|
||||
|
||||
# Free the finished and selected parent sequences' memory in block
|
||||
# manager. Keep them in the sequence group as candidate output.
|
||||
for seq, parent in selected_child_seqs:
|
||||
if seq is parent and seq.is_finished():
|
||||
self.scheduler.free_seq(seq)
|
||||
|
||||
# Remove the unselected parent sequences from the sequence group and
|
||||
# free their memory in block manager.
|
||||
for seq, parent in unselected_child_seqs:
|
||||
if seq is parent:
|
||||
# Remove the parent sequence if it is not selected for next
|
||||
# iteration
|
||||
seq_group.remove(seq.seq_id)
|
||||
self.scheduler.free_seq(seq)
|
||||
|
||||
def _process_model_outputs(
|
||||
self, output: SamplerOutput,
|
||||
scheduler_outputs: SchedulerOutputs) -> List[RequestOutput]:
|
||||
# Update the scheduled sequence groups with the model outputs.
|
||||
scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups
|
||||
for seq_group, samples in zip(scheduled_seq_groups, output):
|
||||
self._process_sequence_group_samples(seq_group, samples)
|
||||
|
||||
# Free the finished sequence groups.
|
||||
self.scheduler.free_finished_seq_groups()
|
||||
|
||||
# Create the outputs.
|
||||
request_outputs: List[RequestOutput] = []
|
||||
for seq_group in (scheduled_seq_groups +
|
||||
scheduler_outputs.ignored_seq_groups):
|
||||
request_output = RequestOutput.from_seq_group(seq_group)
|
||||
request_outputs.append(request_output)
|
||||
|
||||
if self.log_stats:
|
||||
# Log the system stats.
|
||||
self._log_system_stats(scheduler_outputs.prompt_run,
|
||||
scheduler_outputs.num_batched_tokens)
|
||||
return request_outputs
|
||||
|
||||
def step(self) -> List[RequestOutput]:
|
||||
"""Performs one decoding iteration and returns newly generated results.
|
||||
|
||||
@ -297,17 +542,10 @@ class LLMEngine:
|
||||
and updates the scheduler with the model outputs. Finally, it decodes
|
||||
the sequences and returns the newly generated results.
|
||||
"""
|
||||
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
|
||||
if scheduler_outputs.is_empty():
|
||||
if not scheduler_outputs.ignored_seq_groups:
|
||||
# Nothing to do.
|
||||
return []
|
||||
# If there are ignored seq groups, we need to return them as the
|
||||
# request outputs.
|
||||
return [
|
||||
RequestOutput.from_seq_group(seq_group)
|
||||
for seq_group in scheduler_outputs.ignored_seq_groups
|
||||
]
|
||||
(seq_group_metadata_list, scheduler_outputs,
|
||||
early_return) = self._schedule()
|
||||
if early_return is not None:
|
||||
return early_return
|
||||
|
||||
# Execute the model.
|
||||
output = self._run_workers(
|
||||
@ -317,27 +555,8 @@ class LLMEngine:
|
||||
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
|
||||
blocks_to_copy=scheduler_outputs.blocks_to_copy,
|
||||
)
|
||||
# Update the scheduler with the model outputs.
|
||||
seq_groups = self.scheduler.update(output)
|
||||
|
||||
# Decode the sequences.
|
||||
self._decode_sequences(seq_groups)
|
||||
# Stop the sequences that meet the stopping criteria.
|
||||
self._stop_sequences(seq_groups)
|
||||
# Free the finished sequence groups.
|
||||
self.scheduler.free_finished_seq_groups()
|
||||
|
||||
# Create the outputs.
|
||||
request_outputs: List[RequestOutput] = []
|
||||
for seq_group in seq_groups + scheduler_outputs.ignored_seq_groups:
|
||||
request_output = RequestOutput.from_seq_group(seq_group)
|
||||
request_outputs.append(request_output)
|
||||
|
||||
if self.log_stats:
|
||||
# Log the system stats.
|
||||
self._log_system_stats(scheduler_outputs.prompt_run,
|
||||
scheduler_outputs.num_batched_tokens)
|
||||
return request_outputs
|
||||
return self._process_model_outputs(output, scheduler_outputs)
|
||||
|
||||
def _log_system_stats(
|
||||
self,
|
||||
@ -402,55 +621,44 @@ class LLMEngine:
|
||||
f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%")
|
||||
self.last_logging_time = now
|
||||
|
||||
def _decode_sequences(self, seq_groups: List[SequenceGroup]) -> None:
|
||||
"""Decodes the sequence outputs."""
|
||||
for seq_group in seq_groups:
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||
new_token, new_output_text = detokenize_incrementally(
|
||||
self.tokenizer,
|
||||
seq.output_tokens,
|
||||
seq.get_last_token_id(),
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
if new_token is not None:
|
||||
seq.output_tokens.append(new_token)
|
||||
seq.output_text = new_output_text
|
||||
def _decode_sequence(self, seq: Sequence) -> None:
|
||||
"""Decodes the new token for a sequence."""
|
||||
new_token, new_output_text = detokenize_incrementally(
|
||||
self.tokenizer,
|
||||
seq.output_tokens,
|
||||
seq.get_last_token_id(),
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
if new_token is not None:
|
||||
seq.output_tokens.append(new_token)
|
||||
seq.output_text = new_output_text
|
||||
|
||||
def _stop_sequences(self, seq_groups: List[SequenceGroup]) -> None:
|
||||
def _check_stop(self, seq: Sequence,
|
||||
sampling_params: SamplingParams) -> None:
|
||||
"""Stop the finished sequences."""
|
||||
for seq_group in seq_groups:
|
||||
sampling_params = seq_group.sampling_params
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||
# Check if the sequence has generated a stop string.
|
||||
stopped = False
|
||||
for stop_str in sampling_params.stop:
|
||||
if seq.output_text.endswith(stop_str):
|
||||
# Truncate the output text so that the stop string is
|
||||
# not included in the output.
|
||||
seq.output_text = seq.output_text[:-len(stop_str)]
|
||||
self.scheduler.free_seq(
|
||||
seq, SequenceStatus.FINISHED_STOPPED)
|
||||
stopped = True
|
||||
break
|
||||
if stopped:
|
||||
continue
|
||||
for stop_str in sampling_params.stop:
|
||||
if seq.output_text.endswith(stop_str):
|
||||
# Truncate the output text so that the stop string is
|
||||
# not included in the output.
|
||||
seq.output_text = seq.output_text[:-len(stop_str)]
|
||||
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||
return
|
||||
|
||||
# Check if the sequence has reached max_model_len.
|
||||
if seq.get_len() > self.scheduler_config.max_model_len:
|
||||
self.scheduler.free_seq(
|
||||
seq, SequenceStatus.FINISHED_LENGTH_CAPPED)
|
||||
continue
|
||||
# Check if the sequence has reached max_tokens.
|
||||
if seq.get_output_len() == sampling_params.max_tokens:
|
||||
self.scheduler.free_seq(
|
||||
seq, SequenceStatus.FINISHED_LENGTH_CAPPED)
|
||||
continue
|
||||
# Check if the sequence has generated the EOS token.
|
||||
if not sampling_params.ignore_eos:
|
||||
if seq.get_last_token_id() == self.tokenizer.eos_token_id:
|
||||
self.scheduler.free_seq(
|
||||
seq, SequenceStatus.FINISHED_STOPPED)
|
||||
continue
|
||||
# Check if the sequence has reached max_model_len.
|
||||
if seq.get_len() > self.scheduler_config.max_model_len:
|
||||
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
|
||||
return
|
||||
|
||||
# Check if the sequence has reached max_tokens.
|
||||
if seq.get_output_len() == sampling_params.max_tokens:
|
||||
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
|
||||
return
|
||||
|
||||
# Check if the sequence has generated the EOS token.
|
||||
if ((not sampling_params.ignore_eos)
|
||||
and seq.get_last_token_id() == self.tokenizer.eos_token_id):
|
||||
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||
return
|
||||
|
||||
def _run_workers(
|
||||
self,
|
||||
|
@ -14,6 +14,7 @@ from vllm.utils import random_uuid
|
||||
TIMEOUT_KEEP_ALIVE = 5 # seconds.
|
||||
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds.
|
||||
app = FastAPI()
|
||||
engine = None
|
||||
|
||||
|
||||
@app.post("/generate")
|
||||
@ -30,6 +31,10 @@ async def generate(request: Request) -> Response:
|
||||
stream = request_dict.pop("stream", False)
|
||||
sampling_params = SamplingParams(**request_dict)
|
||||
request_id = random_uuid()
|
||||
|
||||
if not engine.is_running:
|
||||
engine.start_background_loop()
|
||||
|
||||
results_generator = engine.generate(prompt, sampling_params, request_id)
|
||||
|
||||
# Streaming case
|
||||
@ -75,7 +80,8 @@ if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
engine = AsyncLLMEngine.from_engine_args(engine_args,
|
||||
start_engine_loop=False)
|
||||
|
||||
uvicorn.run(app,
|
||||
host=args.host,
|
||||
|
@ -44,6 +44,7 @@ TIMEOUT_KEEP_ALIVE = 5 # seconds
|
||||
logger = init_logger(__name__)
|
||||
served_model = None
|
||||
app = fastapi.FastAPI()
|
||||
engine = None
|
||||
|
||||
|
||||
def create_error_response(status_code: HTTPStatus,
|
||||
@ -178,7 +179,8 @@ def create_logprobs(token_ids: List[int],
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def create_chat_completion(raw_request: Request):
|
||||
async def create_chat_completion(request: ChatCompletionRequest,
|
||||
raw_request: Request):
|
||||
"""Completion API similar to OpenAI's API.
|
||||
|
||||
See https://platform.openai.com/docs/api-reference/chat/create
|
||||
@ -188,9 +190,11 @@ async def create_chat_completion(raw_request: Request):
|
||||
- function_call (Users should implement this by themselves)
|
||||
- logit_bias (to be supported by vLLM engine)
|
||||
"""
|
||||
request = ChatCompletionRequest(**await raw_request.json())
|
||||
logger.info(f"Received chat completion request: {request}")
|
||||
|
||||
if not engine.is_running:
|
||||
engine.start_background_loop()
|
||||
|
||||
error_check_ret = await check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
@ -348,7 +352,7 @@ async def create_chat_completion(raw_request: Request):
|
||||
|
||||
|
||||
@app.post("/v1/completions")
|
||||
async def create_completion(raw_request: Request):
|
||||
async def create_completion(request: CompletionRequest, raw_request: Request):
|
||||
"""Completion API similar to OpenAI's API.
|
||||
|
||||
See https://platform.openai.com/docs/api-reference/completions/create
|
||||
@ -361,9 +365,11 @@ async def create_completion(raw_request: Request):
|
||||
suffix)
|
||||
- logit_bias (to be supported by vLLM engine)
|
||||
"""
|
||||
request = CompletionRequest(**await raw_request.json())
|
||||
logger.info(f"Received completion request: {request}")
|
||||
|
||||
if not engine.is_running:
|
||||
engine.start_background_loop()
|
||||
|
||||
error_check_ret = await check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
@ -621,7 +627,8 @@ if __name__ == "__main__":
|
||||
served_model = args.model
|
||||
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
engine = AsyncLLMEngine.from_engine_args(engine_args,
|
||||
start_engine_loop=False)
|
||||
engine_model_config = asyncio.run(engine.get_model_config())
|
||||
max_model_len = engine_model_config.get_max_model_len()
|
||||
|
||||
|
@ -61,7 +61,6 @@ class PagedAttention(nn.Module):
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.attn_op = xops.fmha.cutlass.FwOp()
|
||||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
@ -115,7 +114,6 @@ class PagedAttention(nn.Module):
|
||||
attn_bias=input_metadata.attn_bias[0],
|
||||
p=0.0,
|
||||
scale=self.scale,
|
||||
op=self.attn_op,
|
||||
)
|
||||
# TODO(woosuk): Unnecessary copy. Optimize.
|
||||
output.copy_(out.squeeze(0))
|
||||
@ -244,7 +242,7 @@ class PagedAttention(nn.Module):
|
||||
|
||||
|
||||
class PagedAttentionWithRoPE(PagedAttention):
|
||||
"""PagedAttention with GPT-NeoX style rotary embedding."""
|
||||
"""PagedAttention with rotary embedding."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -255,8 +253,10 @@ class PagedAttentionWithRoPE(PagedAttention):
|
||||
max_position: int = 8192,
|
||||
base: int = 10000,
|
||||
num_kv_heads: Optional[int] = None,
|
||||
is_neox_style: bool = True,
|
||||
) -> None:
|
||||
super().__init__(num_heads, head_size, scale, num_kv_heads)
|
||||
self.is_neox_style = is_neox_style
|
||||
|
||||
# Create the cos and sin cache.
|
||||
inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim))
|
||||
@ -305,12 +305,13 @@ class PagedAttentionWithRoPE(PagedAttention):
|
||||
|
||||
# Apply rotary embedding to the query and key before passing them
|
||||
# to the attention op.
|
||||
pos_encoding_ops.rotary_embedding_neox(
|
||||
pos_encoding_ops.rotary_embedding(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
self.head_size,
|
||||
self.cos_sin_cache,
|
||||
self.is_neox_style,
|
||||
)
|
||||
return super().forward(
|
||||
query,
|
||||
@ -404,7 +405,6 @@ class PagedAttentionWithALiBi(PagedAttention):
|
||||
attn_bias=input_metadata.attn_bias[i],
|
||||
p=0.0,
|
||||
scale=self.scale,
|
||||
op=self.attn_op,
|
||||
)
|
||||
# TODO(woosuk): Unnecessary copy. Optimize.
|
||||
output[start:end].copy_(out.squeeze(0))
|
||||
|
@ -9,7 +9,7 @@ from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||
gather_from_tensor_model_parallel_region)
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import SequenceOutputs
|
||||
from vllm.sequence import SamplerOutput, SequenceOutputs
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
|
||||
@ -39,7 +39,7 @@ class Sampler(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
embedding_bias: Optional[torch.Tensor] = None,
|
||||
) -> Dict[int, SequenceOutputs]:
|
||||
) -> SamplerOutput:
|
||||
# Get the hidden states that we use for sampling.
|
||||
hidden_states = _prune_hidden_states(hidden_states, input_metadata)
|
||||
|
||||
@ -100,7 +100,8 @@ def _prune_hidden_states(
|
||||
start_idx += prompt_len
|
||||
last_token_indicies.extend(
|
||||
range(start_idx, start_idx + input_metadata.num_generation_tokens))
|
||||
return hidden_states[last_token_indicies]
|
||||
return hidden_states.index_select(
|
||||
0, torch.tensor(last_token_indicies, device=hidden_states.device))
|
||||
|
||||
|
||||
def _get_penalties(
|
||||
@ -157,7 +158,7 @@ def _apply_penalties(
|
||||
continue
|
||||
p = presence_penalties[i]
|
||||
f = frequency_penalties[i]
|
||||
if p < _SAMPLING_EPS and f < _SAMPLING_EPS:
|
||||
if abs(p) < _SAMPLING_EPS and abs(f) < _SAMPLING_EPS:
|
||||
continue
|
||||
indices.append(i)
|
||||
|
||||
@ -291,7 +292,13 @@ def _sample_from_prompt(
|
||||
if sampling_params.use_beam_search:
|
||||
# Beam search.
|
||||
beam_width = sampling_params.best_of
|
||||
_, next_token_ids = torch.topk(prob, beam_width)
|
||||
# Sample 2 * beam_width candidates to make sure that with high
|
||||
# probability we can get `beam_width` candidates in addition to
|
||||
# the finished sequences for the next iteration. See
|
||||
# https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563
|
||||
# for details. See also HF reference:
|
||||
# https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
|
||||
_, next_token_ids = torch.topk(prob, 2 * beam_width)
|
||||
next_token_ids = next_token_ids.tolist()
|
||||
elif sampling_params.temperature < _SAMPLING_EPS:
|
||||
# Greedy sampling.
|
||||
@ -329,29 +336,11 @@ def _sample_from_generation_tokens(
|
||||
|
||||
vocab_size = logprobs.size(-1)
|
||||
beam_width = len(seq_ids)
|
||||
_, topk_ids = torch.topk(logprobs.flatten(), beam_width)
|
||||
_, topk_ids = torch.topk(logprobs.flatten(), 2 * beam_width)
|
||||
topk_ids = topk_ids.tolist()
|
||||
seq_idx = [i // vocab_size for i in topk_ids]
|
||||
beam_seq_ids = [seq_ids[i] for i in seq_idx]
|
||||
token_ids = [i % vocab_size for i in topk_ids]
|
||||
|
||||
beam_outputs: Dict[int, Tuple[int, int]] = {}
|
||||
outstanding_beams: List[Tuple[int, int]] = []
|
||||
# If a beam survives, continue with it.
|
||||
for seq_id, token_id in zip(beam_seq_ids, token_ids):
|
||||
if seq_id not in beam_outputs:
|
||||
beam_outputs[seq_id] = (seq_id, token_id)
|
||||
else:
|
||||
outstanding_beams.append((seq_id, token_id))
|
||||
|
||||
# If a beam is discarded, fork another beam.
|
||||
for seq_id in seq_ids:
|
||||
if seq_id not in beam_outputs:
|
||||
beam_outputs[seq_id] = outstanding_beams.pop()
|
||||
assert not outstanding_beams
|
||||
|
||||
parent_seq_ids = [beam_outputs[seq_id][0] for seq_id in seq_ids]
|
||||
next_token_ids = [beam_outputs[seq_id][1] for seq_id in seq_ids]
|
||||
parent_seq_ids = [seq_ids[i] for i in seq_idx]
|
||||
next_token_ids = [i % vocab_size for i in topk_ids]
|
||||
elif sampling_params.temperature < _SAMPLING_EPS:
|
||||
# Greedy sampling.
|
||||
assert len(seq_ids) == 1
|
||||
@ -373,16 +362,18 @@ def _sample(
|
||||
probs: torch.Tensor,
|
||||
logprobs: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
) -> Dict[int, SequenceOutputs]:
|
||||
seq_outputs: Dict[int, SequenceOutputs] = {}
|
||||
) -> SamplerOutput:
|
||||
seq_outputs: SamplerOutput = []
|
||||
|
||||
# TODO(woosuk): Optimize.
|
||||
idx = 0
|
||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
||||
seq_group_outputs: List[SequenceOutputs] = []
|
||||
seq_ids, sampling_params = seq_group
|
||||
if i < input_metadata.num_prompts:
|
||||
# Generate the next tokens for a prompt input.
|
||||
assert len(seq_ids) == sampling_params.best_of
|
||||
assert len(seq_ids) == 1, "Prompt input should have only one seq."
|
||||
parent_seq_id = seq_ids[0]
|
||||
prob = probs[idx]
|
||||
logprob = logprobs[idx]
|
||||
idx += 1
|
||||
@ -394,17 +385,18 @@ def _sample(
|
||||
sampling_params.logprobs)
|
||||
|
||||
# Build the output.
|
||||
for seq_id, next_token_id in zip(seq_ids, next_token_ids):
|
||||
for next_token_id in next_token_ids:
|
||||
output_logprobs = next_logprobs.copy()
|
||||
output_logprobs[next_token_id] = logprob[next_token_id].item()
|
||||
seq_outputs[seq_id] = SequenceOutputs(seq_id, seq_id,
|
||||
next_token_id,
|
||||
output_logprobs)
|
||||
seq_group_outputs.append(
|
||||
SequenceOutputs(parent_seq_id, next_token_id,
|
||||
output_logprobs))
|
||||
else:
|
||||
# Generate the next tokens for generation tokens.
|
||||
prob = probs[idx:idx + len(seq_ids)]
|
||||
logprob = logprobs[idx:idx + len(seq_ids)]
|
||||
idx += len(seq_ids)
|
||||
num_parent_seqs = len(seq_ids)
|
||||
prob = probs[idx:idx + num_parent_seqs]
|
||||
logprob = logprobs[idx:idx + num_parent_seqs]
|
||||
idx += num_parent_seqs
|
||||
|
||||
# Sample the next tokens.
|
||||
seq_logprobs = [
|
||||
@ -421,17 +413,15 @@ def _sample(
|
||||
logprob[j], sampling_params.logprobs)
|
||||
|
||||
# Build the output.
|
||||
for seq_id, parent_seq_id, next_token_id in zip(
|
||||
seq_ids, parent_seq_ids, next_token_ids):
|
||||
for parent_seq_id, next_token_id in zip(parent_seq_ids,
|
||||
next_token_ids):
|
||||
j = seq_ids.index(parent_seq_id)
|
||||
output_logprobs = next_logprobs[parent_seq_id].copy()
|
||||
output_logprobs[next_token_id] = logprob[j,
|
||||
next_token_id].item()
|
||||
seq_outputs[seq_id] = SequenceOutputs(
|
||||
seq_id,
|
||||
parent_seq_id,
|
||||
next_token_id,
|
||||
output_logprobs,
|
||||
)
|
||||
seq_group_outputs.append(
|
||||
SequenceOutputs(parent_seq_id, next_token_id,
|
||||
output_logprobs))
|
||||
seq_outputs.append(seq_group_outputs)
|
||||
|
||||
return seq_outputs
|
||||
|
@ -1,4 +1,5 @@
|
||||
"""Utilities for selecting and loading models."""
|
||||
import contextlib
|
||||
from typing import Type
|
||||
|
||||
import torch
|
||||
@ -30,6 +31,15 @@ _MODEL_REGISTRY = {
|
||||
}
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _set_default_torch_dtype(dtype: torch.dtype):
|
||||
"""Sets the default torch dtype to the given dtype."""
|
||||
old_dtype = torch.get_default_dtype()
|
||||
torch.set_default_dtype(dtype)
|
||||
yield
|
||||
torch.set_default_dtype(old_dtype)
|
||||
|
||||
|
||||
def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
|
||||
architectures = getattr(config, "architectures", [])
|
||||
for arch in architectures:
|
||||
@ -42,19 +52,18 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
|
||||
|
||||
def get_model(model_config: ModelConfig) -> nn.Module:
|
||||
model_class = _get_model_architecture(model_config.hf_config)
|
||||
torch.set_default_dtype(model_config.dtype)
|
||||
|
||||
# Create a model instance.
|
||||
# The weights will be initialized as empty tensors.
|
||||
model = model_class(model_config.hf_config)
|
||||
if model_config.use_dummy_weights:
|
||||
model = model.cuda()
|
||||
# NOTE(woosuk): For accurate performance evaluation, we assign
|
||||
# random values to the weights.
|
||||
initialize_dummy_weights(model)
|
||||
else:
|
||||
# Load the weights from the cached or downloaded files.
|
||||
model.load_weights(model_config.model, model_config.download_dir,
|
||||
model_config.use_np_weights)
|
||||
model = model.cuda()
|
||||
with _set_default_torch_dtype(model_config.dtype):
|
||||
# Create a model instance.
|
||||
# The weights will be initialized as empty tensors.
|
||||
model = model_class(model_config.hf_config)
|
||||
if model_config.load_format == "dummy":
|
||||
model = model.cuda()
|
||||
# NOTE(woosuk): For accurate performance evaluation, we assign
|
||||
# random values to the weights.
|
||||
initialize_dummy_weights(model)
|
||||
else:
|
||||
# Load the weights from the cached or downloaded files.
|
||||
model.load_weights(model_config.model, model_config.download_dir,
|
||||
model_config.load_format)
|
||||
model = model.cuda()
|
||||
return model.eval()
|
||||
|
@ -25,7 +25,7 @@
|
||||
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
||||
InputMetadata to extract the original 2D shape of the input.
|
||||
"""
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -34,13 +34,14 @@ from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
|
||||
load_tensor_parallel_weights)
|
||||
from vllm.model_executor.weight_utils import (
|
||||
hf_model_weights_iterator, load_padded_tensor_parallel_vocab,
|
||||
load_tensor_parallel_weights)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
||||
from vllm.sequence import SequenceOutputs
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.transformers_utils.configs.aquila import AquilaConfig
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
@ -272,7 +273,7 @@ class AquilaForCausalLM(nn.Module):
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> Dict[int, SequenceOutputs]:
|
||||
) -> SamplerOutput:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
input_metadata, cache_events)
|
||||
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
||||
@ -280,15 +281,14 @@ class AquilaForCausalLM(nn.Module):
|
||||
return next_tokens
|
||||
|
||||
_column_parallel_weights = [
|
||||
"embed_tokens.weight", "lm_head.weight", "qkv_proj.weight",
|
||||
"gate_proj.weight", "up_proj.weight"
|
||||
"qkv_proj.weight", "gate_proj.weight", "up_proj.weight"
|
||||
]
|
||||
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
|
||||
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
use_np_cache: bool = False):
|
||||
load_format: str = "auto"):
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||
q_proj_shard_size = (self.config.hidden_size // tp_size)
|
||||
@ -305,20 +305,10 @@ class AquilaForCausalLM(nn.Module):
|
||||
state_dict = self.state_dict()
|
||||
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, use_np_cache):
|
||||
model_name_or_path, cache_dir, load_format):
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
if "embed_tokens" in name or "lm_head" in name:
|
||||
param = state_dict[name]
|
||||
# Consider padding in the vocab size.
|
||||
padded_vocab_size = (param.shape[0] * tp_size)
|
||||
num_extra_rows = padded_vocab_size - self.config.vocab_size
|
||||
extra_rows = torch.empty(num_extra_rows,
|
||||
loaded_weight.shape[1])
|
||||
extra_rows = extra_rows.to(loaded_weight)
|
||||
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
|
||||
|
||||
is_attention_weight = False
|
||||
for weight_name, shard_size, offset in attention_weight_specs:
|
||||
if weight_name not in name:
|
||||
@ -356,6 +346,11 @@ class AquilaForCausalLM(nn.Module):
|
||||
continue
|
||||
|
||||
param = state_dict[name]
|
||||
if "embed_tokens" in name or "lm_head" in name:
|
||||
load_padded_tensor_parallel_vocab(param, loaded_weight,
|
||||
tensor_model_parallel_rank)
|
||||
continue
|
||||
|
||||
load_tensor_parallel_weights(param, loaded_weight, name,
|
||||
self._column_parallel_weights,
|
||||
self._row_parallel_weights,
|
||||
|
@ -23,23 +23,25 @@ The input of the model is flattened to a 1D tensor of tokens. The model uses
|
||||
InputMetadata to extract the original 2D shape of the input.
|
||||
"""
|
||||
import math
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.sequence import SequenceOutputs
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE, PagedAttentionWithALiBi
|
||||
from vllm.model_executor.layers.attention import (PagedAttentionWithRoPE,
|
||||
PagedAttentionWithALiBi)
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
|
||||
load_tensor_parallel_weights)
|
||||
from vllm.model_executor.weight_utils import (
|
||||
convert_pyslice_to_tensor, hf_model_weights_iterator,
|
||||
load_padded_tensor_parallel_vocab, load_tensor_parallel_weights)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.transformers_utils.configs.baichuan import BaiChuanConfig
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
@ -288,41 +290,30 @@ class BaiChuanBaseForCausalLM(nn.Module):
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> Dict[int, SequenceOutputs]:
|
||||
) -> SamplerOutput:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
input_metadata, cache_events)
|
||||
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
||||
input_metadata)
|
||||
return next_tokens
|
||||
|
||||
_column_parallel_weights = [
|
||||
"embed_tokens.weight",
|
||||
"lm_head.weight",
|
||||
]
|
||||
_column_parallel_weights = []
|
||||
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
|
||||
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
use_np_cache: bool = False):
|
||||
load_format: str = "auto"):
|
||||
tp_world_size = get_tensor_model_parallel_world_size()
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
state_dict = self.state_dict()
|
||||
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, use_np_cache):
|
||||
model_name_or_path, cache_dir, load_format):
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
if "embed_tokens" in name or "lm_head" in name:
|
||||
# Consider padding in the vocab size.
|
||||
param = state_dict[name]
|
||||
padded_vocab_size = param.shape[0] * tp_world_size
|
||||
num_extra_rows = padded_vocab_size - self.config.vocab_size
|
||||
extra_rows = torch.empty(num_extra_rows,
|
||||
loaded_weight.shape[1])
|
||||
extra_rows = extra_rows.to(loaded_weight)
|
||||
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
|
||||
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
||||
|
||||
if "W_pack" in name:
|
||||
total_num_heads = self.config.num_attention_heads
|
||||
@ -355,6 +346,12 @@ class BaiChuanBaseForCausalLM(nn.Module):
|
||||
continue
|
||||
|
||||
param = state_dict[name]
|
||||
|
||||
if "embed_tokens" in name or "lm_head" in name:
|
||||
load_padded_tensor_parallel_vocab(param, loaded_weight,
|
||||
tp_rank)
|
||||
continue
|
||||
|
||||
load_tensor_parallel_weights(
|
||||
param,
|
||||
loaded_weight,
|
||||
|
@ -21,7 +21,7 @@ The input of the model is flattened to a 1D tensor of tokens. The model uses
|
||||
InputMetadata to extract the original 2D shape of the input.
|
||||
"""
|
||||
import math
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -37,7 +37,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
||||
from vllm.sequence import SequenceOutputs
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
@ -264,7 +264,7 @@ class BloomForCausalLM(nn.Module):
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> Dict[int, SequenceOutputs]:
|
||||
) -> SamplerOutput:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
input_metadata, cache_events)
|
||||
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
|
||||
@ -279,11 +279,11 @@ class BloomForCausalLM(nn.Module):
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
use_np_cache: bool = False):
|
||||
load_format: str = "auto"):
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
state_dict = self.state_dict()
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, use_np_cache):
|
||||
model_name_or_path, cache_dir, load_format):
|
||||
if name == "lm_head.weight":
|
||||
# Since hidden_states are parallelized, we need to
|
||||
# load lm_head.weight in parallel.
|
||||
|
@ -19,7 +19,7 @@
|
||||
"""PyTorch Falcon model."""
|
||||
|
||||
import math
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -31,14 +31,15 @@ from vllm.model_executor.layers.attention import (PagedAttention,
|
||||
PagedAttentionWithALiBi,
|
||||
PagedAttentionWithRoPE)
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
|
||||
from vllm.model_executor.weight_utils import (convert_pyslice_to_tensor,
|
||||
hf_model_weights_iterator,
|
||||
load_tensor_parallel_weights)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear,
|
||||
reduce_from_tensor_model_parallel_region)
|
||||
from vllm.sequence import SequenceOutputs
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.transformers_utils.configs import RWConfig
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
@ -397,7 +398,7 @@ class FalconForCausalLM(nn.Module):
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> Dict[int, SequenceOutputs]:
|
||||
) -> SamplerOutput:
|
||||
hidden_states = self.transformer(
|
||||
input_ids,
|
||||
positions,
|
||||
@ -419,7 +420,7 @@ class FalconForCausalLM(nn.Module):
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
use_np_cache: bool = False):
|
||||
load_format: str = "auto"):
|
||||
tp_size = (get_tensor_model_parallel_world_size())
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
@ -451,8 +452,9 @@ class FalconForCausalLM(nn.Module):
|
||||
state_dict = self.state_dict()
|
||||
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, use_np_cache):
|
||||
model_name_or_path, cache_dir, load_format):
|
||||
if "query_key_value" in name:
|
||||
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
||||
loaded_weight_size = loaded_weight.size()
|
||||
loaded_weight = loaded_weight.view(
|
||||
total_num_kv_heads, num_query_heads_per_kv_head + 2,
|
||||
|
@ -21,7 +21,7 @@
|
||||
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
||||
InputMetadata to extract the original 2D shape of the input.
|
||||
"""
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -31,13 +31,14 @@ from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.attention import PagedAttention
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
|
||||
load_tensor_parallel_weights)
|
||||
from vllm.model_executor.weight_utils import (
|
||||
convert_pyslice_to_tensor, hf_model_weights_iterator,
|
||||
load_padded_tensor_parallel_vocab, load_tensor_parallel_weights)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
||||
from vllm.sequence import SequenceOutputs
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
@ -217,27 +218,27 @@ class GPT2LMHeadModel(nn.Module):
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> Dict[int, SequenceOutputs]:
|
||||
) -> SamplerOutput:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
input_metadata, cache_events)
|
||||
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
|
||||
input_metadata)
|
||||
return next_tokens
|
||||
|
||||
_column_parallel_weights = ["wte.weight", "c_fc.weight", "c_fc.bias"]
|
||||
_column_parallel_weights = ["c_fc.weight", "c_fc.bias"]
|
||||
_row_parallel_weights = ["c_proj.weight"]
|
||||
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
use_np_cache: bool = False):
|
||||
load_format: str = "auto"):
|
||||
tensor_model_parallel_world_size = (
|
||||
get_tensor_model_parallel_world_size())
|
||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||
state_dict = self.state_dict()
|
||||
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, use_np_cache):
|
||||
model_name_or_path, cache_dir, load_format):
|
||||
if "lm_head.weight" in name:
|
||||
# GPT-2 ties the weights of the embedding layer and the final
|
||||
# linear layer.
|
||||
@ -250,6 +251,8 @@ class GPT2LMHeadModel(nn.Module):
|
||||
if not name.startswith("transformer."):
|
||||
name = "transformer." + name
|
||||
|
||||
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
||||
|
||||
# The HF's GPT-2 implementation uses Conv1D instead of Linear.
|
||||
# Because of this, we need to transpose the weights.
|
||||
for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
|
||||
@ -261,14 +264,9 @@ class GPT2LMHeadModel(nn.Module):
|
||||
param = state_dict[name]
|
||||
|
||||
if name == "transformer.wte.weight":
|
||||
# Consider padding in the vocab size.
|
||||
padded_vocab_size = (param.shape[0] *
|
||||
tensor_model_parallel_world_size)
|
||||
num_extra_rows = padded_vocab_size - self.config.vocab_size
|
||||
extra_rows = torch.empty(num_extra_rows,
|
||||
loaded_weight.shape[1])
|
||||
extra_rows = extra_rows.to(loaded_weight)
|
||||
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
|
||||
load_padded_tensor_parallel_vocab(param, loaded_weight,
|
||||
tensor_model_parallel_rank)
|
||||
continue
|
||||
|
||||
# For the fused QKV linear layer, manually shard the weights.
|
||||
if "c_attn" in name:
|
||||
|
@ -22,7 +22,7 @@
|
||||
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
||||
InputMetadata to extract the original 2D shape of the input.
|
||||
"""
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -32,13 +32,14 @@ from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.attention import PagedAttention
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
|
||||
load_tensor_parallel_weights)
|
||||
from vllm.model_executor.weight_utils import (
|
||||
convert_pyslice_to_tensor, hf_model_weights_iterator,
|
||||
load_padded_tensor_parallel_vocab, load_tensor_parallel_weights)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
||||
from vllm.sequence import SequenceOutputs
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
@ -245,27 +246,27 @@ class GPTBigCodeForCausalLM(nn.Module):
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> Dict[int, SequenceOutputs]:
|
||||
) -> SamplerOutput:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
input_metadata, cache_events)
|
||||
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
|
||||
input_metadata)
|
||||
return next_tokens
|
||||
|
||||
_column_parallel_weights = ["wte.weight", "c_fc.weight", "c_fc.bias"]
|
||||
_column_parallel_weights = ["c_fc.weight", "c_fc.bias"]
|
||||
_row_parallel_weights = ["c_proj.weight"]
|
||||
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
use_np_cache: bool = False):
|
||||
load_format: str = "auto"):
|
||||
tensor_model_parallel_world_size = (
|
||||
get_tensor_model_parallel_world_size())
|
||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||
state_dict = self.state_dict()
|
||||
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, use_np_cache):
|
||||
model_name_or_path, cache_dir, load_format):
|
||||
if "lm_head.weight" in name:
|
||||
# GPT-2 ties the weights of the embedding layer and the final
|
||||
# linear layer.
|
||||
@ -294,6 +295,7 @@ class GPTBigCodeForCausalLM(nn.Module):
|
||||
head_start = tensor_model_parallel_rank * num_heads
|
||||
head_end = (tensor_model_parallel_rank + 1) * num_heads
|
||||
|
||||
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
||||
wq, wk, wv = torch.split(
|
||||
loaded_weight, [hidden_size, total_kv_size, total_kv_size],
|
||||
dim=0)
|
||||
@ -328,14 +330,9 @@ class GPTBigCodeForCausalLM(nn.Module):
|
||||
param = state_dict[name]
|
||||
|
||||
if name == "transformer.wte.weight":
|
||||
# Consider padding in the vocab size.
|
||||
padded_vocab_size = param.shape[
|
||||
0] * tensor_model_parallel_world_size
|
||||
num_extra_rows = padded_vocab_size - self.config.vocab_size
|
||||
extra_rows = torch.empty(num_extra_rows,
|
||||
loaded_weight.shape[1])
|
||||
extra_rows = extra_rows.to(loaded_weight)
|
||||
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
|
||||
load_padded_tensor_parallel_vocab(param, loaded_weight,
|
||||
tensor_model_parallel_rank)
|
||||
continue
|
||||
|
||||
load_tensor_parallel_weights(param, loaded_weight, name,
|
||||
self._column_parallel_weights,
|
||||
|
@ -20,7 +20,7 @@
|
||||
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
||||
InputMetadata to extract the original 2D shape of the input.
|
||||
"""
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -36,7 +36,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
||||
from vllm.sequence import SequenceOutputs
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
@ -67,8 +67,11 @@ class GPTJAttention(nn.Module):
|
||||
scaling = self.head_size**-0.5
|
||||
assert getattr(config, "rotary", True)
|
||||
assert config.rotary_dim % 2 == 0
|
||||
self.attn = PagedAttentionWithRoPE(self.num_heads, self.head_size,
|
||||
scaling, config.rotary_dim)
|
||||
self.attn = PagedAttentionWithRoPE(self.num_heads,
|
||||
self.head_size,
|
||||
scaling,
|
||||
config.rotary_dim,
|
||||
is_neox_style=False)
|
||||
self.warmup = False
|
||||
|
||||
def forward(
|
||||
@ -203,7 +206,7 @@ class GPTJForCausalLM(nn.Module):
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> Dict[int, SequenceOutputs]:
|
||||
) -> SamplerOutput:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
input_metadata, cache_events)
|
||||
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
||||
@ -219,11 +222,11 @@ class GPTJForCausalLM(nn.Module):
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
use_np_cache: bool = False):
|
||||
load_format: str = "auto"):
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
state_dict = self.state_dict()
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, use_np_cache):
|
||||
model_name_or_path, cache_dir, load_format):
|
||||
if "attn.bias" in name or "attn.masked_bias" in name:
|
||||
continue
|
||||
|
||||
|
@ -20,7 +20,7 @@
|
||||
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
||||
InputMetadata to extract the original 2D shape of the input.
|
||||
"""
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -36,7 +36,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
||||
from vllm.sequence import SequenceOutputs
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
@ -215,7 +215,7 @@ class GPTNeoXForCausalLM(nn.Module):
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> Dict[int, SequenceOutputs]:
|
||||
) -> SamplerOutput:
|
||||
hidden_states = self.gpt_neox(input_ids, positions, kv_caches,
|
||||
input_metadata, cache_events)
|
||||
next_tokens = self.sampler(self.embed_out.weight, hidden_states,
|
||||
@ -231,11 +231,11 @@ class GPTNeoXForCausalLM(nn.Module):
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
use_np_cache: bool = False):
|
||||
load_format: str = "auto"):
|
||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||
state_dict = self.state_dict()
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, use_np_cache):
|
||||
model_name_or_path, cache_dir, load_format):
|
||||
if ("attention.bias" in name or "attention.masked_bias" in name
|
||||
or "rotary_emb.inv_freq" in name):
|
||||
continue
|
||||
|
@ -1,5 +1,5 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -14,9 +14,10 @@ from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||
ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding)
|
||||
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
|
||||
load_tensor_parallel_weights)
|
||||
from vllm.sequence import SequenceOutputs
|
||||
from vllm.model_executor.weight_utils import (
|
||||
hf_model_weights_iterator, load_padded_tensor_parallel_vocab,
|
||||
load_tensor_parallel_weights)
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
@ -217,7 +218,7 @@ class InternLMForCausalLM(nn.Module):
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> Dict[int, SequenceOutputs]:
|
||||
) -> SamplerOutput:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
input_metadata, cache_events)
|
||||
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
||||
@ -225,35 +226,27 @@ class InternLMForCausalLM(nn.Module):
|
||||
return next_tokens
|
||||
|
||||
_column_parallel_weights = [
|
||||
"embed_tokens.weight", "lm_head.weight", "qkv_proj.weight",
|
||||
"gate_proj.weight", "up_proj.weight"
|
||||
"qkv_proj.weight", "gate_proj.weight", "up_proj.weight"
|
||||
]
|
||||
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
|
||||
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
use_np_cache: bool = False):
|
||||
tensor_model_parallel_world_size = (
|
||||
get_tensor_model_parallel_world_size())
|
||||
load_format: str = "auto"):
|
||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||
state_dict = self.state_dict()
|
||||
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, use_np_cache):
|
||||
model_name_or_path, cache_dir, load_format):
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
if "embed_tokens" in name or "lm_head" in name:
|
||||
param = state_dict[name]
|
||||
# Consider padding in the vocab size.
|
||||
padded_vocab_size = (param.shape[0] *
|
||||
tensor_model_parallel_world_size)
|
||||
num_extra_rows = padded_vocab_size - self.config.vocab_size
|
||||
extra_rows = torch.empty(num_extra_rows,
|
||||
loaded_weight.shape[1])
|
||||
extra_rows = extra_rows.to(loaded_weight)
|
||||
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
|
||||
load_padded_tensor_parallel_vocab(param, loaded_weight,
|
||||
tensor_model_parallel_rank)
|
||||
continue
|
||||
|
||||
is_attention_weight = False
|
||||
for stride_id, att_weight_name in enumerate(
|
||||
|
@ -25,7 +25,7 @@
|
||||
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
||||
InputMetadata to extract the original 2D shape of the input.
|
||||
"""
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -36,13 +36,14 @@ from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
|
||||
load_tensor_parallel_weights)
|
||||
from vllm.model_executor.weight_utils import (
|
||||
load_tensor_parallel_weights, load_padded_tensor_parallel_vocab,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
||||
from vllm.sequence import SequenceOutputs
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
@ -85,6 +86,7 @@ class LlamaAttention(nn.Module):
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
rope_theta: float = 10000,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@ -99,6 +101,7 @@ class LlamaAttention(nn.Module):
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
|
||||
self.qkv_proj = ColumnParallelLinear(
|
||||
hidden_size,
|
||||
@ -118,6 +121,7 @@ class LlamaAttention(nn.Module):
|
||||
self.attn = PagedAttentionWithRoPE(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
base=self.rope_theta,
|
||||
rotary_dim=self.head_dim,
|
||||
num_kv_heads=self.num_kv_heads)
|
||||
|
||||
@ -143,10 +147,13 @@ class LlamaDecoderLayer(nn.Module):
|
||||
def __init__(self, config: LlamaConfig):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
# Requires transformers > 4.32.0
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
self.self_attn = LlamaAttention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
rope_theta=rope_theta,
|
||||
)
|
||||
self.mlp = LlamaMLP(
|
||||
hidden_size=self.hidden_size,
|
||||
@ -249,7 +256,7 @@ class LlamaForCausalLM(nn.Module):
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> Dict[int, SequenceOutputs]:
|
||||
) -> SamplerOutput:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
input_metadata, cache_events)
|
||||
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
||||
@ -257,15 +264,14 @@ class LlamaForCausalLM(nn.Module):
|
||||
return next_tokens
|
||||
|
||||
_column_parallel_weights = [
|
||||
"embed_tokens.weight", "lm_head.weight", "qkv_proj.weight",
|
||||
"gate_proj.weight", "up_proj.weight"
|
||||
"qkv_proj.weight", "gate_proj.weight", "up_proj.weight"
|
||||
]
|
||||
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
|
||||
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
use_np_cache: bool = False):
|
||||
load_format: str = "auto"):
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||
q_proj_shard_size = (self.config.hidden_size // tp_size)
|
||||
@ -282,20 +288,10 @@ class LlamaForCausalLM(nn.Module):
|
||||
state_dict = self.state_dict()
|
||||
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, use_np_cache):
|
||||
model_name_or_path, cache_dir, load_format):
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
if "embed_tokens" in name or "lm_head" in name:
|
||||
param = state_dict[name]
|
||||
# Consider padding in the vocab size.
|
||||
padded_vocab_size = (param.shape[0] * tp_size)
|
||||
num_extra_rows = padded_vocab_size - self.config.vocab_size
|
||||
extra_rows = torch.empty(num_extra_rows,
|
||||
loaded_weight.shape[1])
|
||||
extra_rows = extra_rows.to(loaded_weight)
|
||||
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
|
||||
|
||||
is_attention_weight = False
|
||||
for weight_name, shard_size, offset in attention_weight_specs:
|
||||
if weight_name not in name:
|
||||
@ -333,6 +329,12 @@ class LlamaForCausalLM(nn.Module):
|
||||
continue
|
||||
|
||||
param = state_dict[name]
|
||||
|
||||
if "embed_tokens" in name or "lm_head" in name:
|
||||
load_padded_tensor_parallel_vocab(param, loaded_weight,
|
||||
tensor_model_parallel_rank)
|
||||
continue
|
||||
|
||||
load_tensor_parallel_weights(param, loaded_weight, name,
|
||||
self._column_parallel_weights,
|
||||
self._row_parallel_weights,
|
||||
|
@ -1,7 +1,7 @@
|
||||
# coding=utf-8
|
||||
# Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main
|
||||
import math
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -10,13 +10,14 @@ from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.attention import PagedAttentionWithALiBi
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
|
||||
from vllm.model_executor.weight_utils import (convert_pyslice_to_tensor,
|
||||
hf_model_weights_iterator,
|
||||
load_tensor_parallel_weights)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
||||
from vllm.sequence import SequenceOutputs
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.transformers_utils.configs.mpt import MPTConfig
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
@ -230,7 +231,7 @@ class MPTForCausalLM(nn.Module):
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> Dict[int, SequenceOutputs]:
|
||||
) -> SamplerOutput:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
input_metadata, cache_events)
|
||||
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
|
||||
@ -243,12 +244,12 @@ class MPTForCausalLM(nn.Module):
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
use_np_cache: bool = False):
|
||||
load_format: str = "auto"):
|
||||
tp_world_size = get_tensor_model_parallel_world_size()
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
state_dict = self.state_dict()
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, use_np_cache):
|
||||
model_name_or_path, cache_dir, load_format):
|
||||
if "Wqkv" in name:
|
||||
# NOTE(woosuk): MPT's fused QKV has the shape of
|
||||
# [3 * num_heads * head_size, hidden_size].
|
||||
@ -260,7 +261,7 @@ class MPTForCausalLM(nn.Module):
|
||||
num_heads = total_num_heads // tp_world_size
|
||||
head_start = tp_rank * num_heads
|
||||
head_end = (tp_rank + 1) * num_heads
|
||||
|
||||
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
||||
if name.endswith(".weight"):
|
||||
loaded_weight = loaded_weight.view(3, total_num_heads,
|
||||
head_size, hidden_size)
|
||||
|
@ -21,7 +21,7 @@
|
||||
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
||||
InputMetadata to extract the original 2D shape of the input.
|
||||
"""
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -37,7 +37,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
||||
from vllm.sequence import SequenceOutputs
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
@ -282,7 +282,7 @@ class OPTForCausalLM(nn.Module):
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> Dict[int, SequenceOutputs]:
|
||||
) -> SamplerOutput:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
input_metadata, cache_events)
|
||||
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
|
||||
@ -297,12 +297,12 @@ class OPTForCausalLM(nn.Module):
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
use_np_cache: bool = False):
|
||||
load_format: str = "auto"):
|
||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||
state_dict = self.state_dict()
|
||||
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, use_np_cache):
|
||||
model_name_or_path, cache_dir, load_format):
|
||||
if "lm_head.weight" in name:
|
||||
continue
|
||||
|
||||
|
@ -8,7 +8,7 @@
|
||||
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
||||
InputMetadata to extract the original 2D shape of the input.
|
||||
"""
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -19,7 +19,9 @@ from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.weight_utils import (
|
||||
convert_pyslice_to_tensor,
|
||||
hf_model_weights_iterator,
|
||||
load_padded_tensor_parallel_vocab,
|
||||
load_tensor_parallel_weights,
|
||||
)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
@ -31,7 +33,7 @@ from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||
ColumnParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.sequence import SequenceOutputs
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.transformers_utils.configs.qwen import QWenConfig
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
@ -234,40 +236,32 @@ class QWenLMHeadModel(nn.Module):
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> Dict[int, SequenceOutputs]:
|
||||
) -> SamplerOutput:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
input_metadata, cache_events)
|
||||
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
||||
input_metadata)
|
||||
return next_tokens
|
||||
|
||||
_column_parallel_weights = ["wte.weight", "lm_head.weight"]
|
||||
_column_parallel_weights = []
|
||||
_row_parallel_weights = ["c_proj.weight"]
|
||||
|
||||
def load_weights(
|
||||
self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
use_np_cache: bool = False,
|
||||
load_format: str = "auto",
|
||||
):
|
||||
tp_world_size = get_tensor_model_parallel_world_size()
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
state_dict = self.state_dict()
|
||||
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, use_np_cache):
|
||||
model_name_or_path, cache_dir, load_format):
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
if "wte" in name or "lm_head" in name:
|
||||
# Consider padding in the vocab size.
|
||||
param = state_dict[name]
|
||||
padded_vocab_size = param.shape[0] * tp_world_size
|
||||
num_extra_rows = padded_vocab_size - self.config.vocab_size
|
||||
extra_rows = torch.empty(num_extra_rows,
|
||||
loaded_weight.shape[1])
|
||||
extra_rows = extra_rows.to(loaded_weight)
|
||||
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
|
||||
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
||||
|
||||
if "c_attn" in name:
|
||||
total_num_heads = self.config.num_attention_heads
|
||||
@ -306,6 +300,12 @@ class QWenLMHeadModel(nn.Module):
|
||||
continue
|
||||
|
||||
param = state_dict[name]
|
||||
|
||||
if "wte" in name or "lm_head" in name:
|
||||
load_padded_tensor_parallel_vocab(param, loaded_weight,
|
||||
tp_rank)
|
||||
continue
|
||||
|
||||
load_tensor_parallel_weights(
|
||||
param,
|
||||
loaded_weight,
|
||||
|
@ -3,13 +3,19 @@ import filelock
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
from typing import Iterator, List, Optional, Tuple
|
||||
from collections import defaultdict
|
||||
from typing import Iterator, List, Optional, Tuple, Any
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from safetensors.torch import load_file, save_file, safe_open
|
||||
import numpy as np
|
||||
import torch
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class Disabledtqdm(tqdm):
|
||||
|
||||
@ -17,43 +23,140 @@ class Disabledtqdm(tqdm):
|
||||
super().__init__(*args, **kwargs, disable=True)
|
||||
|
||||
|
||||
def hf_model_weights_iterator(
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
use_np_cache: bool = False,
|
||||
) -> Iterator[Tuple[str, torch.Tensor]]:
|
||||
# Prepare file lock directory to prevent multiple processes from
|
||||
# downloading the same model weights at the same time.
|
||||
def get_lock(model_name_or_path: str, cache_dir: Optional[str] = None):
|
||||
lock_dir = cache_dir if cache_dir is not None else "/tmp"
|
||||
lock_file_name = model_name_or_path.replace("/", "-") + ".lock"
|
||||
lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name))
|
||||
return lock
|
||||
|
||||
|
||||
def _shared_pointers(tensors):
|
||||
ptrs = defaultdict(list)
|
||||
for k, v in tensors.items():
|
||||
ptrs[v.data_ptr()].append(k)
|
||||
failing = []
|
||||
for _, names in ptrs.items():
|
||||
if len(names) > 1:
|
||||
failing.append(names)
|
||||
return failing
|
||||
|
||||
|
||||
def convert_bin_to_safetensor_file(
|
||||
pt_filename: str,
|
||||
sf_filename: str,
|
||||
):
|
||||
loaded = torch.load(pt_filename, map_location="cpu")
|
||||
if "state_dict" in loaded:
|
||||
loaded = loaded["state_dict"]
|
||||
shared = _shared_pointers(loaded)
|
||||
for shared_weights in shared:
|
||||
for name in shared_weights[1:]:
|
||||
loaded.pop(name)
|
||||
|
||||
# For tensors to be contiguous
|
||||
loaded = {k: v.contiguous() for k, v in loaded.items()}
|
||||
|
||||
dirname = os.path.dirname(sf_filename)
|
||||
os.makedirs(dirname, exist_ok=True)
|
||||
save_file(loaded, sf_filename, metadata={"format": "pt"})
|
||||
|
||||
# check file size
|
||||
sf_size = os.stat(sf_filename).st_size
|
||||
pt_size = os.stat(pt_filename).st_size
|
||||
if (sf_size - pt_size) / pt_size > 0.01:
|
||||
raise RuntimeError(f"""The file size different is more than 1%:
|
||||
- {sf_filename}: {sf_size}
|
||||
- {pt_filename}: {pt_size}
|
||||
""")
|
||||
|
||||
# check if the tensors are the same
|
||||
reloaded = load_file(sf_filename)
|
||||
for k in loaded:
|
||||
pt_tensor = loaded[k]
|
||||
sf_tensor = reloaded[k]
|
||||
if not torch.equal(pt_tensor, sf_tensor):
|
||||
raise RuntimeError(f"The output tensors do not match for key {k}")
|
||||
|
||||
|
||||
def prepare_hf_model_weights(
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
use_safetensors: bool = False,
|
||||
fall_back_to_pt: bool = True,
|
||||
):
|
||||
# Download model weights from huggingface.
|
||||
is_local = os.path.isdir(model_name_or_path)
|
||||
allow_patterns = "*.safetensors" if use_safetensors else "*.bin"
|
||||
if not is_local:
|
||||
with lock:
|
||||
# Use file lock to prevent multiple processes from
|
||||
# downloading the same model weights at the same time.
|
||||
with get_lock(model_name_or_path, cache_dir):
|
||||
hf_folder = snapshot_download(model_name_or_path,
|
||||
allow_patterns="*.bin",
|
||||
allow_patterns=allow_patterns,
|
||||
cache_dir=cache_dir,
|
||||
tqdm_class=Disabledtqdm)
|
||||
else:
|
||||
hf_folder = model_name_or_path
|
||||
hf_weights_files = glob.glob(os.path.join(hf_folder, allow_patterns))
|
||||
if not use_safetensors:
|
||||
hf_weights_files = [
|
||||
x for x in hf_weights_files if not x.endswith("training_args.bin")
|
||||
]
|
||||
|
||||
hf_bin_files = [
|
||||
x for x in glob.glob(os.path.join(hf_folder, "*.bin"))
|
||||
if not x.endswith("training_args.bin")
|
||||
]
|
||||
if len(hf_weights_files) == 0 and use_safetensors and fall_back_to_pt:
|
||||
return prepare_hf_model_weights(model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
use_safetensors=False,
|
||||
fall_back_to_pt=False)
|
||||
|
||||
if len(hf_weights_files) == 0:
|
||||
raise RuntimeError(
|
||||
f"Cannot find any model weights with `{model_name_or_path}`")
|
||||
|
||||
return hf_folder, hf_weights_files, use_safetensors
|
||||
|
||||
|
||||
def hf_model_weights_iterator(
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
) -> Iterator[Tuple[str, torch.Tensor]]:
|
||||
use_safetensors = False
|
||||
use_np_cache = False
|
||||
fall_back_to_pt = False
|
||||
if load_format == "auto":
|
||||
use_safetensors = True
|
||||
fall_back_to_pt = True
|
||||
elif load_format == "safetensors":
|
||||
use_safetensors = True
|
||||
elif load_format == "pt":
|
||||
pass
|
||||
elif load_format == "npcache":
|
||||
use_np_cache = True
|
||||
else:
|
||||
raise ValueError(f"Unknown load_format: {load_format}")
|
||||
|
||||
hf_folder, hf_weights_files, use_safetensors = prepare_hf_model_weights(
|
||||
model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
use_safetensors=use_safetensors,
|
||||
fall_back_to_pt=fall_back_to_pt)
|
||||
|
||||
if use_np_cache:
|
||||
# Currently np_cache only support *.bin checkpoints
|
||||
assert use_safetensors is False
|
||||
|
||||
# Convert the model weights from torch tensors to numpy arrays for
|
||||
# faster loading.
|
||||
np_folder = os.path.join(hf_folder, "np")
|
||||
os.makedirs(np_folder, exist_ok=True)
|
||||
weight_names_file = os.path.join(np_folder, "weight_names.json")
|
||||
with lock:
|
||||
# Use file lock to prevent multiple processes from
|
||||
# dumping the same model weights to numpy at the same time.
|
||||
with get_lock(model_name_or_path, cache_dir):
|
||||
if not os.path.exists(weight_names_file):
|
||||
weight_names = []
|
||||
for bin_file in hf_bin_files:
|
||||
for bin_file in hf_weights_files:
|
||||
state = torch.load(bin_file, map_location="cpu")
|
||||
for name, param in state.items():
|
||||
param_path = os.path.join(np_folder, name)
|
||||
@ -71,8 +174,14 @@ def hf_model_weights_iterator(
|
||||
with open(param_path, "rb") as f:
|
||||
param = np.load(f)
|
||||
yield name, torch.from_numpy(param)
|
||||
elif use_safetensors:
|
||||
for st_file in hf_weights_files:
|
||||
with safe_open(st_file, framework="pt") as f:
|
||||
for name in f.keys():
|
||||
param = f.get_slice(name)
|
||||
yield name, param
|
||||
else:
|
||||
for bin_file in hf_bin_files:
|
||||
for bin_file in hf_weights_files:
|
||||
state = torch.load(bin_file, map_location="cpu")
|
||||
for name, param in state.items():
|
||||
yield name, param
|
||||
@ -80,9 +189,37 @@ def hf_model_weights_iterator(
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
|
||||
"""convert PySafeSlice object from safetensors to torch.Tensor
|
||||
|
||||
PySafeSlice object supports indexing, which is done before loading the
|
||||
actual tensor and can reduce the amount of memory being read into the
|
||||
memory. However, it does not support more advanced functionalities
|
||||
like `.view()` or `.t()`. Therefore, if we need to modify the loaded
|
||||
tensor with these more complicated operators, we need to convert to
|
||||
tensor first.
|
||||
"""
|
||||
if not isinstance(x, torch.Tensor):
|
||||
x = x[:]
|
||||
return x
|
||||
|
||||
|
||||
def load_padded_tensor_parallel_vocab(
|
||||
param: torch.Tensor,
|
||||
loaded_weight: Any, # `torch.Tensor` or `PySafeSlice`
|
||||
tensor_model_parallel_rank: int,
|
||||
) -> None:
|
||||
shard_size = param.shape[0]
|
||||
start_idx = tensor_model_parallel_rank * shard_size
|
||||
end_idx = (tensor_model_parallel_rank + 1) * shard_size
|
||||
loaded_weight = loaded_weight[start_idx:end_idx]
|
||||
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
||||
param[:loaded_weight.shape[0]].copy_(loaded_weight)
|
||||
|
||||
|
||||
def load_tensor_parallel_weights(
|
||||
param: torch.Tensor,
|
||||
loaded_weight: torch.Tensor,
|
||||
loaded_weight: Any, # `torch.Tensor` or `PySafeSlice`
|
||||
param_name: str,
|
||||
column_parallel_weight_names: List[str],
|
||||
row_parallel_weight_names: List[str],
|
||||
@ -102,6 +239,8 @@ def load_tensor_parallel_weights(
|
||||
end_idx = (tensor_model_parallel_rank + 1) * shard_size
|
||||
loaded_weight = loaded_weight[:, start_idx:end_idx]
|
||||
break
|
||||
|
||||
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
||||
assert param.shape == loaded_weight.shape, (
|
||||
f"{param_name} shape mismatch between model and checkpoint: "
|
||||
f"{param.shape} != {loaded_weight.shape}")
|
||||
|
@ -75,10 +75,12 @@ class RequestOutput:
|
||||
# Get the top-n sequences.
|
||||
n = seq_group.sampling_params.n
|
||||
seqs = seq_group.get_seqs()
|
||||
assert n <= len(seqs)
|
||||
sorted_seqs = sorted(seqs,
|
||||
key=lambda seq: seq.get_cumulative_logprob(),
|
||||
reverse=True)
|
||||
if seq_group.sampling_params.use_beam_search:
|
||||
sorting_key = lambda seq: seq.get_beam_search_score(
|
||||
seq_group.sampling_params.length_penalty)
|
||||
else:
|
||||
sorting_key = lambda seq: seq.get_cumulative_logprob()
|
||||
sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
|
||||
top_n_seqs = sorted_seqs[:n]
|
||||
|
||||
# Create the outputs.
|
||||
|
@ -34,6 +34,15 @@ class SamplingParams:
|
||||
top_k: Integer that controls the number of top tokens to consider. Set
|
||||
to -1 to consider all tokens.
|
||||
use_beam_search: Whether to use beam search instead of sampling.
|
||||
length_penalty: Float that penalizes sequences based on their length.
|
||||
Used in beam search.
|
||||
early_stopping: Controls the stopping condition for beam search. It
|
||||
accepts the following values: `True`, where the generation stops as
|
||||
soon as there are `best_of` complete candidates; `False`, where an
|
||||
heuristic is applied and the generation stops when is it very
|
||||
unlikely to find better candidates; `"never"`, where the beam search
|
||||
procedure only stops when there cannot be better candidates
|
||||
(canonical beam search algorithm).
|
||||
stop: List of strings that stop the generation when they are generated.
|
||||
The returned output will not contain the stop strings.
|
||||
ignore_eos: Whether to ignore the EOS token and continue generating
|
||||
@ -52,6 +61,8 @@ class SamplingParams:
|
||||
top_p: float = 1.0,
|
||||
top_k: int = -1,
|
||||
use_beam_search: bool = False,
|
||||
length_penalty: float = 1.0,
|
||||
early_stopping: Union[bool, str] = False,
|
||||
stop: Union[None, str, List[str]] = None,
|
||||
ignore_eos: bool = False,
|
||||
max_tokens: int = 16,
|
||||
@ -65,6 +76,8 @@ class SamplingParams:
|
||||
self.top_p = top_p
|
||||
self.top_k = top_k
|
||||
self.use_beam_search = use_beam_search
|
||||
self.length_penalty = length_penalty
|
||||
self.early_stopping = early_stopping
|
||||
if stop is None:
|
||||
self.stop = []
|
||||
elif isinstance(stop, str):
|
||||
@ -78,9 +91,11 @@ class SamplingParams:
|
||||
self._verify_args()
|
||||
if self.use_beam_search:
|
||||
self._verify_beam_search()
|
||||
elif self.temperature < _SAMPLING_EPS:
|
||||
# Zero temperature means greedy sampling.
|
||||
self._verify_greedy_sampling()
|
||||
else:
|
||||
self._verify_non_beam_search()
|
||||
if self.temperature < _SAMPLING_EPS:
|
||||
# Zero temperature means greedy sampling.
|
||||
self._verify_greedy_sampling()
|
||||
|
||||
def _verify_args(self) -> None:
|
||||
if self.n < 1:
|
||||
@ -119,6 +134,20 @@ class SamplingParams:
|
||||
raise ValueError("top_p must be 1 when using beam search.")
|
||||
if self.top_k != -1:
|
||||
raise ValueError("top_k must be -1 when using beam search.")
|
||||
if self.early_stopping not in [True, False, "never"]:
|
||||
raise ValueError(
|
||||
f"early_stopping must be True, False, or 'never', "
|
||||
f"got {self.early_stopping}.")
|
||||
|
||||
def _verify_non_beam_search(self) -> None:
|
||||
if self.early_stopping is not False:
|
||||
raise ValueError("early_stopping is not effective and must be "
|
||||
"False when not using beam search.")
|
||||
if (self.length_penalty < 1.0 - _SAMPLING_EPS
|
||||
or self.length_penalty > 1.0 + _SAMPLING_EPS):
|
||||
raise ValueError(
|
||||
"length_penalty is not effective and must be the "
|
||||
"default value of 1.0 when not using beam search.")
|
||||
|
||||
def _verify_greedy_sampling(self) -> None:
|
||||
if self.best_of > 1:
|
||||
@ -138,6 +167,8 @@ class SamplingParams:
|
||||
f"top_p={self.top_p}, "
|
||||
f"top_k={self.top_k}, "
|
||||
f"use_beam_search={self.use_beam_search}, "
|
||||
f"length_penalty={self.length_penalty}, "
|
||||
f"early_stopping={self.early_stopping}, "
|
||||
f"stop={self.stop}, "
|
||||
f"ignore_eos={self.ignore_eos}, "
|
||||
f"max_tokens={self.max_tokens}, "
|
||||
|
@ -69,6 +69,9 @@ class SequenceData:
|
||||
def get_len(self) -> int:
|
||||
return len(self.output_token_ids) + len(self.prompt_token_ids)
|
||||
|
||||
def get_prompt_len(self) -> int:
|
||||
return len(self.prompt_token_ids)
|
||||
|
||||
def get_output_len(self) -> int:
|
||||
return len(self.output_token_ids)
|
||||
|
||||
@ -155,6 +158,9 @@ class Sequence:
|
||||
def get_len(self) -> int:
|
||||
return self.data.get_len()
|
||||
|
||||
def get_prompt_len(self) -> int:
|
||||
return self.data.get_prompt_len()
|
||||
|
||||
def get_output_len(self) -> int:
|
||||
return self.data.get_output_len()
|
||||
|
||||
@ -170,14 +176,32 @@ class Sequence:
|
||||
def get_cumulative_logprob(self) -> float:
|
||||
return self.data.cumulative_logprob
|
||||
|
||||
def get_beam_search_score(self,
|
||||
length_penalty: float = 0.0,
|
||||
seq_len: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None) -> float:
|
||||
"""Calculate the beam search score with length penalty.
|
||||
|
||||
Adapted from
|
||||
|
||||
https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938
|
||||
"""
|
||||
if seq_len is None:
|
||||
seq_len = self.get_len()
|
||||
# Note: HF implementation does not count the EOS token
|
||||
# towards the length, we align with that here for testing.
|
||||
if (eos_token_id is not None
|
||||
and self.get_last_token_id() == eos_token_id):
|
||||
seq_len -= 1
|
||||
return self.get_cumulative_logprob() / (seq_len**length_penalty)
|
||||
|
||||
def is_finished(self) -> bool:
|
||||
return SequenceStatus.is_finished(self.status)
|
||||
|
||||
def fork(self, child_seq: "Sequence") -> None:
|
||||
child_seq.logical_token_blocks = copy.deepcopy(
|
||||
self.logical_token_blocks)
|
||||
child_seq.output_logprobs = copy.deepcopy(self.output_logprobs)
|
||||
child_seq.data = copy.deepcopy(self.data)
|
||||
def fork(self, new_seq_id: int) -> "Sequence":
|
||||
new_seq = copy.deepcopy(self)
|
||||
new_seq.seq_id = new_seq_id
|
||||
return new_seq
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"Sequence(seq_id={self.seq_id}, "
|
||||
@ -203,35 +227,66 @@ class SequenceGroup:
|
||||
arrival_time: float,
|
||||
) -> None:
|
||||
self.request_id = request_id
|
||||
self.seqs = seqs
|
||||
self.seqs_dict = {seq.seq_id: seq for seq in seqs}
|
||||
self.sampling_params = sampling_params
|
||||
self.arrival_time = arrival_time
|
||||
|
||||
def get_max_num_running_seqs(self) -> int:
|
||||
"""The maximum number of sequences running in parallel in the remaining
|
||||
lifetime of the request."""
|
||||
if self.sampling_params.use_beam_search:
|
||||
# For beam search, maximally there will always be `best_of` beam
|
||||
# candidates running in the future.
|
||||
return self.sampling_params.best_of
|
||||
else:
|
||||
if self.sampling_params.best_of > self.num_seqs():
|
||||
# At prompt stage, the sequence group is not yet filled up
|
||||
# and only have one sequence running. However, in the
|
||||
# generation stage, we will have `best_of` sequences running.
|
||||
return self.sampling_params.best_of
|
||||
# At sampling stages, return the number of actual sequences
|
||||
# running.
|
||||
return self.num_seqs(status=SequenceStatus.RUNNING)
|
||||
|
||||
def get_seqs(
|
||||
self,
|
||||
status: Optional[SequenceStatus] = None,
|
||||
) -> List[Sequence]:
|
||||
if status is None:
|
||||
return self.seqs
|
||||
return list(self.seqs_dict.values())
|
||||
else:
|
||||
return [seq for seq in self.seqs if seq.status == status]
|
||||
return [
|
||||
seq for seq in self.seqs_dict.values() if seq.status == status
|
||||
]
|
||||
|
||||
def get_finished_seqs(self) -> List[Sequence]:
|
||||
return [seq for seq in self.seqs_dict.values() if seq.is_finished()]
|
||||
|
||||
def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
|
||||
return len(self.get_seqs(status))
|
||||
|
||||
def find(self, seq_id: int) -> Sequence:
|
||||
for seq in self.seqs:
|
||||
if seq.seq_id == seq_id:
|
||||
return seq
|
||||
raise ValueError(f"Sequence {seq_id} not found.")
|
||||
if seq_id not in self.seqs_dict:
|
||||
raise ValueError(f"Sequence {seq_id} not found.")
|
||||
return self.seqs_dict[seq_id]
|
||||
|
||||
def add(self, seq: Sequence) -> None:
|
||||
if seq.seq_id in self.seqs_dict:
|
||||
raise ValueError(f"Sequence {seq.seq_id} already exists.")
|
||||
self.seqs_dict[seq.seq_id] = seq
|
||||
|
||||
def remove(self, seq_id: int) -> None:
|
||||
if seq_id not in self.seqs_dict:
|
||||
raise ValueError(f"Sequence {seq_id} not found.")
|
||||
del self.seqs_dict[seq_id]
|
||||
|
||||
def is_finished(self) -> bool:
|
||||
return all(seq.is_finished() for seq in self.seqs)
|
||||
return all(seq.is_finished() for seq in self.get_seqs())
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"SequenceGroup(request_id={self.request_id}, "
|
||||
f"sampling_params={self.sampling_params}, "
|
||||
f"num_seqs={len(self.seqs)})")
|
||||
f"num_seqs={len(self.seqs_dict)})")
|
||||
|
||||
|
||||
class SequenceGroupMetadata:
|
||||
@ -266,7 +321,6 @@ class SequenceOutputs:
|
||||
"""The model output associated with a sequence.
|
||||
|
||||
Args:
|
||||
seq_id: The ID of the sequence.
|
||||
parent_seq_id: The ID of the parent sequence (for forking in beam
|
||||
search).
|
||||
output_token: The output token ID.
|
||||
@ -276,26 +330,27 @@ class SequenceOutputs:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
seq_id: int,
|
||||
parent_seq_id: int,
|
||||
output_token: int,
|
||||
logprobs: Dict[int, float],
|
||||
) -> None:
|
||||
self.seq_id = seq_id
|
||||
self.parent_seq_id = parent_seq_id
|
||||
self.output_token = output_token
|
||||
self.logprobs = logprobs
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"SequenceOutputs(seq_id={self.seq_id}, "
|
||||
f"parent_seq_id={self.parent_seq_id}, "
|
||||
return (f"SequenceOutputs(parent_seq_id={self.parent_seq_id}, "
|
||||
f"output_token={self.output_token}), "
|
||||
f"logprobs={self.logprobs}")
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, SequenceOutputs):
|
||||
return NotImplemented
|
||||
return (self.seq_id == other.seq_id
|
||||
and self.parent_seq_id == other.parent_seq_id
|
||||
return NotImplementedError()
|
||||
return (self.parent_seq_id == other.parent_seq_id
|
||||
and self.output_token == other.output_token
|
||||
and self.logprobs == other.logprobs)
|
||||
|
||||
|
||||
# For each sequence group, we generate a list of SequenceOutputs object,
|
||||
# each of which contains one possible candidate for the next token.
|
||||
SamplerOutput = List[List[SequenceOutputs]]
|
||||
|
@ -25,7 +25,8 @@ def get_tokenizer(
|
||||
"Cannot use the fast tokenizer in slow tokenizer mode.")
|
||||
kwargs["use_fast"] = False
|
||||
|
||||
if "llama" in tokenizer_name.lower() and kwargs.get("use_fast", True):
|
||||
if ("llama" in tokenizer_name.lower() and kwargs.get("use_fast", True)
|
||||
and tokenizer_name != _FAST_LLAMA_TOKENIZER):
|
||||
logger.info(
|
||||
"For some LLaMA-based models, initializing the fast tokenizer may "
|
||||
"take a long time. To eliminate the initialization time, consider "
|
||||
|
@ -11,7 +11,7 @@ from vllm.model_executor import get_model, InputMetadata, set_random_seed
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
initialize_model_parallel)
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import SequenceData, SequenceGroupMetadata, SequenceOutputs
|
||||
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
|
||||
from vllm.worker.cache_engine import CacheEngine
|
||||
from vllm.utils import get_gpu_memory
|
||||
|
||||
@ -260,7 +260,7 @@ class Worker:
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
) -> Dict[int, SequenceOutputs]:
|
||||
) -> SamplerOutput:
|
||||
# Issue cache operations.
|
||||
issued_cache_op = False
|
||||
if blocks_to_swap_in:
|
||||
|
Reference in New Issue
Block a user