mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
Compare commits
4 Commits
v0.11.1rc1
...
v1-block-t
Author | SHA1 | Date | |
---|---|---|---|
44d638a896 | |||
caacd1ddfb | |||
e68f63ef83 | |||
223e17424c |
@ -228,6 +228,7 @@ endif()
|
||||
|
||||
set(VLLM_EXT_SRC
|
||||
"csrc/cache_kernels.cu"
|
||||
"csrc/block_table.cu"
|
||||
"csrc/attention/paged_attention_v1.cu"
|
||||
"csrc/attention/paged_attention_v2.cu"
|
||||
"csrc/pos_encoding_kernels.cu"
|
||||
|
92
csrc/block_table.cu
Normal file
92
csrc/block_table.cu
Normal file
@ -0,0 +1,92 @@
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
namespace vllm {
|
||||
__global__ void append_kernel(const int* __restrict__ row_indices,
|
||||
const int* __restrict__ cu_num_appends,
|
||||
const int* __restrict__ block_ids,
|
||||
int* __restrict__ block_table,
|
||||
int max_num_blocks_per_row) {
|
||||
int bid = blockIdx.x;
|
||||
int tgt_row = row_indices[2 * bid];
|
||||
int tgt_offset = row_indices[2 * bid + 1];
|
||||
|
||||
int start = cu_num_appends[bid];
|
||||
int end = cu_num_appends[bid + 1];
|
||||
int length = end - start;
|
||||
int tid = threadIdx.x;
|
||||
int64_t offset = tgt_row * max_num_blocks_per_row + tgt_offset;
|
||||
for (int i = tid; i < length; i += blockDim.x) {
|
||||
block_table[offset + i] = block_ids[start + i];
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void move_kernel(const int* __restrict__ src_dst_n,
|
||||
int* __restrict__ block_table,
|
||||
int max_num_blocks_per_row) {
|
||||
int bid = blockIdx.x;
|
||||
int src_row = src_dst_n[3 * bid];
|
||||
int tgt_row = src_dst_n[3 * bid + 1];
|
||||
int num_blocks = src_dst_n[3 * bid + 2];
|
||||
|
||||
int tid = threadIdx.x;
|
||||
for (int i = tid; i < num_blocks; i += blockDim.x) {
|
||||
block_table[tgt_row * max_num_blocks_per_row + i] =
|
||||
block_table[src_row * max_num_blocks_per_row + i];
|
||||
}
|
||||
}
|
||||
} // namespace vllm
|
||||
|
||||
void block_table_appends(
|
||||
torch::Tensor& append_row_indices,
|
||||
torch::Tensor& append_row_indices_cpu,
|
||||
torch::Tensor& append_cumsums,
|
||||
torch::Tensor& append_cumsums_cpu,
|
||||
torch::Tensor& append_block_ids,
|
||||
torch::Tensor& append_block_ids_cpu,
|
||||
torch::Tensor& block_table,
|
||||
int64_t num_appends,
|
||||
int64_t total_num_append_blocks) {
|
||||
int* append_row_indices_ptr = append_row_indices.data_ptr<int>();
|
||||
const int* append_row_indices_cpu_ptr = append_row_indices_cpu.data_ptr<int>();
|
||||
int* append_cumsums_ptr = append_cumsums.data_ptr<int>();
|
||||
const int* append_cumsums_cpu_ptr = append_cumsums_cpu.data_ptr<int>();
|
||||
int* append_block_ids_ptr = append_block_ids.data_ptr<int>();
|
||||
const int* append_block_ids_cpu_ptr = append_block_ids_cpu.data_ptr<int>();
|
||||
int* block_table_ptr = block_table.data_ptr<int>();
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(block_table));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
cudaMemcpyAsync(append_row_indices_ptr, append_row_indices_cpu_ptr,
|
||||
num_appends * 2 * sizeof(int), cudaMemcpyHostToDevice, stream);
|
||||
cudaMemcpyAsync(append_cumsums_ptr, append_cumsums_cpu_ptr,
|
||||
(num_appends + 1) * sizeof(int), cudaMemcpyHostToDevice, stream);
|
||||
cudaMemcpyAsync(append_block_ids_ptr, append_block_ids_cpu_ptr,
|
||||
total_num_append_blocks * sizeof(int), cudaMemcpyHostToDevice, stream);
|
||||
|
||||
int64_t max_num_blocks_per_row = block_table.size(1);
|
||||
vllm::append_kernel<<<num_appends, 1024, 0, stream>>>(
|
||||
append_row_indices_ptr, append_cumsums_ptr, append_block_ids_ptr,
|
||||
block_table_ptr, max_num_blocks_per_row);
|
||||
}
|
||||
|
||||
void block_table_moves(
|
||||
torch::Tensor& src_dst_n,
|
||||
torch::Tensor& src_dst_n_cpu,
|
||||
torch::Tensor& block_table,
|
||||
int64_t num_moves) {
|
||||
int* src_dst_n_ptr = src_dst_n.data_ptr<int>();
|
||||
const int* src_dst_n_cpu_ptr = src_dst_n_cpu.data_ptr<int>();
|
||||
int* block_table_ptr = block_table.data_ptr<int>();
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(block_table));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
cudaMemcpyAsync(src_dst_n_ptr, src_dst_n_cpu_ptr,
|
||||
num_moves * 3 * sizeof(int), cudaMemcpyHostToDevice, stream);
|
||||
|
||||
int64_t max_num_blocks_per_row = block_table.size(1);
|
||||
vllm::move_kernel<<<num_moves, 1024, 0, stream>>>(
|
||||
src_dst_n_ptr, block_table_ptr, max_num_blocks_per_row);
|
||||
}
|
12
csrc/ops.h
12
csrc/ops.h
@ -119,6 +119,18 @@ void advance_step_flashinfer(
|
||||
torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr,
|
||||
torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bounds);
|
||||
|
||||
void block_table_appends(torch::Tensor& append_row_indices,
|
||||
torch::Tensor& append_row_indices_cpu,
|
||||
torch::Tensor& append_cumsums,
|
||||
torch::Tensor& append_cumsums_cpu,
|
||||
torch::Tensor& append_block_ids,
|
||||
torch::Tensor& append_block_ids_cpu,
|
||||
torch::Tensor& block_table, int64_t num_appends,
|
||||
int64_t total_num_append_blocks);
|
||||
|
||||
void block_table_moves(torch::Tensor& src_dst_n, torch::Tensor& src_dst_n_cpu,
|
||||
torch::Tensor& block_table, int64_t num_moves);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
|
||||
const torch::Tensor& codebooks,
|
||||
|
@ -111,6 +111,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
") -> ()");
|
||||
ops.impl("advance_step_flashinfer", torch::kCUDA, &advance_step_flashinfer);
|
||||
|
||||
ops.def(
|
||||
"block_table_appends(Tensor append_row_indices, "
|
||||
"Tensor append_row_indices_cpu, Tensor append_cumsums, "
|
||||
"Tensor append_cumsums_cpu, Tensor append_block_ids, "
|
||||
"Tensor append_block_ids_cpu, Tensor! block_table, int num_appends, "
|
||||
"int total_num_append_blocks) -> ()");
|
||||
ops.impl("block_table_appends", torch::kCUDA, &block_table_appends);
|
||||
|
||||
ops.def(
|
||||
"block_table_moves(Tensor src_dst_n, Tensor src_dst_n_cpu, "
|
||||
"Tensor! block_table, int num_moves) -> ()");
|
||||
ops.impl("block_table_moves", torch::kCUDA, &block_table_moves);
|
||||
|
||||
// Layernorm
|
||||
// Apply Root Mean Square (RMS) Normalization to the input tensor.
|
||||
ops.def(
|
||||
|
@ -202,6 +202,33 @@ def advance_step_flashinfer(num_seqs: int, num_queries: int, block_size: int,
|
||||
block_table_bound)
|
||||
|
||||
|
||||
def block_table_appends(
|
||||
append_row_indices: torch.Tensor,
|
||||
append_row_indices_cpu: torch.Tensor,
|
||||
append_cumsums: torch.Tensor,
|
||||
append_cumsums_cpu: torch.Tensor,
|
||||
append_block_ids: torch.Tensor,
|
||||
append_block_ids_cpu: torch.Tensor,
|
||||
block_table: torch.Tensor,
|
||||
num_appends: int,
|
||||
total_num_append_blocks: int,
|
||||
) -> None:
|
||||
torch.ops._C.block_table_appends.default(
|
||||
append_row_indices, append_row_indices_cpu, append_cumsums,
|
||||
append_cumsums_cpu, append_block_ids, append_block_ids_cpu,
|
||||
block_table, num_appends, total_num_append_blocks)
|
||||
|
||||
|
||||
def block_table_moves(
|
||||
src_dst_n: torch.Tensor,
|
||||
src_dst_n_cpu: torch.Tensor,
|
||||
block_table: torch.Tensor,
|
||||
num_moves: int,
|
||||
) -> None:
|
||||
torch.ops._C.block_table_moves.default(src_dst_n, src_dst_n_cpu,
|
||||
block_table, num_moves)
|
||||
|
||||
|
||||
# fused quant layer norm ops
|
||||
def rms_norm_dynamic_per_token_quant(
|
||||
input: torch.Tensor,
|
||||
|
@ -9,6 +9,7 @@ logger = init_logger(__name__)
|
||||
|
||||
|
||||
class BlockTable:
|
||||
"""Device-agnostic block table for storing block IDs for each request."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
164
vllm/v1/worker/gpu_block_table.py
Normal file
164
vllm/v1/worker/gpu_block_table.py
Normal file
@ -0,0 +1,164 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import List, Set
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class GPUBlockTable:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_num_reqs: int,
|
||||
max_num_blocks_per_req: int,
|
||||
pin_memory: bool,
|
||||
device: torch.device,
|
||||
):
|
||||
self.max_num_reqs = max_num_reqs
|
||||
self.max_num_blocks_per_req = max_num_blocks_per_req
|
||||
self.pin_memory = pin_memory
|
||||
self.device = device
|
||||
|
||||
self.block_table = torch.zeros(
|
||||
(max_num_reqs, max_num_blocks_per_req),
|
||||
device=self.device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
self.block_table_cpu = torch.zeros(
|
||||
(max_num_reqs, max_num_blocks_per_req),
|
||||
device="cpu",
|
||||
dtype=torch.int32,
|
||||
pin_memory=False,
|
||||
)
|
||||
self.block_table_np = self.block_table_cpu.numpy()
|
||||
self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32)
|
||||
|
||||
self.block_table_diff_np = np.zeros(
|
||||
(max_num_reqs, 2),
|
||||
dtype=np.int32,
|
||||
)
|
||||
self.diff_rows: Set[int] = set()
|
||||
|
||||
self.append_row_indices = torch.zeros(
|
||||
(max_num_reqs, 2),
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
self.append_row_indices_cpu = torch.zeros_like(
|
||||
self.append_row_indices,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
self.append_row_indices_np = self.append_row_indices_cpu.numpy()
|
||||
self.append_cumsums = torch.zeros(
|
||||
(max_num_reqs + 1, ),
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
self.append_cumsums_cpu = torch.zeros_like(
|
||||
self.append_cumsums,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
self.append_cumsums_np = self.append_cumsums_cpu.numpy()
|
||||
self.append_data = torch.zeros(
|
||||
(max_num_reqs * max_num_blocks_per_req, ),
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
self.append_data_cpu = torch.zeros_like(
|
||||
self.append_data,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
self.append_data_np = self.append_data_cpu.numpy()
|
||||
|
||||
def append_row(
|
||||
self,
|
||||
row_idx: int,
|
||||
start: int,
|
||||
block_ids: List[int],
|
||||
) -> None:
|
||||
num_blocks = len(block_ids)
|
||||
self.block_table_np[row_idx, start:start + num_blocks] = block_ids
|
||||
self.num_blocks_per_row[row_idx] = start + num_blocks
|
||||
|
||||
self.block_table_diff_np[row_idx, 0] = start
|
||||
self.block_table_diff_np[row_idx, 1] = num_blocks
|
||||
self.diff_rows.add(row_idx)
|
||||
|
||||
def add_row(self, row_idx: int, block_ids: List[int]) -> None:
|
||||
self.append_row(row_idx, 0, block_ids)
|
||||
|
||||
def move_row(self, src: int, tgt: int) -> None:
|
||||
num_blocks = self.num_blocks_per_row[src]
|
||||
self.block_table_np[tgt, :num_blocks] = self.block_table_np[
|
||||
src, :num_blocks]
|
||||
self.num_blocks_per_row[tgt] = num_blocks
|
||||
|
||||
self.block_table_diff_np[tgt, 0] = 0
|
||||
self.block_table_diff_np[tgt, 1] = num_blocks
|
||||
self.diff_rows.discard(src)
|
||||
self.diff_rows.add(tgt)
|
||||
|
||||
def commit(self, num_reqs: int) -> None:
|
||||
if not self.diff_rows:
|
||||
return
|
||||
|
||||
cu_end = 0
|
||||
self.append_cumsums_np[0] = 0
|
||||
for i, row_idx in enumerate(self.diff_rows):
|
||||
start, num_blocks = self.block_table_diff_np[row_idx]
|
||||
assert num_blocks > 0
|
||||
|
||||
self.append_row_indices_np[i, 0] = row_idx
|
||||
self.append_row_indices_np[i, 1] = start
|
||||
cu_start = self.append_cumsums_np[i]
|
||||
cu_end = cu_start + num_blocks
|
||||
self.append_cumsums_np[i + 1] = cu_end
|
||||
self.append_data_np[cu_start:cu_end] = self.block_table_np[
|
||||
row_idx, start:start + num_blocks]
|
||||
|
||||
ops.block_table_appends(
|
||||
self.append_row_indices,
|
||||
self.append_row_indices_cpu,
|
||||
self.append_cumsums,
|
||||
self.append_cumsums_cpu,
|
||||
self.append_data,
|
||||
self.append_data_cpu,
|
||||
self.block_table,
|
||||
len(self.diff_rows),
|
||||
cu_end,
|
||||
)
|
||||
self.diff_rows.clear()
|
||||
|
||||
def clear(self) -> None:
|
||||
self.block_table.fill_(0)
|
||||
self.block_table_cpu.fill_(0)
|
||||
|
||||
self.diff_rows.clear()
|
||||
self.block_table_diff_np.fill(0)
|
||||
|
||||
self.append_row_indices.fill_(0)
|
||||
self.append_row_indices_cpu.fill_(0)
|
||||
self.append_cumsums.fill_(0)
|
||||
self.append_cumsums_cpu.fill_(0)
|
||||
self.append_data.fill_(0)
|
||||
self.append_data_cpu.fill_(0)
|
||||
|
||||
def get_device_tensor(self) -> torch.Tensor:
|
||||
"""Ruturns the device tensor of the block table."""
|
||||
return self.block_table
|
||||
|
||||
def get_cpu_tensor(self) -> torch.Tensor:
|
||||
"""Returns the CPU tensor of the block table."""
|
||||
return self.block_table_cpu
|
||||
|
||||
def get_numpy_array(self) -> np.ndarray:
|
||||
"""Returns the numpy array of the block table."""
|
||||
return self.block_table_np
|
@ -14,7 +14,7 @@ from vllm.utils import swap_dict_values
|
||||
from vllm.v1.outputs import LogprobsTensors
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.utils import copy_slice
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
from vllm.v1.worker.gpu_block_table import GPUBlockTable
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
|
||||
@ -92,7 +92,7 @@ class InputBatch:
|
||||
self.num_computed_tokens_cpu_tensor.numpy()
|
||||
|
||||
# Block table.
|
||||
self.block_table = BlockTable(
|
||||
self.block_table = GPUBlockTable(
|
||||
max_num_reqs=max_num_reqs,
|
||||
max_num_blocks_per_req=max_num_blocks_per_req,
|
||||
pin_memory=pin_memory,
|
||||
|
Reference in New Issue
Block a user