Compare commits

...

22 Commits

Author SHA1 Message Date
7097f31955 test
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-01-15 03:22:32 -08:00
f840b53063 fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-01-15 03:07:17 -08:00
1ca4298b9b Fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-01-01 18:44:21 -08:00
ba64a0249f Minor
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-01-01 18:42:22 -08:00
1260e43230 Minor
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-01-01 03:16:56 -08:00
a6e5d7b5b7 Merge branch 'main' into v1-blocktable-opt 2025-01-01 03:10:50 -08:00
ebfbe1244b ruff
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2024-12-26 20:06:53 -08:00
6ba31aa5f6 Minor
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2024-12-26 19:03:59 -08:00
34d6cc2aea Merge branch 'main' into v1-blocktable-opt 2024-12-26 18:52:19 -08:00
27e8eb2e94 Add kernel test
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2024-12-26 11:23:52 -08:00
ca4f9e69a8 minor
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2024-12-26 11:13:41 -08:00
52922193cd Add test for uva
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2024-12-26 11:00:19 -08:00
bef68163a0 Minor
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2024-12-26 10:48:29 -08:00
ff5b1033dc Merge branch 'main' into v1-blocktable-opt 2024-12-26 10:12:17 -08:00
b938606993 Merge branch 'main' into v1-blocktable-opt 2024-12-25 15:49:02 -08:00
3fdbd8e2f5 comments
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2024-12-22 22:39:03 -08:00
0420fb2c7b Merge branch 'main' into v1-blocktable-opt 2024-12-22 22:16:22 -08:00
ee965c9c69 Use default
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2024-12-22 22:16:12 -08:00
0a669eed7b Minor
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2024-12-21 17:39:13 -08:00
03b1e6fdbd Minor
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2024-12-21 17:28:21 -08:00
8a4180c8b6 yapf
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2024-12-21 17:11:00 -08:00
1aaced5830 wip
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2024-12-21 17:07:46 -08:00
14 changed files with 477 additions and 24 deletions

View File

@ -193,6 +193,7 @@ set(VLLM_EXT_SRC
"csrc/activation_kernels.cu"
"csrc/layernorm_kernels.cu"
"csrc/layernorm_quant_kernels.cu"
"csrc/cuda_view.cu"
"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
"csrc/quantization/fp8/common.cu"
@ -200,6 +201,7 @@ set(VLLM_EXT_SRC
"csrc/quantization/gguf/gguf_kernel.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/prepare_inputs/advance_step.cu"
"csrc/prepare_inputs/copy_subranges.cu"
"csrc/torch_bindings.cpp")
if(VLLM_GPU_LANG STREQUAL "CUDA")

View File

@ -47,3 +47,11 @@
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL)
#endif
// #ifndef USE_ROCM
// #define VLLM_cudaHostGetDevicePointer(device_ptr, host_ptr, flags) \
// cudaHostGetDevicePointer(device_ptr, host_ptr, flags)
// #else
// #define VLLM_cudaHostGetDevicePointer(device_ptr, host_ptr, flags) \
// hipHostGetDevicePointer(device_ptr, host_ptr, flags)
// #endif

43
csrc/cuda_view.cu Normal file
View File

