mirror of
				https://github.com/vllm-project/vllm.git
				synced 2025-11-04 17:34:34 +08:00 
			
		
		
		
	Compare commits
	
		
			4 Commits
		
	
	
		
			v0.11.0rc2
			...
			v1-block-t
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 44d638a896 | |||
| caacd1ddfb | |||
| e68f63ef83 | |||
| 223e17424c | 
@ -228,6 +228,7 @@ endif()
 | 
			
		||||
 | 
			
		||||
set(VLLM_EXT_SRC
 | 
			
		||||
  "csrc/cache_kernels.cu"
 | 
			
		||||
  "csrc/block_table.cu"
 | 
			
		||||
  "csrc/attention/paged_attention_v1.cu"
 | 
			
		||||
  "csrc/attention/paged_attention_v2.cu"
 | 
			
		||||
  "csrc/pos_encoding_kernels.cu"
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										92
									
								
								csrc/block_table.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										92
									
								
								csrc/block_table.cu
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,92 @@
 | 
			
		||||
#include <torch/all.h>
 | 
			
		||||
 | 
			
		||||
#include <ATen/cuda/CUDAContext.h>
 | 
			
		||||
#include <c10/cuda/CUDAGuard.h>
 | 
			
		||||
 | 
			
		||||
namespace vllm {
 | 
			
		||||
__global__ void append_kernel(const int* __restrict__ row_indices,
 | 
			
		||||
                              const int* __restrict__ cu_num_appends,
 | 
			
		||||
                              const int* __restrict__ block_ids,
 | 
			
		||||
                              int* __restrict__ block_table,
 | 
			
		||||
                              int max_num_blocks_per_row) {
 | 
			
		||||
  int bid = blockIdx.x;
 | 
			
		||||
  int tgt_row = row_indices[2 * bid];
 | 
			
		||||
  int tgt_offset = row_indices[2 * bid + 1];
 | 
			
		||||
 | 
			
		||||
  int start = cu_num_appends[bid];
 | 
			
		||||
  int end = cu_num_appends[bid + 1];
 | 
			
		||||
  int length = end - start;
 | 
			
		||||
  int tid = threadIdx.x;
 | 
			
		||||
  int64_t offset = tgt_row * max_num_blocks_per_row + tgt_offset;
 | 
			
		||||
  for (int i = tid; i < length; i += blockDim.x) {
 | 
			
		||||
    block_table[offset + i] = block_ids[start + i];
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
__global__ void move_kernel(const int* __restrict__ src_dst_n,
 | 
			
		||||
                            int* __restrict__ block_table,
 | 
			
		||||
                            int max_num_blocks_per_row) {
 | 
			
		||||
  int bid = blockIdx.x;
 | 
			
		||||
  int src_row = src_dst_n[3 * bid];
 | 
			
		||||
  int tgt_row = src_dst_n[3 * bid + 1];
 | 
			
		||||
  int num_blocks = src_dst_n[3 * bid + 2];
 | 
			
		||||
 | 
			
		||||
  int tid = threadIdx.x;
 | 
			
		||||
  for (int i = tid; i < num_blocks; i += blockDim.x) {
 | 
			
		||||
    block_table[tgt_row * max_num_blocks_per_row + i] =
 | 
			
		||||
        block_table[src_row * max_num_blocks_per_row + i];
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
}  // namespace vllm
 | 
			
		||||
 | 
			
		||||
void block_table_appends(
 | 
			
		||||
    torch::Tensor& append_row_indices,
 | 
			
		||||
    torch::Tensor& append_row_indices_cpu,
 | 
			
		||||
    torch::Tensor& append_cumsums,
 | 
			
		||||
    torch::Tensor& append_cumsums_cpu,
 | 
			
		||||
    torch::Tensor& append_block_ids,
 | 
			
		||||
    torch::Tensor& append_block_ids_cpu,
 | 
			
		||||
    torch::Tensor& block_table,
 | 
			
		||||
    int64_t num_appends,
 | 
			
		||||
    int64_t total_num_append_blocks) {
 | 
			
		||||
  int* append_row_indices_ptr = append_row_indices.data_ptr<int>();
 | 
			
		||||
  const int* append_row_indices_cpu_ptr = append_row_indices_cpu.data_ptr<int>();
 | 
			
		||||
  int* append_cumsums_ptr = append_cumsums.data_ptr<int>();
 | 
			
		||||
  const int* append_cumsums_cpu_ptr = append_cumsums_cpu.data_ptr<int>();
 | 
			
		||||
  int* append_block_ids_ptr = append_block_ids.data_ptr<int>();
 | 
			
		||||
  const int* append_block_ids_cpu_ptr = append_block_ids_cpu.data_ptr<int>();
 | 
			
		||||
  int* block_table_ptr = block_table.data_ptr<int>();
 | 
			
		||||
 | 
			
		||||
  const at::cuda::OptionalCUDAGuard device_guard(device_of(block_table));
 | 
			
		||||
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
			
		||||
  cudaMemcpyAsync(append_row_indices_ptr, append_row_indices_cpu_ptr,
 | 
			
		||||
                  num_appends * 2 * sizeof(int), cudaMemcpyHostToDevice, stream);
 | 
			
		||||
  cudaMemcpyAsync(append_cumsums_ptr, append_cumsums_cpu_ptr,
 | 
			
		||||
                  (num_appends + 1) * sizeof(int), cudaMemcpyHostToDevice, stream);
 | 
			
		||||
  cudaMemcpyAsync(append_block_ids_ptr, append_block_ids_cpu_ptr,
 | 
			
		||||
                  total_num_append_blocks * sizeof(int), cudaMemcpyHostToDevice, stream);
 | 
			
		||||
 | 
			
		||||
  int64_t max_num_blocks_per_row = block_table.size(1);
 | 
			
		||||
  vllm::append_kernel<<<num_appends, 1024, 0, stream>>>(
 | 
			
		||||
      append_row_indices_ptr, append_cumsums_ptr, append_block_ids_ptr,
 | 
			
		||||
      block_table_ptr, max_num_blocks_per_row);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void block_table_moves(
 | 
			
		||||
    torch::Tensor& src_dst_n,
 | 
			
		||||
    torch::Tensor& src_dst_n_cpu,
 | 
			
		||||
    torch::Tensor& block_table,
 | 
			
		||||
    int64_t num_moves) {
 | 
			
		||||
  int* src_dst_n_ptr = src_dst_n.data_ptr<int>();
 | 
			
		||||
  const int* src_dst_n_cpu_ptr = src_dst_n_cpu.data_ptr<int>();
 | 
			
		||||
  int* block_table_ptr = block_table.data_ptr<int>();
 | 
			
		||||
 | 
			
		||||
  const at::cuda::OptionalCUDAGuard device_guard(device_of(block_table));
 | 
			
		||||
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
			
		||||
  cudaMemcpyAsync(src_dst_n_ptr, src_dst_n_cpu_ptr,
 | 
			
		||||
                  num_moves * 3 * sizeof(int), cudaMemcpyHostToDevice, stream);
 | 
			
		||||
 | 
			
		||||
  int64_t max_num_blocks_per_row = block_table.size(1);
 | 
			
		||||
  vllm::move_kernel<<<num_moves, 1024, 0, stream>>>(
 | 
			
		||||
      src_dst_n_ptr, block_table_ptr, max_num_blocks_per_row);
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										12
									
								
								csrc/ops.h
									
									
									
									
									
								
							
							
						
						
									
										12
									
								
								csrc/ops.h
									
									
									
									
									
								
							@ -119,6 +119,18 @@ void advance_step_flashinfer(
 | 
			
		||||
    torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr,
 | 
			
		||||
    torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bounds);
 | 
			
		||||
 | 
			
		||||
void block_table_appends(torch::Tensor& append_row_indices,
 | 
			
		||||
                         torch::Tensor& append_row_indices_cpu,
 | 
			
		||||
                         torch::Tensor& append_cumsums,
 | 
			
		||||
                         torch::Tensor& append_cumsums_cpu,
 | 
			
		||||
                         torch::Tensor& append_block_ids,
 | 
			
		||||
                         torch::Tensor& append_block_ids_cpu,
 | 
			
		||||
                         torch::Tensor& block_table, int64_t num_appends,
 | 
			
		||||
                         int64_t total_num_append_blocks);
 | 
			
		||||
 | 
			
		||||
void block_table_moves(torch::Tensor& src_dst_n, torch::Tensor& src_dst_n_cpu,
 | 
			
		||||
                       torch::Tensor& block_table, int64_t num_moves);
 | 
			
		||||
 | 
			
		||||
#ifndef USE_ROCM
 | 
			
		||||
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
 | 
			
		||||
                        const torch::Tensor& codebooks,
 | 
			
		||||
 | 
			
		||||
@ -111,6 +111,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
 | 
			
		||||
      ") -> ()");
 | 
			
		||||
  ops.impl("advance_step_flashinfer", torch::kCUDA, &advance_step_flashinfer);
 | 
			
		||||
 | 
			
		||||
  ops.def(
 | 
			
		||||
      "block_table_appends(Tensor append_row_indices, "
 | 
			
		||||
      "Tensor append_row_indices_cpu, Tensor append_cumsums, "
 | 
			
		||||
      "Tensor append_cumsums_cpu, Tensor append_block_ids, "
 | 
			
		||||
      "Tensor append_block_ids_cpu, Tensor! block_table, int num_appends, "
 | 
			
		||||
      "int total_num_append_blocks) -> ()");
 | 
			
		||||
  ops.impl("block_table_appends", torch::kCUDA, &block_table_appends);
 | 
			
		||||
 | 
			
		||||
  ops.def(
 | 
			
		||||
      "block_table_moves(Tensor src_dst_n, Tensor src_dst_n_cpu, "
 | 
			
		||||
      "Tensor! block_table, int num_moves) -> ()");
 | 
			
		||||
  ops.impl("block_table_moves", torch::kCUDA, &block_table_moves);
 | 
			
		||||
 | 
			
		||||
  // Layernorm
 | 
			
		||||
  // Apply Root Mean Square (RMS) Normalization to the input tensor.
 | 
			
		||||
  ops.def(
 | 
			
		||||
 | 
			
		||||
@ -202,6 +202,33 @@ def advance_step_flashinfer(num_seqs: int, num_queries: int, block_size: int,
 | 
			
		||||
        block_table_bound)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def block_table_appends(
 | 
			
		||||
    append_row_indices: torch.Tensor,
 | 
			
		||||
    append_row_indices_cpu: torch.Tensor,
 | 
			
		||||
    append_cumsums: torch.Tensor,
 | 
			
		||||
    append_cumsums_cpu: torch.Tensor,
 | 
			
		||||
    append_block_ids: torch.Tensor,
 | 
			
		||||
    append_block_ids_cpu: torch.Tensor,
 | 
			
		||||
    block_table: torch.Tensor,
 | 
			
		||||
    num_appends: int,
 | 
			
		||||
    total_num_append_blocks: int,
 | 
			
		||||
) -> None:
 | 
			
		||||
    torch.ops._C.block_table_appends.default(
 | 
			
		||||
        append_row_indices, append_row_indices_cpu, append_cumsums,
 | 
			
		||||
        append_cumsums_cpu, append_block_ids, append_block_ids_cpu,
 | 
			
		||||
        block_table, num_appends, total_num_append_blocks)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def block_table_moves(
 | 
			
		||||
    src_dst_n: torch.Tensor,
 | 
			
		||||
    src_dst_n_cpu: torch.Tensor,
 | 
			
		||||
    block_table: torch.Tensor,
 | 
			
		||||
    num_moves: int,
 | 
			
		||||
) -> None:
 | 
			
		||||
    torch.ops._C.block_table_moves.default(src_dst_n, src_dst_n_cpu,
 | 
			
		||||
                                           block_table, num_moves)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# fused quant layer norm ops
 | 
			
		||||
def rms_norm_dynamic_per_token_quant(
 | 
			
		||||
    input: torch.Tensor,
 | 
			
		||||
 | 
			
		||||
@ -9,6 +9,7 @@ logger = init_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class BlockTable:
 | 
			
		||||
    """Device-agnostic block table for storing block IDs for each request."""
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										164
									
								
								vllm/v1/worker/gpu_block_table.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										164
									
								
								vllm/v1/worker/gpu_block_table.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,164 @@
 | 
			
		||||
# SPDX-License-Identifier: Apache-2.0
 | 
			
		||||
from typing import List, Set
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from vllm import _custom_ops as ops
 | 
			
		||||
from vllm.logger import init_logger
 | 
			
		||||
 | 
			
		||||
logger = init_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class GPUBlockTable:
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        max_num_reqs: int,
 | 
			
		||||
        max_num_blocks_per_req: int,
 | 
			
		||||
        pin_memory: bool,
 | 
			
		||||
        device: torch.device,
 | 
			
		||||
    ):
 | 
			
		||||
        self.max_num_reqs = max_num_reqs
 | 
			
		||||
        self.max_num_blocks_per_req = max_num_blocks_per_req
 | 
			
		||||
        self.pin_memory = pin_memory
 | 
			
		||||
        self.device = device
 | 
			
		||||
 | 
			
		||||
        self.block_table = torch.zeros(
 | 
			
		||||
            (max_num_reqs, max_num_blocks_per_req),
 | 
			
		||||
            device=self.device,
 | 
			
		||||
            dtype=torch.int32,
 | 
			
		||||
        )
 | 
			
		||||
        self.block_table_cpu = torch.zeros(
 | 
			
		||||
            (max_num_reqs, max_num_blocks_per_req),
 | 
			
		||||
            device="cpu",
 | 
			
		||||
            dtype=torch.int32,
 | 
			
		||||
            pin_memory=False,
 | 
			
		||||
        )
 | 
			
		||||
        self.block_table_np = self.block_table_cpu.numpy()
 | 
			
		||||
        self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32)
 | 
			
		||||
 | 
			
		||||
        self.block_table_diff_np = np.zeros(
 | 
			
		||||
            (max_num_reqs, 2),
 | 
			
		||||
            dtype=np.int32,
 | 
			
		||||
        )
 | 
			
		||||
        self.diff_rows: Set[int] = set()
 | 
			
		||||
 | 
			
		||||
        self.append_row_indices = torch.zeros(
 | 
			
		||||
            (max_num_reqs, 2),
 | 
			
		||||
            dtype=torch.int32,
 | 
			
		||||
            device=self.device,
 | 
			
		||||
        )
 | 
			
		||||
        self.append_row_indices_cpu = torch.zeros_like(
 | 
			
		||||
            self.append_row_indices,
 | 
			
		||||
            device="cpu",
 | 
			
		||||
            pin_memory=pin_memory,
 | 
			
		||||
        )
 | 
			
		||||
        self.append_row_indices_np = self.append_row_indices_cpu.numpy()
 | 
			
		||||
        self.append_cumsums = torch.zeros(
 | 
			
		||||
            (max_num_reqs + 1, ),
 | 
			
		||||
            dtype=torch.int32,
 | 
			
		||||
            device=self.device,
 | 
			
		||||
        )
 | 
			
		||||
        self.append_cumsums_cpu = torch.zeros_like(
 | 
			
		||||
            self.append_cumsums,
 | 
			
		||||
            device="cpu",
 | 
			
		||||
            pin_memory=pin_memory,
 | 
			
		||||
        )
 | 
			
		||||
        self.append_cumsums_np = self.append_cumsums_cpu.numpy()
 | 
			
		||||
        self.append_data = torch.zeros(
 | 
			
		||||
            (max_num_reqs * max_num_blocks_per_req, ),
 | 
			
		||||
            dtype=torch.int32,
 | 
			
		||||
            device=self.device,
 | 
			
		||||
        )
 | 
			
		||||
        self.append_data_cpu = torch.zeros_like(
 | 
			
		||||
            self.append_data,
 | 
			
		||||
            device="cpu",
 | 
			
		||||
            pin_memory=pin_memory,
 | 
			
		||||
        )
 | 
			
		||||
        self.append_data_np = self.append_data_cpu.numpy()
 | 
			
		||||
 | 
			
		||||
    def append_row(
 | 
			
		||||
        self,
 | 
			
		||||
        row_idx: int,
 | 
			
		||||
        start: int,
 | 
			
		||||
        block_ids: List[int],
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        num_blocks = len(block_ids)
 | 
			
		||||
        self.block_table_np[row_idx, start:start + num_blocks] = block_ids
 | 
			
		||||
        self.num_blocks_per_row[row_idx] = start + num_blocks
 | 
			
		||||
 | 
			
		||||
        self.block_table_diff_np[row_idx, 0] = start
 | 
			
		||||
        self.block_table_diff_np[row_idx, 1] = num_blocks
 | 
			
		||||
        self.diff_rows.add(row_idx)
 | 
			
		||||
 | 
			
		||||
    def add_row(self, row_idx: int, block_ids: List[int]) -> None:
 | 
			
		||||
        self.append_row(row_idx, 0, block_ids)
 | 
			
		||||
 | 
			
		||||
    def move_row(self, src: int, tgt: int) -> None:
 | 
			
		||||
        num_blocks = self.num_blocks_per_row[src]
 | 
			
		||||
        self.block_table_np[tgt, :num_blocks] = self.block_table_np[
 | 
			
		||||
            src, :num_blocks]
 | 
			
		||||
        self.num_blocks_per_row[tgt] = num_blocks
 | 
			
		||||
 | 
			
		||||
        self.block_table_diff_np[tgt, 0] = 0
 | 
			
		||||
        self.block_table_diff_np[tgt, 1] = num_blocks
 | 
			
		||||
        self.diff_rows.discard(src)
 | 
			
		||||
        self.diff_rows.add(tgt)
 | 
			
		||||
 | 
			
		||||
    def commit(self, num_reqs: int) -> None:
 | 
			
		||||
        if not self.diff_rows:
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        cu_end = 0
 | 
			
		||||
        self.append_cumsums_np[0] = 0
 | 
			
		||||
        for i, row_idx in enumerate(self.diff_rows):
 | 
			
		||||
            start, num_blocks = self.block_table_diff_np[row_idx]
 | 
			
		||||
            assert num_blocks > 0
 | 
			
		||||
 | 
			
		||||
            self.append_row_indices_np[i, 0] = row_idx
 | 
			
		||||
            self.append_row_indices_np[i, 1] = start
 | 
			
		||||
            cu_start = self.append_cumsums_np[i]
 | 
			
		||||
            cu_end = cu_start + num_blocks
 | 
			
		||||
            self.append_cumsums_np[i + 1] = cu_end
 | 
			
		||||
            self.append_data_np[cu_start:cu_end] = self.block_table_np[
 | 
			
		||||
                row_idx, start:start + num_blocks]
 | 
			
		||||
 | 
			
		||||
        ops.block_table_appends(
 | 
			
		||||
            self.append_row_indices,
 | 
			
		||||
            self.append_row_indices_cpu,
 | 
			
		||||
            self.append_cumsums,
 | 
			
		||||
            self.append_cumsums_cpu,
 | 
			
		||||
            self.append_data,
 | 
			
		||||
            self.append_data_cpu,
 | 
			
		||||
            self.block_table,
 | 
			
		||||
            len(self.diff_rows),
 | 
			
		||||
            cu_end,
 | 
			
		||||
        )
 | 
			
		||||
        self.diff_rows.clear()
 | 
			
		||||
 | 
			
		||||
    def clear(self) -> None:
 | 
			
		||||
        self.block_table.fill_(0)
 | 
			
		||||
        self.block_table_cpu.fill_(0)
 | 
			
		||||
 | 
			
		||||
        self.diff_rows.clear()
 | 
			
		||||
        self.block_table_diff_np.fill(0)
 | 
			
		||||
 | 
			
		||||
        self.append_row_indices.fill_(0)
 | 
			
		||||
        self.append_row_indices_cpu.fill_(0)
 | 
			
		||||
        self.append_cumsums.fill_(0)
 | 
			
		||||
        self.append_cumsums_cpu.fill_(0)
 | 
			
		||||
        self.append_data.fill_(0)
 | 
			
		||||
        self.append_data_cpu.fill_(0)
 | 
			
		||||
 | 
			
		||||
    def get_device_tensor(self) -> torch.Tensor:
 | 
			
		||||
        """Ruturns the device tensor of the block table."""
 | 
			
		||||
        return self.block_table
 | 
			
		||||
 | 
			
		||||
    def get_cpu_tensor(self) -> torch.Tensor:
 | 
			
		||||
        """Returns the CPU tensor of the block table."""
 | 
			
		||||
        return self.block_table_cpu
 | 
			
		||||
 | 
			
		||||
    def get_numpy_array(self) -> np.ndarray:
 | 
			
		||||
        """Returns the numpy array of the block table."""
 | 
			
		||||
        return self.block_table_np
 | 
			
		||||
@ -14,7 +14,7 @@ from vllm.utils import swap_dict_values
 | 
			
		||||
from vllm.v1.outputs import LogprobsTensors
 | 
			
		||||
from vllm.v1.sample.metadata import SamplingMetadata
 | 
			
		||||
from vllm.v1.utils import copy_slice
 | 
			
		||||
from vllm.v1.worker.block_table import BlockTable
 | 
			
		||||
from vllm.v1.worker.gpu_block_table import GPUBlockTable
 | 
			
		||||
 | 
			
		||||
_SAMPLING_EPS = 1e-5
 | 
			
		||||
 | 
			
		||||
@ -92,7 +92,7 @@ class InputBatch:
 | 
			
		||||
            self.num_computed_tokens_cpu_tensor.numpy()
 | 
			
		||||
 | 
			
		||||
        # Block table.
 | 
			
		||||
        self.block_table = BlockTable(
 | 
			
		||||
        self.block_table = GPUBlockTable(
 | 
			
		||||
            max_num_reqs=max_num_reqs,
 | 
			
		||||
            max_num_blocks_per_req=max_num_blocks_per_req,
 | 
			
		||||
            pin_memory=pin_memory,
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user