mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-21 07:13:52 +08:00
Compare commits
22 Commits
v0.7.2
...
v1-blockta
Author | SHA1 | Date | |
---|---|---|---|
7097f31955 | |||
f840b53063 | |||
1ca4298b9b | |||
ba64a0249f | |||
1260e43230 | |||
a6e5d7b5b7 | |||
ebfbe1244b | |||
6ba31aa5f6 | |||
34d6cc2aea | |||
27e8eb2e94 | |||
ca4f9e69a8 | |||
52922193cd | |||
bef68163a0 | |||
ff5b1033dc | |||
b938606993 | |||
3fdbd8e2f5 | |||
0420fb2c7b | |||
ee965c9c69 | |||
0a669eed7b | |||
03b1e6fdbd | |||
8a4180c8b6 | |||
1aaced5830 |
@ -193,6 +193,7 @@ set(VLLM_EXT_SRC
|
||||
"csrc/activation_kernels.cu"
|
||||
"csrc/layernorm_kernels.cu"
|
||||
"csrc/layernorm_quant_kernels.cu"
|
||||
"csrc/cuda_view.cu"
|
||||
"csrc/quantization/gptq/q_gemm.cu"
|
||||
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
|
||||
"csrc/quantization/fp8/common.cu"
|
||||
@ -200,6 +201,7 @@ set(VLLM_EXT_SRC
|
||||
"csrc/quantization/gguf/gguf_kernel.cu"
|
||||
"csrc/cuda_utils_kernels.cu"
|
||||
"csrc/prepare_inputs/advance_step.cu"
|
||||
"csrc/prepare_inputs/copy_subranges.cu"
|
||||
"csrc/torch_bindings.cpp")
|
||||
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
|
@ -47,3 +47,11 @@
|
||||
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
|
||||
hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL)
|
||||
#endif
|
||||
|
||||
// #ifndef USE_ROCM
|
||||
// #define VLLM_cudaHostGetDevicePointer(device_ptr, host_ptr, flags) \
|
||||
// cudaHostGetDevicePointer(device_ptr, host_ptr, flags)
|
||||
// #else
|
||||
// #define VLLM_cudaHostGetDevicePointer(device_ptr, host_ptr, flags) \
|
||||
// hipHostGetDevicePointer(device_ptr, host_ptr, flags)
|
||||
// #endif
|
||||
|
43
csrc/cuda_view.cu
Normal file
43
csrc/cuda_view.cu
Normal file
@ -0,0 +1,43 @@
|
||||
#include <torch/all.h>
|
||||
#include <torch/cuda.h>
|
||||
|
||||
// This function assumes that `cpu_tensor` is a CPU tensor allocated with pinned
|
||||
// memory, and that UVA (Unified Virtual Addressing) is enabled.
|
||||
torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor) {
|
||||
TORCH_CHECK(cpu_tensor.device().is_cpu(), "Input tensor must be on CPU");
|
||||
TORCH_CHECK(cpu_tensor.is_contiguous(), "Input tensor must be contiguous");
|
||||
|
||||
// Get raw host pointer from CPU tensor
|
||||
void* host_ptr = cpu_tensor.data_ptr();
|
||||
|
||||
// Get a device pointer corresponding to the pinned host memory
|
||||
void* device_ptr = nullptr;
|
||||
cudaError_t err = cudaHostGetDevicePointer(&device_ptr, host_ptr, 0);
|
||||
TORCH_CHECK(err == cudaSuccess,
|
||||
"cudaHostGetDevicePointer failed: ", cudaGetErrorString(err));
|
||||
|
||||
// Construct a CUDA tensor from the device pointer.
|
||||
// We'll use the same sizes, strides, and dtype as the CPU tensor.
|
||||
auto sizes = cpu_tensor.sizes();
|
||||
auto strides = cpu_tensor.strides();
|
||||
auto options =
|
||||
cpu_tensor.options().device(torch::kCUDA); // Change device to CUDA
|
||||
|
||||
// from_blob signature: from_blob(void *data, IntArrayRef sizes, ..., Deleter,
|
||||
// const TensorOptions &) Provide a no-op deleter. The CPU tensor holds the
|
||||
// memory, so we don't free it here.
|
||||
auto deleter = [](void*) {
|
||||
// no-op, since the memory is owned by the original CPU tensor
|
||||
};
|
||||
|
||||
torch::Tensor cuda_tensor =
|
||||
torch::from_blob(device_ptr, sizes, strides, deleter, options);
|
||||
|
||||
TORCH_CHECK(cuda_tensor.device().is_cuda(),
|
||||
"Resulting tensor is not on CUDA device");
|
||||
TORCH_CHECK(cuda_tensor.sizes().equals(sizes), "Size mismatch");
|
||||
TORCH_CHECK(cuda_tensor.strides().equals(strides), "Stride mismatch");
|
||||
TORCH_CHECK(cuda_tensor.dtype() == cpu_tensor.dtype(), "Dtype mismatch");
|
||||
|
||||
return cuda_tensor;
|
||||
}
|
@ -115,6 +115,11 @@ void advance_step_flashinfer(
|
||||
torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr,
|
||||
torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bounds);
|
||||
|
||||
void copy_subranges(torch::Tensor& matrix_src, torch::Tensor& matrix_diff,
|
||||
torch::Tensor& matrix_tgt, int64_t n);
|
||||
|
||||
torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
|
||||
const torch::Tensor& codebooks,
|
||||
|
75
csrc/prepare_inputs/copy_subranges.cu
Normal file
75
csrc/prepare_inputs/copy_subranges.cu
Normal file
@ -0,0 +1,75 @@
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
namespace vllm {
|
||||
__global__ void copy_subranges_kernel(const int* __restrict__ matrix_src,
|
||||
const int* __restrict__ matrix_diff,
|
||||
int* __restrict__ matrix_tgt, int64_t M) {
|
||||
int row_id = blockIdx.x;
|
||||
int row_offset = row_id * M;
|
||||
|
||||
int start = matrix_diff[row_id * 2];
|
||||
int length = matrix_diff[row_id * 2 + 1];
|
||||
int end = start + length;
|
||||
int thread_idx = threadIdx.x;
|
||||
for (int i = start + thread_idx; i < end; i += blockDim.x) {
|
||||
int idx = row_offset + i;
|
||||
matrix_tgt[idx] = matrix_src[idx];
|
||||
}
|
||||
}
|
||||
} // namespace vllm
|
||||
|
||||
void copy_subranges(torch::Tensor& matrix_src, torch::Tensor& matrix_diff,
|
||||
torch::Tensor& matrix_tgt, int64_t n) {
|
||||
// NOTE(woosuk): Here, we skip most of the error checking to minimize the
|
||||
// CPU overheads. We assume that the caller will pass the correct inputs.
|
||||
|
||||
// Check tensor properties
|
||||
// TORCH_CHECK(matrix_src.is_cuda(), "matrix_src must be a CUDA tensor");
|
||||
// TORCH_CHECK(matrix_diff.is_cuda(), "matrix_diff must be a CUDA tensor");
|
||||
// TORCH_CHECK(matrix_tgt.is_cuda(), "matrix_tgt must be a CUDA tensor");
|
||||
// TORCH_CHECK(matrix_src.is_contiguous(), "matrix_src must be contiguous");
|
||||
// TORCH_CHECK(matrix_diff.is_contiguous(), "matrix_diff must be contiguous");
|
||||
// TORCH_CHECK(matrix_tgt.is_contiguous(), "matrix_tgt must be contiguous");
|
||||
|
||||
auto src_sizes = matrix_src.sizes();
|
||||
auto diff_sizes = matrix_diff.sizes();
|
||||
auto tgt_sizes = matrix_tgt.sizes();
|
||||
|
||||
// TORCH_CHECK(src_sizes.size() == 2, "matrix_src must be 2D");
|
||||
// TORCH_CHECK(diff_sizes.size() == 2, "matrix_diff must be 2D");
|
||||
// TORCH_CHECK(tgt_sizes.size() == 2, "matrix_tgt must be 2D");
|
||||
|
||||
int64_t N = src_sizes[0];
|
||||
int64_t M = src_sizes[1];
|
||||
|
||||
// TORCH_CHECK(diff_sizes[0] == N, "matrix_diff first dim must match N");
|
||||
// TORCH_CHECK(diff_sizes[1] == 2, "matrix_diff second dim must be 2");
|
||||
// TORCH_CHECK(tgt_sizes[0] == N && tgt_sizes[1] == M,
|
||||
// "matrix_tgt must have same shape as matrix_src");
|
||||
|
||||
// TORCH_CHECK(n <= N, "n must be <= N");
|
||||
|
||||
const int* d_matrix_src = matrix_src.data_ptr<int>();
|
||||
const int* d_matrix_diff = matrix_diff.data_ptr<int>();
|
||||
int* d_matrix_tgt = matrix_tgt.data_ptr<int>();
|
||||
|
||||
// One thread block per row.
|
||||
int blocks = n;
|
||||
int threads;
|
||||
if (blocks < 128) {
|
||||
threads = 1024;
|
||||
} else if (blocks < 256) {
|
||||
threads = 512;
|
||||
} else if (blocks < 512) {
|
||||
threads = 256;
|
||||
} else {
|
||||
threads = 128;
|
||||
}
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(matrix_tgt));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
vllm::copy_subranges_kernel<<<blocks, threads, 0, stream>>>(
|
||||
d_matrix_src, d_matrix_diff, d_matrix_tgt, M);
|
||||
}
|
@ -21,6 +21,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.def("weak_ref_tensor(Tensor input) -> Tensor");
|
||||
ops.impl("weak_ref_tensor", torch::kCUDA, &weak_ref_tensor);
|
||||
|
||||
ops.def("get_cuda_view_from_cpu_tensor(Tensor cpu_tensor) -> Tensor");
|
||||
ops.impl("get_cuda_view_from_cpu_tensor", torch::kCPU,
|
||||
&get_cuda_view_from_cpu_tensor);
|
||||
|
||||
// Attention ops
|
||||
// Compute the attention between an input query and the cached
|
||||
// keys/values using PagedAttention.
|
||||
@ -98,6 +102,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
") -> ()");
|
||||
ops.impl("advance_step_flashinfer", torch::kCUDA, &advance_step_flashinfer);
|
||||
|
||||
ops.def(
|
||||
"copy_subranges(Tensor matrix_src, Tensor matrix_diff, Tensor! "
|
||||
"matrix_tgt, int n) -> ()");
|
||||
ops.impl("copy_subranges", torch::kCUDA, ©_subranges);
|
||||
|
||||
// Layernorm
|
||||
// Apply Root Mean Square (RMS) Normalization to the input tensor.
|
||||
ops.def(
|
||||
|
47
tests/kernels/test_copy_subranges.py
Normal file
47
tests/kernels/test_copy_subranges.py
Normal file
@ -0,0 +1,47 @@
|
||||
import random
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
SEEDS = [0]
|
||||
CUDA_DEVICES = [
|
||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_copy_subranges(seed, device):
|
||||
torch.set_default_device(device)
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
num_rows = 1024
|
||||
num_cols = 1024
|
||||
src_matrix = torch.zeros(num_rows,
|
||||
num_cols,
|
||||
device=device,
|
||||
dtype=torch.int32)
|
||||
dst_matrix = torch.zeros(num_rows,
|
||||
num_cols,
|
||||
device=device,
|
||||
dtype=torch.int32)
|
||||
diff_matrix = torch.zeros(num_rows, 2, device=device, dtype=torch.int32)
|
||||
|
||||
for i in range(num_rows):
|
||||
start_idx = random.randint(0, num_cols - 1)
|
||||
end_idx = random.randint(start_idx, num_cols - 1)
|
||||
num_diffs = end_idx - start_idx
|
||||
|
||||
src_matrix[i, start_idx:end_idx] = torch.randint(0,
|
||||
100, (num_diffs, ),
|
||||
device=device,
|
||||
dtype=torch.int32)
|
||||
|
||||
diff_matrix[i, 0] = start_idx
|
||||
diff_matrix[i, 1] = num_diffs
|
||||
|
||||
ops.copy_subranges(src_matrix, diff_matrix, dst_matrix, num_rows)
|
||||
assert torch.allclose(src_matrix, dst_matrix, rtol=0, atol=0)
|
60
tests/kernels/test_uva.py
Normal file
60
tests/kernels/test_uva.py
Normal file
@ -0,0 +1,60 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.utils import get_cuda_view_from_cpu_tensor, is_uva_available
|
||||
|
||||
CUDA_DEVICES = [
|
||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_uva_available(), reason="UVA is not available.")
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_cpu_write(device):
|
||||
torch.set_default_device(device)
|
||||
cpu_tensor = torch.zeros(10,
|
||||
10,
|
||||
device="cpu",
|
||||
pin_memory=True,
|
||||
dtype=torch.int32)
|
||||
cuda_view = get_cuda_view_from_cpu_tensor(cpu_tensor)
|
||||
assert cuda_view.device.type == "cuda"
|
||||
|
||||
assert cuda_view[0, 0] == 0
|
||||
assert cuda_view[2, 3] == 0
|
||||
assert cuda_view[4, 5] == 0
|
||||
|
||||
cpu_tensor[0, 0] = 1
|
||||
cpu_tensor[2, 3] = 2
|
||||
cpu_tensor[4, 5] = -1
|
||||
|
||||
cuda_view.mul_(2)
|
||||
assert cuda_view[0, 0] == 2
|
||||
assert cuda_view[2, 3] == 4
|
||||
assert cuda_view[4, 5] == -2
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_uva_available(), reason="UVA is not available.")
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_gpu_write(device):
|
||||
torch.set_default_device(device)
|
||||
cpu_tensor = torch.zeros(10,
|
||||
10,
|
||||
device="cpu",
|
||||
pin_memory=True,
|
||||
dtype=torch.int32)
|
||||
cuda_view = get_cuda_view_from_cpu_tensor(cpu_tensor)
|
||||
assert cuda_view.device.type == "cuda"
|
||||
|
||||
assert cuda_view[0, 0] == 0
|
||||
assert cuda_view[2, 3] == 0
|
||||
assert cuda_view[4, 5] == 0
|
||||
|
||||
cuda_view[0, 0] = 1
|
||||
cuda_view[2, 3] = 2
|
||||
cuda_view[4, 5] = -1
|
||||
cuda_view.mul_(2)
|
||||
|
||||
assert cpu_tensor[0, 0] == 2
|
||||
assert cpu_tensor[2, 3] == 4
|
||||
assert cpu_tensor[4, 5] == -2
|
52
tests/v1/worker/test_gpu_block_table.py
Normal file
52
tests/v1/worker/test_gpu_block_table.py
Normal file
@ -0,0 +1,52 @@
|
||||
import pytest
|
||||
import random
|
||||
import time
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.v1.worker.gpu_block_table import BlockTable
|
||||
|
||||
MAX_NUM_REQS = 1024
|
||||
MAX_MODEL_LEN = 128 * 1024
|
||||
BLOCK_SIZE = 16
|
||||
MAX_NUM_BLOCKS_PER_REQ = MAX_MODEL_LEN // BLOCK_SIZE
|
||||
|
||||
|
||||
def test_block_table(do_wait: bool):
|
||||
random.seed(0)
|
||||
torch.manual_seed(0)
|
||||
torch.cuda.manual_seed_all(0)
|
||||
|
||||
block_table = BlockTable(
|
||||
max_num_reqs=MAX_NUM_REQS,
|
||||
max_model_len=MAX_MODEL_LEN,
|
||||
max_num_blocks_per_req=MAX_NUM_BLOCKS_PER_REQ,
|
||||
pin_memory=True,
|
||||
device=torch.device(0),
|
||||
)
|
||||
|
||||
num_blocks = random.randint(1, MAX_NUM_BLOCKS_PER_REQ - 1)
|
||||
block_ids = torch.randint(0, MAX_NUM_BLOCKS_PER_REQ, (num_blocks,), dtype=torch.int32, device="cpu")
|
||||
block_table.add_row(0, block_ids)
|
||||
num_blocks = random.randint(1, MAX_NUM_BLOCKS_PER_REQ - 100)
|
||||
block_ids = torch.randint(0, MAX_NUM_BLOCKS_PER_REQ, (num_blocks,), dtype=torch.int32, device="cpu")
|
||||
block_table.add_row(1, block_ids)
|
||||
block_table.commit(2)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
if do_wait:
|
||||
time.sleep(1)
|
||||
|
||||
block_ids = torch.randint(0, MAX_NUM_BLOCKS_PER_REQ, (100,), dtype=torch.int32, device="cpu")
|
||||
block_table.append_row(1, num_blocks, block_ids)
|
||||
block_table.move_row(1, 0)
|
||||
block_table.commit(2)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
if do_wait:
|
||||
time.sleep(1)
|
||||
|
||||
torch.testing.assert_close(block_table.block_table[:1].cpu(), block_table.block_table_cpu[:1])
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_block_table(do_wait=False)
|
@ -219,6 +219,20 @@ def advance_step_flashinfer(num_seqs: int, num_queries: int, block_size: int,
|
||||
block_table_bound)
|
||||
|
||||
|
||||
# copy subrange op. Used for input preparation in the vLLM V1 GPU backend.
|
||||
def copy_subranges(
|
||||
src_matrix: torch.Tensor,
|
||||
diff_matrix: torch.Tensor,
|
||||
tgt_matrix: torch.Tensor,
|
||||
num_subranges: int,
|
||||
) -> None:
|
||||
# NOTE(woosuk): We use `torch.ops._C.copy_subranges.default` instead of
|
||||
# `torch.ops._C.copy_subranges` to avoid unnecessary CPU overheads from
|
||||
# the dispatcher.
|
||||
torch.ops._C.copy_subranges.default(src_matrix, diff_matrix, tgt_matrix,
|
||||
num_subranges)
|
||||
|
||||
|
||||
# fused quant layer norm ops
|
||||
def rms_norm_dynamic_per_token_quant(
|
||||
input: torch.Tensor,
|
||||
|
@ -707,6 +707,14 @@ def is_pin_memory_available() -> bool:
|
||||
return current_platform.is_pin_memory_available()
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def is_uva_available() -> bool:
|
||||
"""Check if Unified Virtual Addressing (UVA) is available."""
|
||||
# UVA requires pinned memory.
|
||||
# TODO(woosuk): Add more requirements for UVA.
|
||||
return is_pin_memory_available()
|
||||
|
||||
|
||||
class DeviceMemoryProfiler:
|
||||
|
||||
def __init__(self, device: Optional[torch.types.Device] = None):
|
||||
@ -1557,6 +1565,14 @@ def weak_ref_tensors(
|
||||
raise ValueError("Invalid type for tensors")
|
||||
|
||||
|
||||
def get_cuda_view_from_cpu_tensor(cpu_tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Get a CUDA view of a CPU tensor using Unified Virtual Addressing (UVA).
|
||||
"""
|
||||
assert cpu_tensor.is_pinned(), "CPU tensor must be pinned"
|
||||
return torch.ops._C.get_cuda_view_from_cpu_tensor(cpu_tensor)
|
||||
|
||||
|
||||
def is_in_doc_build() -> bool:
|
||||
try:
|
||||
from sphinx.ext.autodoc.mock import _MockModule
|
||||
|
131
vllm/v1/worker/gpu_block_table.py
Normal file
131
vllm/v1/worker/gpu_block_table.py
Normal file
@ -0,0 +1,131 @@
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import get_cuda_view_from_cpu_tensor, is_uva_available
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class BlockTable:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_num_reqs: int,
|
||||
max_model_len: int,
|
||||
max_num_blocks_per_req: int,
|
||||
pin_memory: bool,
|
||||
device: torch.device,
|
||||
):
|
||||
self.max_num_reqs = max_num_reqs
|
||||
self.max_model_len = max_model_len
|
||||
self.max_num_blocks_per_req = max_num_blocks_per_req
|
||||
self.pin_memory = pin_memory
|
||||
self.device = device
|
||||
|
||||
self.block_table = torch.zeros(
|
||||
(max_num_reqs, max_num_blocks_per_req),
|
||||
device=self.device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
self.block_table_cpu = torch.zeros(
|
||||
(max_num_reqs, max_num_blocks_per_req),
|
||||
device="cpu",
|
||||
dtype=torch.int32,
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
self.block_table_np = self.block_table_cpu.numpy()
|
||||
self.num_blocks_per_row = np.zeros((max_num_reqs,), dtype=np.int32)
|
||||
|
||||
# UVA requires pinned memory.
|
||||
self.use_uva = is_uva_available() and pin_memory
|
||||
if self.use_uva:
|
||||
logger.info("Using Unified Virtual Addressing (UVA) for block "
|
||||
"table transfer.")
|
||||
self.block_table_diff = torch.zeros((max_num_reqs, 2),
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
pin_memory=True)
|
||||
self.block_table_diff_np = self.block_table_diff.numpy()
|
||||
|
||||
self.block_table_cpu_cuda_view = get_cuda_view_from_cpu_tensor(
|
||||
self.block_table_cpu)
|
||||
self.block_table_diff_cuda_view = get_cuda_view_from_cpu_tensor(
|
||||
self.block_table_diff)
|
||||
else:
|
||||
logger.warning("Unified Virtual Addressing (UVA) is not supported "
|
||||
"in the current environment. This may result in "
|
||||
"lower performance.")
|
||||
|
||||
def add_row(self, row_idx: int, block_ids: List[int]) -> None:
|
||||
num_blocks = len(block_ids)
|
||||
self.block_table_np[row_idx, :num_blocks] = block_ids
|
||||
self.num_blocks_per_row[row_idx] = num_blocks
|
||||
if self.use_uva:
|
||||
self.block_table_diff_np[row_idx, 0] = 0
|
||||
self.block_table_diff_np[row_idx, 1] = num_blocks
|
||||
|
||||
def append_row(
|
||||
self,
|
||||
row_idx: int,
|
||||
start: int,
|
||||
block_ids: List[int],
|
||||
) -> None:
|
||||
num_blocks = len(block_ids)
|
||||
self.block_table_np[row_idx, start:start + num_blocks] = block_ids
|
||||
self.num_blocks_per_row[row_idx] = start + num_blocks
|
||||
if self.use_uva:
|
||||
self.block_table_diff_np[row_idx, 0] = start
|
||||
self.block_table_diff_np[row_idx, 1] = num_blocks
|
||||
|
||||
def move_row(self, src: int, tgt: int) -> None:
|
||||
num_blocks = self.num_blocks_per_row[src]
|
||||
self.block_table_np[tgt, :num_blocks] = \
|
||||
self.block_table_np[src, :num_blocks]
|
||||
self.num_blocks_per_row[tgt] = num_blocks
|
||||
if self.use_uva:
|
||||
# Append-and-move is allowed.
|
||||
self.block_table_diff_np[tgt, 0] = 0
|
||||
self.block_table_diff_np[tgt, 1] = num_blocks
|
||||
# Clear the source row.
|
||||
self.block_table_diff_np[src].fill(0)
|
||||
|
||||
def commit(self, num_reqs: int) -> None:
|
||||
if self.use_uva:
|
||||
# Only copy the diff to the GPU.
|
||||
ops.copy_subranges(
|
||||
self.block_table_cpu_cuda_view,
|
||||
self.block_table_diff_cuda_view,
|
||||
self.block_table,
|
||||
num_reqs,
|
||||
)
|
||||
else:
|
||||
# Copy the entire block table to the GPU.
|
||||
# NOTE(woosuk): This can be a performance bottleneck when the block
|
||||
# table is large.
|
||||
self.block_table[:num_reqs].copy_(self.block_table_cpu[:num_reqs],
|
||||
non_blocking=True)
|
||||
self.clear_diff()
|
||||
|
||||
def clear(self) -> None:
|
||||
self.block_table.fill_(0)
|
||||
self.block_table_cpu.fill_(0)
|
||||
self.num_blocks_per_row.fill(0)
|
||||
if self.use_uva:
|
||||
self.block_table_diff.fill_(0)
|
||||
|
||||
def clear_diff(self) -> None:
|
||||
if self.use_uva:
|
||||
self.block_table_diff_np.fill(0)
|
||||
|
||||
def cuda(self) -> torch.Tensor:
|
||||
return self.block_table
|
||||
|
||||
def cpu(self) -> torch.Tensor:
|
||||
return self.block_table_cpu
|
||||
|
||||
def numpy(self) -> np.ndarray:
|
||||
return self.block_table_np
|
@ -9,6 +9,7 @@ import torch
|
||||
from vllm.multimodal import MultiModalKwargs
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.worker.gpu_block_table import BlockTable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.multimodal.inputs import PlaceholderRange
|
||||
@ -69,19 +70,14 @@ class InputBatch:
|
||||
self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32)
|
||||
self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
|
||||
|
||||
# Attention-related.
|
||||
self.block_table = torch.zeros(
|
||||
(max_num_reqs, max_num_blocks_per_req),
|
||||
device=self.device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
self.block_table_cpu_tensor = torch.zeros(
|
||||
(max_num_reqs, max_num_blocks_per_req),
|
||||
device="cpu",
|
||||
dtype=torch.int32,
|
||||
# Block table.
|
||||
self.block_table = BlockTable(
|
||||
max_num_reqs=max_num_reqs,
|
||||
max_model_len=max_model_len,
|
||||
max_num_blocks_per_req=max_num_blocks_per_req,
|
||||
pin_memory=pin_memory,
|
||||
device=device,
|
||||
)
|
||||
self.block_table_cpu = self.block_table_cpu_tensor.numpy()
|
||||
|
||||
# Sampling-related.
|
||||
self.temperature = torch.empty((max_num_reqs, ),
|
||||
@ -191,8 +187,7 @@ class InputBatch:
|
||||
start_idx:end_idx] = request.output_token_ids
|
||||
|
||||
self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
|
||||
num_blocks = len(request.block_ids)
|
||||
self.block_table_cpu[req_index, :num_blocks] = request.block_ids
|
||||
self.block_table.add_row(req_index, request.block_ids)
|
||||
|
||||
sampling_params = request.sampling_params
|
||||
self.temperature_cpu[req_index] = sampling_params.temperature
|
||||
@ -291,15 +286,14 @@ class InputBatch:
|
||||
self.req_id_to_index[req_id] = empty_index
|
||||
|
||||
# TODO(woosuk): Optimize the copy of token_ids_cpu and
|
||||
# block_table_cpu.
|
||||
# block_table.
|
||||
self.token_ids_cpu[empty_index] = self.token_ids_cpu[
|
||||
last_req_index]
|
||||
self.num_prompt_tokens[empty_index] = \
|
||||
self.num_prompt_tokens[last_req_index]
|
||||
self.num_computed_tokens_cpu[
|
||||
empty_index] = self.num_computed_tokens_cpu[last_req_index]
|
||||
self.block_table_cpu[empty_index] = self.block_table_cpu[
|
||||
last_req_index]
|
||||
self.block_table.move_row(last_req_index, empty_index)
|
||||
self.temperature_cpu[empty_index] = self.temperature_cpu[
|
||||
last_req_index]
|
||||
self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
|
||||
|
@ -204,10 +204,9 @@ class GPUModelRunner:
|
||||
if num_new_blocks == 0:
|
||||
continue
|
||||
start_index = len(req_state.block_ids)
|
||||
end_index = start_index + num_new_blocks
|
||||
req_state.block_ids.extend(req_data.new_block_ids)
|
||||
self.input_batch.block_table_cpu[
|
||||
req_index, start_index:end_index] = req_data.new_block_ids
|
||||
self.input_batch.block_table.append_row(req_index, start_index,
|
||||
req_data.new_block_ids)
|
||||
|
||||
req_ids_to_add: List[str] = []
|
||||
# Add new requests to the cached states.
|
||||
@ -268,9 +267,7 @@ class GPUModelRunner:
|
||||
|
||||
# OPTIMIZATION: Start copying the block table first.
|
||||
# This way, we can overlap the copy with the following CPU operations.
|
||||
self.input_batch.block_table[:num_reqs].copy_(
|
||||
self.input_batch.block_table_cpu_tensor[:num_reqs],
|
||||
non_blocking=True)
|
||||
self.input_batch.block_table.commit(num_reqs)
|
||||
|
||||
# Get the number of scheduled tokens for each request.
|
||||
# TODO: The Python loop can be slow. Optimize.
|
||||
@ -326,7 +323,7 @@ class GPUModelRunner:
|
||||
# NOTE(woosuk): We use torch.index_select instead of np.take here
|
||||
# because torch.index_select is much faster than np.take for large
|
||||
# tensors.
|
||||
block_numbers = (self.input_batch.block_table_cpu_tensor.flatten()
|
||||
block_numbers = (self.input_batch.block_table.cpu().flatten()
|
||||
[block_table_indices].numpy())
|
||||
block_offsets = positions_np % self.block_size
|
||||
np.add(block_numbers * self.block_size,
|
||||
@ -361,7 +358,7 @@ class GPUModelRunner:
|
||||
query_start_loc=query_start_loc,
|
||||
max_seq_len=max_seq_len,
|
||||
seq_start_loc=seq_start_loc,
|
||||
block_table=self.input_batch.block_table[:num_reqs],
|
||||
block_table=self.input_batch.block_table.cuda()[:num_reqs],
|
||||
slot_mapping=slot_mapping,
|
||||
)
|
||||
# NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial
|
||||
|
Reference in New Issue
Block a user