@ -0,0 +1,43 @@
#include <torch/all.h>
#include <torch/cuda.h>
// This function assumes that `cpu_tensor` is a CPU tensor allocated with pinned
// memory, and that UVA (Unified Virtual Addressing) is enabled.
torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor) {
TORCH_CHECK(cpu_tensor.device().is_cpu(), "Input tensor must be on CPU");
TORCH_CHECK(cpu_tensor.is_contiguous(), "Input tensor must be contiguous");
// Get raw host pointer from CPU tensor
void* host_ptr = cpu_tensor.data_ptr();
// Get a device pointer corresponding to the pinned host memory
void* device_ptr = nullptr;
cudaError_t err = cudaHostGetDevicePointer(&device_ptr, host_ptr, 0);
TORCH_CHECK(err == cudaSuccess,
"cudaHostGetDevicePointer failed: ", cudaGetErrorString(err));
// Construct a CUDA tensor from the device pointer.
// We'll use the same sizes, strides, and dtype as the CPU tensor.
auto sizes = cpu_tensor.sizes();
auto strides = cpu_tensor.strides();
auto options =
cpu_tensor.options().device(torch::kCUDA); // Change device to CUDA
// from_blob signature: from_blob(void *data, IntArrayRef sizes, ..., Deleter,
// const TensorOptions &) Provide a no-op deleter. The CPU tensor holds the
// memory, so we don't free it here.
auto deleter = [](void*) {
// no-op, since the memory is owned by the original CPU tensor
};
torch::Tensor cuda_tensor =
torch::from_blob(device_ptr, sizes, strides, deleter, options);
TORCH_CHECK(cuda_tensor.device().is_cuda(),
"Resulting tensor is not on CUDA device");
TORCH_CHECK(cuda_tensor.sizes().equals(sizes), "Size mismatch");
TORCH_CHECK(cuda_tensor.strides().equals(strides), "Stride mismatch");
TORCH_CHECK(cuda_tensor.dtype() == cpu_tensor.dtype(), "Dtype mismatch");
return cuda_tensor;
}

View File

@ -115,6 +115,11 @@ void advance_step_flashinfer(
torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr,
torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bounds);
void copy_subranges(torch::Tensor& matrix_src, torch::Tensor& matrix_diff,
torch::Tensor& matrix_tgt, int64_t n);
torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor);
#ifndef USE_ROCM
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
const torch::Tensor& codebooks,

View File

@ -0,0 +1,75 @@
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
namespace vllm {
__global__ void copy_subranges_kernel(const int* __restrict__ matrix_src,
const int* __restrict__ matrix_diff,
int* __restrict__ matrix_tgt, int64_t M) {
int row_id = blockIdx.x;
int row_offset = row_id * M;
int start = matrix_diff[row_id * 2];
int length = matrix_diff[row_id * 2 + 1];
int end = start + length;
int thread_idx = threadIdx.x;
for (int i = start + thread_idx; i < end; i += blockDim.x) {
int idx = row_offset + i;
matrix_tgt[idx] = matrix_src[idx];
}
}
} // namespace vllm
void copy_subranges(torch::Tensor& matrix_src, torch::Tensor& matrix_diff,
torch::Tensor& matrix_tgt, int64_t n) {
// NOTE(woosuk): Here, we skip most of the error checking to minimize the
// CPU overheads. We assume that the caller will pass the correct inputs.
// Check tensor properties
// TORCH_CHECK(matrix_src.is_cuda(), "matrix_src must be a CUDA tensor");
// TORCH_CHECK(matrix_diff.is_cuda(), "matrix_diff must be a CUDA tensor");
// TORCH_CHECK(matrix_tgt.is_cuda(), "matrix_tgt must be a CUDA tensor");
// TORCH_CHECK(matrix_src.is_contiguous(), "matrix_src must be contiguous");
// TORCH_CHECK(matrix_diff.is_contiguous(), "matrix_diff must be contiguous");
// TORCH_CHECK(matrix_tgt.is_contiguous(), "matrix_tgt must be contiguous");
auto src_sizes = matrix_src.sizes();
auto diff_sizes = matrix_diff.sizes();
auto tgt_sizes = matrix_tgt.sizes();
// TORCH_CHECK(src_sizes.size() == 2, "matrix_src must be 2D");
// TORCH_CHECK(diff_sizes.size() == 2, "matrix_diff must be 2D");
// TORCH_CHECK(tgt_sizes.size() == 2, "matrix_tgt must be 2D");
int64_t N = src_sizes[0];
int64_t M = src_sizes[1];
// TORCH_CHECK(diff_sizes[0] == N, "matrix_diff first dim must match N");
// TORCH_CHECK(diff_sizes[1] == 2, "matrix_diff second dim must be 2");
// TORCH_CHECK(tgt_sizes[0] == N && tgt_sizes[1] == M,
// "matrix_tgt must have same shape as matrix_src");
// TORCH_CHECK(n <= N, "n must be <= N");
const int* d_matrix_src = matrix_src.data_ptr<int>();
const int* d_matrix_diff = matrix_diff.data_ptr<int>();
int* d_matrix_tgt = matrix_tgt.data_ptr<int>();
// One thread block per row.
int blocks = n;
int threads;
if (blocks < 128) {
threads = 1024;
} else if (blocks < 256) {
threads = 512;
} else if (blocks < 512) {
threads = 256;
} else {
threads = 128;
}
const at::cuda::OptionalCUDAGuard device_guard(device_of(matrix_tgt));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
vllm::copy_subranges_kernel<<<blocks, threads, 0, stream>>>(
d_matrix_src, d_matrix_diff, d_matrix_tgt, M);
}

View File

@ -21,6 +21,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("weak_ref_tensor(Tensor input) -> Tensor");
ops.impl("weak_ref_tensor", torch::kCUDA, &weak_ref_tensor);
ops.def("get_cuda_view_from_cpu_tensor(Tensor cpu_tensor) -> Tensor");
ops.impl("get_cuda_view_from_cpu_tensor", torch::kCPU,
&get_cuda_view_from_cpu_tensor);
// Attention ops
// Compute the attention between an input query and the cached
// keys/values using PagedAttention.
@ -98,6 +102,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
") -> ()");
ops.impl("advance_step_flashinfer", torch::kCUDA, &advance_step_flashinfer);
ops.def(
"copy_subranges(Tensor matrix_src, Tensor matrix_diff, Tensor! "
"matrix_tgt, int n) -> ()");
ops.impl("copy_subranges", torch::kCUDA, &copy_subranges);
// Layernorm
// Apply Root Mean Square (RMS) Normalization to the input tensor.
ops.def(

View File

@ -0,0 +1,47 @@
import random
import pytest
import torch
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
SEEDS = [0]
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_copy_subranges(seed, device):
torch.set_default_device(device)
current_platform.seed_everything(seed)
num_rows = 1024
num_cols = 1024
src_matrix = torch.zeros(num_rows,
num_cols,
device=device,
dtype=torch.int32)
dst_matrix = torch.zeros(num_rows,
num_cols,
device=device,
dtype=torch.int32)
diff_matrix = torch.zeros(num_rows, 2, device=device, dtype=torch.int32)
for i in range(num_rows):
start_idx = random.randint(0, num_cols - 1)
end_idx = random.randint(start_idx, num_cols - 1)
num_diffs = end_idx - start_idx
src_matrix[i, start_idx:end_idx] = torch.randint(0,
100, (num_diffs, ),
device=device,
dtype=torch.int32)
diff_matrix[i, 0] = start_idx
diff_matrix[i, 1] = num_diffs
ops.copy_subranges(src_matrix, diff_matrix, dst_matrix, num_rows)
assert torch.allclose(src_matrix, dst_matrix, rtol=0, atol=0)

60
tests/kernels/test_uva.py Normal file
View File

@ -0,0 +1,60 @@
import pytest
import torch
from vllm.utils import get_cuda_view_from_cpu_tensor, is_uva_available
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
@pytest.mark.skipif(not is_uva_available(), reason="UVA is not available.")
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_cpu_write(device):
torch.set_default_device(device)
cpu_tensor = torch.zeros(10,
10,
device="cpu",
pin_memory=True,
dtype=torch.int32)
cuda_view = get_cuda_view_from_cpu_tensor(cpu_tensor)
assert cuda_view.device.type == "cuda"
assert cuda_view[0, 0] == 0
assert cuda_view[2, 3] == 0
assert cuda_view[4, 5] == 0
cpu_tensor[0, 0] = 1
cpu_tensor[2, 3] = 2
cpu_tensor[4, 5] = -1
cuda_view.mul_(2)
assert cuda_view[0, 0] == 2
assert cuda_view[2, 3] == 4
assert cuda_view[4, 5] == -2
@pytest.mark.skipif(not is_uva_available(), reason="UVA is not available.")
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_gpu_write(device):
torch.set_default_device(device)
cpu_tensor = torch.zeros(10,
10,
device="cpu",
pin_memory=True,
dtype=torch.int32)
cuda_view = get_cuda_view_from_cpu_tensor(cpu_tensor)
assert cuda_view.device.type == "cuda"
assert cuda_view[0, 0] == 0
assert cuda_view[2, 3] == 0
assert cuda_view[4, 5] == 0
cuda_view[0, 0] = 1
cuda_view[2, 3] = 2
cuda_view[4, 5] = -1
cuda_view.mul_(2)
assert cpu_tensor[0, 0] == 2
assert cpu_tensor[2, 3] == 4
assert cpu_tensor[4, 5] == -2

View File

@ -0,0 +1,52 @@
import pytest
import random
import time
import torch
from vllm.v1.worker.gpu_block_table import BlockTable
MAX_NUM_REQS = 1024
MAX_MODEL_LEN = 128 * 1024
BLOCK_SIZE = 16
MAX_NUM_BLOCKS_PER_REQ = MAX_MODEL_LEN // BLOCK_SIZE
def test_block_table(do_wait: bool):
random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
block_table = BlockTable(
max_num_reqs=MAX_NUM_REQS,
max_model_len=MAX_MODEL_LEN,
max_num_blocks_per_req=MAX_NUM_BLOCKS_PER_REQ,
pin_memory=True,
device=torch.device(0),
)
num_blocks = random.randint(1, MAX_NUM_BLOCKS_PER_REQ - 1)
block_ids = torch.randint(0, MAX_NUM_BLOCKS_PER_REQ, (num_blocks,), dtype=torch.int32, device="cpu")
block_table.add_row(0, block_ids)
num_blocks = random.randint(1, MAX_NUM_BLOCKS_PER_REQ - 100)
block_ids = torch.randint(0, MAX_NUM_BLOCKS_PER_REQ, (num_blocks,), dtype=torch.int32, device="cpu")
block_table.add_row(1, block_ids)
block_table.commit(2)
torch.cuda.synchronize()
if do_wait:
time.sleep(1)
block_ids = torch.randint(0, MAX_NUM_BLOCKS_PER_REQ, (100,), dtype=torch.int32, device="cpu")
block_table.append_row(1, num_blocks, block_ids)
block_table.move_row(1, 0)
block_table.commit(2)
torch.cuda.synchronize()
if do_wait:
time.sleep(1)
torch.testing.assert_close(block_table.block_table[:1].cpu(), block_table.block_table_cpu[:1])
if __name__ == "__main__":
test_block_table(do_wait=False)

View File

@ -219,6 +219,20 @@ def advance_step_flashinfer(num_seqs: int, num_queries: int, block_size: int,
block_table_bound)
# copy subrange op. Used for input preparation in the vLLM V1 GPU backend.
def copy_subranges(
src_matrix: torch.Tensor,
diff_matrix: torch.Tensor,
tgt_matrix: torch.Tensor,
num_subranges: int,
) -> None:
# NOTE(woosuk): We use `torch.ops._C.copy_subranges.default` instead of
# `torch.ops._C.copy_subranges` to avoid unnecessary CPU overheads from
# the dispatcher.
torch.ops._C.copy_subranges.default(src_matrix, diff_matrix, tgt_matrix,
num_subranges)
# fused quant layer norm ops
def rms_norm_dynamic_per_token_quant(
input: torch.Tensor,

View File

@ -707,6 +707,14 @@ def is_pin_memory_available() -> bool:
return current_platform.is_pin_memory_available()
@lru_cache(maxsize=None)
def is_uva_available() -> bool:
"""Check if Unified Virtual Addressing (UVA) is available."""
# UVA requires pinned memory.
# TODO(woosuk): Add more requirements for UVA.
return is_pin_memory_available()
class DeviceMemoryProfiler:
def __init__(self, device: Optional[torch.types.Device] = None):
@ -1557,6 +1565,14 @@ def weak_ref_tensors(
raise ValueError("Invalid type for tensors")
def get_cuda_view_from_cpu_tensor(cpu_tensor: torch.Tensor) -> torch.Tensor:
"""
Get a CUDA view of a CPU tensor using Unified Virtual Addressing (UVA).
"""
assert cpu_tensor.is_pinned(), "CPU tensor must be pinned"
return torch.ops._C.get_cuda_view_from_cpu_tensor(cpu_tensor)
def is_in_doc_build() -> bool:
try:
from sphinx.ext.autodoc.mock import _MockModule

View File

@ -0,0 +1,131 @@
from typing import List
import numpy as np
import torch
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.utils import get_cuda_view_from_cpu_tensor, is_uva_available
logger = init_logger(__name__)
class BlockTable:
def __init__(
self,
max_num_reqs: int,
max_model_len: int,
max_num_blocks_per_req: int,
pin_memory: bool,
device: torch.device,
):
self.max_num_reqs = max_num_reqs
self.max_model_len = max_model_len
self.max_num_blocks_per_req = max_num_blocks_per_req
self.pin_memory = pin_memory
self.device = device
self.block_table = torch.zeros(
(max_num_reqs, max_num_blocks_per_req),
device=self.device,
dtype=torch.int32,
)
self.block_table_cpu = torch.zeros(
(max_num_reqs, max_num_blocks_per_req),
device="cpu",
dtype=torch.int32,
pin_memory=pin_memory,
)
self.block_table_np = self.block_table_cpu.numpy()
self.num_blocks_per_row = np.zeros((max_num_reqs,), dtype=np.int32)
# UVA requires pinned memory.
self.use_uva = is_uva_available() and pin_memory
if self.use_uva:
logger.info("Using Unified Virtual Addressing (UVA) for block "
"table transfer.")
self.block_table_diff = torch.zeros((max_num_reqs, 2),
dtype=torch.int32,
device="cpu",
pin_memory=True)
self.block_table_diff_np = self.block_table_diff.numpy()
self.block_table_cpu_cuda_view = get_cuda_view_from_cpu_tensor(
self.block_table_cpu)
self.block_table_diff_cuda_view = get_cuda_view_from_cpu_tensor(
self.block_table_diff)
else:
logger.warning("Unified Virtual Addressing (UVA) is not supported "
"in the current environment. This may result in "
"lower performance.")
def add_row(self, row_idx: int, block_ids: List[int]) -> None:
num_blocks = len(block_ids)
self.block_table_np[row_idx, :num_blocks] = block_ids
self.num_blocks_per_row[row_idx] = num_blocks
if self.use_uva:
self.block_table_diff_np[row_idx, 0] = 0
self.block_table_diff_np[row_idx, 1] = num_blocks
def append_row(
self,
row_idx: int,
start: int,
block_ids: List[int],
) -> None:
num_blocks = len(block_ids)
self.block_table_np[row_idx, start:start + num_blocks] = block_ids
self.num_blocks_per_row[row_idx] = start + num_blocks
if self.use_uva:
self.block_table_diff_np[row_idx, 0] = start
self.block_table_diff_np[row_idx, 1] = num_blocks
def move_row(self, src: int, tgt: int) -> None:
num_blocks = self.num_blocks_per_row[src]
self.block_table_np[tgt, :num_blocks] = \
self.block_table_np[src, :num_blocks]
self.num_blocks_per_row[tgt] = num_blocks
if self.use_uva:
# Append-and-move is allowed.
self.block_table_diff_np[tgt, 0] = 0
self.block_table_diff_np[tgt, 1] = num_blocks
# Clear the source row.
self.block_table_diff_np[src].fill(0)
def commit(self, num_reqs: int) -> None:
if self.use_uva:
# Only copy the diff to the GPU.
ops.copy_subranges(
self.block_table_cpu_cuda_view,
self.block_table_diff_cuda_view,
self.block_table,
num_reqs,
)
else:
# Copy the entire block table to the GPU.
# NOTE(woosuk): This can be a performance bottleneck when the block
# table is large.
self.block_table[:num_reqs].copy_(self.block_table_cpu[:num_reqs],
non_blocking=True)
self.clear_diff()
def clear(self) -> None:
self.block_table.fill_(0)
self.block_table_cpu.fill_(0)
self.num_blocks_per_row.fill(0)
if self.use_uva:
self.block_table_diff.fill_(0)
def clear_diff(self) -> None:
if self.use_uva:
self.block_table_diff_np.fill(0)
def cuda(self) -> torch.Tensor:
return self.block_table
def cpu(self) -> torch.Tensor:
return self.block_table_cpu
def numpy(self) -> np.ndarray:
return self.block_table_np

View File

@ -9,6 +9,7 @@ import torch
from vllm.multimodal import MultiModalKwargs
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.worker.gpu_block_table import BlockTable
if TYPE_CHECKING:
from vllm.multimodal.inputs import PlaceholderRange
@ -69,19 +70,14 @@ class InputBatch:
self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32)
self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
# Attention-related.
self.block_table = torch.zeros(
(max_num_reqs, max_num_blocks_per_req),
device=self.device,
dtype=torch.int32,
)
self.block_table_cpu_tensor = torch.zeros(
(max_num_reqs, max_num_blocks_per_req),
device="cpu",
dtype=torch.int32,
# Block table.
self.block_table = BlockTable(
max_num_reqs=max_num_reqs,
max_model_len=max_model_len,
max_num_blocks_per_req=max_num_blocks_per_req,
pin_memory=pin_memory,
device=device,
)
self.block_table_cpu = self.block_table_cpu_tensor.numpy()
# Sampling-related.
self.temperature = torch.empty((max_num_reqs, ),
@ -191,8 +187,7 @@ class InputBatch:
start_idx:end_idx] = request.output_token_ids
self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
num_blocks = len(request.block_ids)
self.block_table_cpu[req_index, :num_blocks] = request.block_ids
self.block_table.add_row(req_index, request.block_ids)
sampling_params = request.sampling_params
self.temperature_cpu[req_index] = sampling_params.temperature
@ -291,15 +286,14 @@ class InputBatch:
self.req_id_to_index[req_id] = empty_index
# TODO(woosuk): Optimize the copy of token_ids_cpu and
# block_table_cpu.
# block_table.
self.token_ids_cpu[empty_index] = self.token_ids_cpu[
last_req_index]
self.num_prompt_tokens[empty_index] = \
self.num_prompt_tokens[last_req_index]
self.num_computed_tokens_cpu[
empty_index] = self.num_computed_tokens_cpu[last_req_index]
self.block_table_cpu[empty_index] = self.block_table_cpu[
last_req_index]
self.block_table.move_row(last_req_index, empty_index)
self.temperature_cpu[empty_index] = self.temperature_cpu[
last_req_index]
self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]

View File

@ -204,10 +204,9 @@ class GPUModelRunner:
if num_new_blocks == 0:
continue
start_index = len(req_state.block_ids)
end_index = start_index + num_new_blocks
req_state.block_ids.extend(req_data.new_block_ids)
self.input_batch.block_table_cpu[
req_index, start_index:end_index] = req_data.new_block_ids
self.input_batch.block_table.append_row(req_index, start_index,
req_data.new_block_ids)
req_ids_to_add: List[str] = []
# Add new requests to the cached states.
@ -268,9 +267,7 @@ class GPUModelRunner:
# OPTIMIZATION: Start copying the block table first.
# This way, we can overlap the copy with the following CPU operations.
self.input_batch.block_table[:num_reqs].copy_(
self.input_batch.block_table_cpu_tensor[:num_reqs],
non_blocking=True)
self.input_batch.block_table.commit(num_reqs)
# Get the number of scheduled tokens for each request.
# TODO: The Python loop can be slow. Optimize.
@ -326,7 +323,7 @@ class GPUModelRunner:
# NOTE(woosuk): We use torch.index_select instead of np.take here
# because torch.index_select is much faster than np.take for large
# tensors.
block_numbers = (self.input_batch.block_table_cpu_tensor.flatten()
block_numbers = (self.input_batch.block_table.cpu().flatten()
[block_table_indices].numpy())
block_offsets = positions_np % self.block_size
np.add(block_numbers * self.block_size,
@ -361,7 +358,7 @@ class GPUModelRunner:
query_start_loc=query_start_loc,
max_seq_len=max_seq_len,
seq_start_loc=seq_start_loc,
block_table=self.input_batch.block_table[:num_reqs],
block_table=self.input_batch.block_table.cuda()[:num_reqs],
slot_mapping=slot_mapping,
)
# NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial