Compare commits

...

23 Commits

Author SHA1 Message Date
c11d1e6781 optimize spec
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-31 16:40:54 -07:00
e696f78e05 minor
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-31 13:29:58 -07:00
efcb786d52 merge
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-31 10:44:36 -07:00
9ee9d0e274 fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-28 15:02:07 -07:00
405578121c minor
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-28 13:19:10 -07:00
19c0dfc469 minor
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-28 13:08:07 -07:00
e451045a66 fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-28 12:55:13 -07:00
efba25e21a minor
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-28 12:39:15 -07:00
b21393cd98 Merge branch 'main' into woosuk/input-prep 2025-08-28 09:58:08 -07:00
d6d719fb24 Merge branch 'main' into woosuk/input-prep 2025-08-28 09:57:49 -07:00
e570b0a4de merge
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-27 21:45:11 -07:00
a851aaa0fc simplify
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-25 09:23:05 -07:00
b1d52734f7 fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-25 08:55:12 -07:00
65f93694be merge
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-25 08:54:32 -07:00
7b4b72e551 fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-24 18:49:23 -07:00
da9cd26c78 Merge branch 'main' into woosuk/input-prep 2025-08-24 18:36:33 -07:00
a1e3745150 wip
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-24 18:36:18 -07:00
48bca9a109 merge
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-23 11:30:29 -07:00
64c8cced18 rename
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-22 01:48:35 -07:00
79e5eb3643 wip
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-22 01:37:43 -07:00
c472982746 merge
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-21 21:40:44 -07:00
699bd7928e Merge branch 'main' into woosuk/input-prep 2025-08-17 19:28:38 -07:00
33a3a26ca5 wip
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-17 14:38:24 -07:00
7 changed files with 1034 additions and 1325 deletions

View File

@ -12,12 +12,12 @@ from vllm.v1.sample.logits_processor import LogitsProcessors
@dataclass
class SamplingMetadata:
temperature: Optional[torch.Tensor]
temperature: torch.Tensor
all_greedy: bool
all_random: bool
top_p: Optional[torch.Tensor]
top_k: Optional[torch.Tensor]
top_p: torch.Tensor
top_k: torch.Tensor
generators: dict[int, torch.Generator]
@ -25,12 +25,13 @@ class SamplingMetadata:
max_num_logprobs: Optional[int]
no_penalties: bool
prompt_token_ids: Optional[torch.Tensor]
frequency_penalties: torch.Tensor
presence_penalties: torch.Tensor
repetition_penalties: torch.Tensor
output_token_ids: list[list[int]]
token_ids: Optional[torch.Tensor]
num_tokens: Optional[torch.Tensor]
num_prompt_tokens: Optional[torch.Tensor]
# `allowed_token_ids_mask` is a 2D bool tensor of shape (max batch size,
# vocab size).

View File

@ -90,9 +90,9 @@ class Sampler(nn.Module):
# Apply bad words exclusion.
logits = self.apply_bad_words(logits, sampling_metadata)
# Apply logits processors which can impact greedy sampling
for processor in sampling_metadata.logitsprocs.non_argmax_invariant:
logits = processor.apply(logits)
# # Apply logits processors which can impact greedy sampling
# for processor in sampling_metadata.logitsprocs.non_argmax_invariant:
# logits = processor.apply(logits)
# Apply penalties (e.g., min_tokens, freq_penalties).
logits = self.apply_penalties(logits, sampling_metadata)
@ -167,10 +167,10 @@ class Sampler(nn.Module):
# Apply temperature.
logits = self.apply_temperature(logits, sampling_metadata.temperature)
# Apply logits processors that only apply to random sampling
# (argmax invariant)
for processor in sampling_metadata.logitsprocs.argmax_invariant:
logits = processor.apply(logits)
# # Apply logits processors that only apply to random sampling
# # (argmax invariant)
# for processor in sampling_metadata.logitsprocs.argmax_invariant:
# logits = processor.apply(logits)
# Apply top_k and/or top_p.
random_sampled, processed_logprobs = self.topk_topp_sampler(

View File

@ -0,0 +1,309 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
import torch
import triton
import triton.language as tl
from vllm.utils import cdiv
from vllm.v1.utils import CpuGpuBuffer
PAD_SLOT_ID = -1
class BlockTables:
def __init__(
self,
block_sizes: list[int],
max_num_reqs: int,
max_num_cached_reqs: int,
max_num_batched_tokens: int,
max_model_len: int,
device: torch.device,
pin_memory: bool,
):
self.block_sizes = block_sizes
self.max_num_reqs = max_num_reqs
self.max_num_cached_reqs = max_num_cached_reqs
self.max_num_batched_tokens = max_num_batched_tokens
self.max_model_len = max_model_len
self.device = device
self.pin_memory = pin_memory
self.num_kv_cache_groups = len(self.block_sizes)
# [num_kv_cache_groups, max_num_reqs, max_num_blocks]
self.block_tables: list[torch.Tensor] = []
# [num_kv_cache_groups, max_num_cached_reqs, max_num_blocks]
self.block_table_buffers: list[torch.Tensor] = []
for i in range(self.num_kv_cache_groups):
block_size = self.block_sizes[i]
max_num_blocks = cdiv(self.max_model_len, block_size)
block_table = torch.zeros(
self.max_num_reqs,
max_num_blocks,
dtype=torch.int32,
device=self.device,
)
self.block_tables.append(block_table)
block_table_buffer = torch.zeros(
self.max_num_cached_reqs,
max_num_blocks,
dtype=torch.int32,
device=self.device,
)
self.block_table_buffers.append(block_table_buffer)
self.block_table_ptrs = self._make_ptr_tensor(self.block_tables)
self.buffer_ptrs = self._make_ptr_tensor(self.block_table_buffers)
self.block_table_strides = torch.tensor(
[b.stride(0) for b in self.block_tables],
dtype=torch.int64,
device=self.device)
self.block_sizes_tensor = torch.tensor(self.block_sizes,
dtype=torch.int32,
device=self.device)
self.num_blocks = torch.zeros(self.num_kv_cache_groups,
self.max_num_reqs,
dtype=torch.int32,
device=self.device)
self.slot_mappings = torch.zeros(self.num_kv_cache_groups,
self.max_num_batched_tokens,
dtype=torch.int64,
device=self.device)
# Misc buffers.
self.req_indices = self._make_buffer(self.max_num_reqs,
dtype=torch.int32)
self.overwrite = self._make_buffer(self.max_num_reqs, dtype=torch.bool)
self.cu_num_new_blocks = self._make_buffer(self.num_kv_cache_groups,
self.max_num_reqs + 1,
dtype=torch.int32)
# NOTE(woosuk): Here, we assume that total number of new blocks
# is ALWAYS less than max_num_batched_tokens.
# TODO(woosuk): Rigorously verify that this assumption is correct.
self.new_block_ids = self._make_buffer(self.num_kv_cache_groups,
self.max_num_batched_tokens,
dtype=torch.int32)
def _make_buffer(self, *args, dtype: torch.dtype) -> CpuGpuBuffer:
return CpuGpuBuffer(*args,
dtype=dtype,
pin_memory=self.pin_memory,
device=self.device)
def _make_ptr_tensor(self, x: Iterable[torch.Tensor]) -> torch.Tensor:
ptrs_tensor_cpu = torch.tensor([t.data_ptr() for t in x],
dtype=torch.int64,
device="cpu",
pin_memory=self.pin_memory)
return ptrs_tensor_cpu.to(self.device, non_blocking=True)
def append_block_ids(
self,
# [num_reqs]
req_indices: list[int],
# [num_kv_cache_groups, num_reqs + 1]
cu_num_new_blocks: list[list[int]],
# [num_kv_cache_groups, num_new_blocks]
new_block_ids: list[list[int]],
# [num_reqs]
overwrite: list[bool],
) -> None:
# TODO(woosuk): Optimize & simplify this.
num_reqs = len(req_indices)
self.req_indices.np[:num_reqs] = req_indices
self.overwrite.np[:num_reqs] = overwrite
for i in range(self.num_kv_cache_groups):
self.cu_num_new_blocks.np[i, :num_reqs + 1] = cu_num_new_blocks[i]
n = len(new_block_ids[i])
self.new_block_ids.np[i, :n] = new_block_ids[i]
_append_block_ids_kernel[(num_reqs, self.num_kv_cache_groups)](
self.req_indices.copy_to_gpu(num_reqs),
self.cu_num_new_blocks.copy_to_gpu(),
self.cu_num_new_blocks.gpu.stride(0),
self.new_block_ids.copy_to_gpu(),
self.new_block_ids.gpu.stride(0),
self.overwrite.copy_to_gpu(num_reqs),
self.block_table_strides,
self.buffer_ptrs,
self.num_blocks,
self.num_blocks.stride(0),
BLOCK_SIZE=1024,
)
def compute_block_tables(
self,
idx_mapping: torch.Tensor,
) -> tuple[torch.Tensor, ...]:
batch_size = idx_mapping.shape[0]
_compute_block_tables_kernel[(batch_size, self.num_kv_cache_groups)](
idx_mapping,
self.buffer_ptrs,
self.block_table_ptrs,
self.block_table_strides,
self.num_blocks,
self.num_blocks.stride(0),
BLOCK_SIZE=1024,
)
return tuple(b[:batch_size] for b in self.block_tables)
def compute_slot_mappings(
self,
cu_num_tokens: torch.Tensor,
pos: torch.Tensor,
) -> tuple[torch.Tensor, ...]:
num_tokens = pos.shape[0]
num_reqs = cu_num_tokens.shape[0] - 1
num_groups = self.num_kv_cache_groups
_compute_slot_mappings_kernel[(num_reqs + 1, num_groups)](
num_tokens,
self.max_num_batched_tokens,
cu_num_tokens,
pos,
self.block_table_ptrs,
self.block_table_strides,
self.block_sizes_tensor,
self.slot_mappings,
self.slot_mappings.stride(0),
PAD_ID=PAD_SLOT_ID,
BLOCK_SIZE=1024,
)
return tuple(x[:num_tokens] for x in self.slot_mappings)
@triton.jit
def _append_block_ids_kernel(
# Inputs
req_indices, # [num_reqs]
cu_num_new_blocks_ptr, # [num_kv_cache_groups, num_reqs + 1]
cu_num_new_blocks_stride,
new_block_ids_ptr, # [num_kv_cache_groups, num_new_blocks]
new_block_ids_stride,
overwrite, # [num_reqs]
block_table_strides, # [num_kv_cache_groups]
# Outputs
block_table_buffer_ptrs, # [num_kv_cache_groups]
num_blocks_ptr, # [num_kv_cache_groups, max_num_reqs]
num_blocks_stride,
# Constants
BLOCK_SIZE: tl.constexpr,
):
batch_idx = tl.program_id(0)
group_id = tl.program_id(1)
req_idx = tl.load(req_indices + batch_idx)
do_overwrite = tl.load(overwrite + batch_idx)
group_new_blocks_ptr = (cu_num_new_blocks_ptr +
group_id * cu_num_new_blocks_stride)
start_idx = tl.load(group_new_blocks_ptr + batch_idx)
end_idx = tl.load(group_new_blocks_ptr + batch_idx + 1)
num_new_blocks = end_idx - start_idx
group_num_blocks_ptr = num_blocks_ptr + group_id * num_blocks_stride
if do_overwrite:
dst_start_idx = 0
else:
dst_start_idx = tl.load(group_num_blocks_ptr + req_idx)
dst_end_idx = dst_start_idx + num_new_blocks
tl.store(group_num_blocks_ptr + req_idx, dst_end_idx)
# Destination
block_table_buffer_ptr = _load_ptr(block_table_buffer_ptrs + group_id,
tl.int32)
block_table_stride = tl.load(block_table_strides + group_id)
buffer_row_ptr = block_table_buffer_ptr + req_idx * block_table_stride
group_new_block_ids_ptr = (new_block_ids_ptr +
group_id * new_block_ids_stride)
for i in tl.range(0, num_new_blocks, BLOCK_SIZE):
offset = i + tl.arange(0, BLOCK_SIZE)
block_ids = tl.load(group_new_block_ids_ptr + start_idx + offset,
mask=offset < num_new_blocks)
tl.store(buffer_row_ptr + dst_start_idx + offset,
block_ids,
mask=offset < num_new_blocks)
@triton.jit
def _compute_block_tables_kernel(
batch_idx_to_req_idx, # [batch_size]
src_block_table_ptrs, # [num_kv_cache_groups]
dst_block_table_ptrs, # [num_kv_cache_groups]
block_table_strides, # [num_kv_cache_groups]
num_blocks_ptr, # [num_kv_cache_groups, max_num_reqs]
num_blocks_stride,
BLOCK_SIZE: tl.constexpr,
):
batch_idx = tl.program_id(0)
# kv cache group id
group_id = tl.program_id(1)
req_idx = tl.load(batch_idx_to_req_idx + batch_idx)
group_num_blocks_ptr = num_blocks_ptr + group_id * num_blocks_stride
num_blocks = tl.load(group_num_blocks_ptr + req_idx)
stride = tl.load(block_table_strides + group_id)
src_block_table_ptr = _load_ptr(src_block_table_ptrs + group_id, tl.int32)
src_row_ptr = src_block_table_ptr + req_idx * stride
dst_block_table_ptr = _load_ptr(dst_block_table_ptrs + group_id, tl.int32)
dst_row_ptr = dst_block_table_ptr + batch_idx * stride
for i in tl.range(0, num_blocks, BLOCK_SIZE):
offset = i + tl.arange(0, BLOCK_SIZE)
block_ids = tl.load(src_row_ptr + offset, mask=offset < num_blocks)
tl.store(dst_row_ptr + offset, block_ids, mask=offset < num_blocks)
@triton.jit
def _compute_slot_mappings_kernel(
num_tokens,
max_num_tokens,
cu_num_tokens, # [num_reqs + 1]
pos, # [num_tokens]
block_table_ptrs, # [num_kv_cache_groups]
block_table_strides, # [num_kv_cache_groups]
page_sizes, # [num_kv_cache_groups]
slot_mappings_ptr, # [num_kv_cache_groups, max_num_tokens]
slot_mappings_stride,
PAD_ID: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
req_idx = tl.program_id(0)
# kv cache group id
group_id = tl.program_id(1)
slot_mapping_ptr = slot_mappings_ptr + group_id * slot_mappings_stride
if req_idx == tl.num_programs(0) - 1:
# Pad remaining slots to -1. This is needed for CUDA graphs.
for i in tl.range(num_tokens, max_num_tokens, BLOCK_SIZE):
offset = i + tl.arange(0, BLOCK_SIZE)
tl.store(slot_mapping_ptr + offset,
PAD_ID,
mask=offset < max_num_tokens)
return
block_table_ptr = _load_ptr(block_table_ptrs + group_id, tl.int32)
block_table_stride = tl.load(block_table_strides + group_id)
page_size = tl.load(page_sizes + group_id)
start_idx = tl.load(cu_num_tokens + req_idx)
end_idx = tl.load(cu_num_tokens + req_idx + 1)
for i in tl.range(start_idx, end_idx, BLOCK_SIZE):
offset = i + tl.arange(0, BLOCK_SIZE)
positions = tl.load(pos + offset, mask=offset < end_idx, other=0)
block_indices = positions // page_size
block_numbers = tl.load(block_table_ptr +
req_idx * block_table_stride + block_indices)
slot_ids = block_numbers * page_size + positions % page_size
tl.store(slot_mapping_ptr + offset, slot_ids, mask=offset < end_idx)
@triton.jit
def _load_ptr(ptr_to_ptr, elem_dtype):
ptr = tl.load(ptr_to_ptr)
return tl.cast(ptr, tl.pointer_type(elem_dtype))

View File

@ -1,803 +1,143 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Datastructures defining a GPU input batch
from dataclasses import dataclass
from typing import Optional, cast
from typing import Any, Optional
import numba
import numpy as np
import torch
from typing_extensions import deprecated
from numba import types
from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import (MultiModalKwargsItem,
MultiModalKwargsItems, PlaceholderRange)
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.utils import swap_dict_values
from vllm.v1.outputs import LogprobsTensors
from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.logits_processor import (BatchUpdateBuilder,
LogitsProcessors,
MoveDirectionality)
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.utils import is_spec_decode_unsupported
from vllm.v1.utils import copy_slice
from vllm.v1.worker.block_table import MultiGroupBlockTable
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
@dataclass
class CachedRequestState:
req_id: str
prompt_token_ids: list[int]
mm_kwargs: list[MultiModalKwargsItem]
mm_positions: list[PlaceholderRange]
mm_hashes: list[str]
sampling_params: Optional[SamplingParams]
pooling_params: Optional[PoolingParams]
generator: Optional[torch.Generator]
block_ids: tuple[list[int], ...]
num_computed_tokens: int
output_token_ids: list[int]
mrope_positions: Optional[torch.Tensor] = None
mrope_position_delta: Optional[int] = None
lora_request: Optional[LoRARequest] = None
def __post_init__(self):
self.num_prompt_tokens = len(self.prompt_token_ids)
@property
def num_tokens(self) -> int:
return self.num_prompt_tokens + len(self.output_token_ids)
# Temporary back-compatibility for plugins that define model runner
@property
@deprecated("`mm_inputs` is superseded by `mm_kwargs` and will be "
"removed in v0.13. Please use `mm_kwargs` instead.")
def mm_inputs(self) -> list[MultiModalKwargsItems]:
return [
MultiModalKwargsItems.from_seq([item]) for item in self.mm_kwargs
]
def get_token_id(self, idx: int) -> int:
if idx < self.num_prompt_tokens:
return self.prompt_token_ids[idx]
return self.output_token_ids[idx - self.num_prompt_tokens]
class InputBatch:
def __init__(
self,
max_num_reqs: int,
max_model_len: int,
max_num_batched_tokens: int,
device: torch.device,
pin_memory: bool,
vocab_size: int,
block_sizes: list[int], # The block_size of each kv cache group
logitsprocs: Optional[LogitsProcessors] = None,
is_spec_decode: bool = False,
is_pooling_model: bool = False,
):
self.is_pooling_model = is_pooling_model
self.is_spec_decode = is_spec_decode
self.max_num_reqs = max_num_reqs
self.max_model_len = max_model_len
self.max_num_batched_tokens = max_num_batched_tokens
self.device = device
self.pin_memory = pin_memory
self.vocab_size = vocab_size
# batch_idx -> req_id
req_ids: list[str]
self._req_ids: list[Optional[str]] = []
self.req_id_to_index: dict[str, int] = {}
# req_id -> batch_idx
req_id_to_batch_idx: dict[str, int]
# TODO(woosuk): This buffer could be too large if max_model_len is big.
# Find a way to reduce the CPU memory usage.
# This buffer is not directly transferred to the GPU, so it does not
# need to be pinned.
self.token_ids_cpu_tensor = torch.zeros(
(max_num_reqs, max_model_len),
device="cpu",
dtype=torch.int32,
pin_memory=False,
# batch_idx -> req_state_idx
idx_mapping: torch.Tensor
idx_mapping_np: np.ndarray
# batch_idx -> num_scheduled_tokens
num_scheduled_tokens: np.ndarray
total_num_tokens: int
max_query_len: int
num_reqs: int
attn_metadata: dict[str, Any]
spec_decode_common_attn_metadata: Optional[Any]
spec_decode_metadata: Optional[SpecDecodeMetadata]
logits_indices: torch.Tensor
# NOTE: With the type annotations, this function is pre-compiled
# before the first call.
@numba.jit(
[
types.none(
types.int32[:], # idx_mapping
types.int32[:, :], # token_ids
types.int32[:], # num_computed_tokens
types.int32[:], # num_scheduled_tokens
types.int32[:], # input_ids
types.int32[:], # query_start_loc
types.int32[:], # seq_lens
types.int64[:], # positions
)
self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32)
self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32)
self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
self.num_computed_tokens_cpu_tensor = torch.zeros(
(max_num_reqs, ),
device="cpu",
dtype=torch.int32,
pin_memory=pin_memory,
)
self.num_computed_tokens_cpu = \
self.num_computed_tokens_cpu_tensor.numpy()
# Block table.
self.block_table = MultiGroupBlockTable(
max_num_reqs=max_num_reqs,
max_model_len=max_model_len,
max_num_batched_tokens=max_num_batched_tokens,
pin_memory=pin_memory,
device=device,
block_sizes=block_sizes,
)
# Sampling-related.
self.temperature = torch.empty((max_num_reqs, ),
dtype=torch.float32,
device=device)
self.temperature_cpu_tensor = torch.empty((max_num_reqs, ),
dtype=torch.float32,
device="cpu",
pin_memory=pin_memory)
self.temperature_cpu = self.temperature_cpu_tensor.numpy()
self.greedy_reqs: set[str] = set()
self.random_reqs: set[str] = set()
self.top_p = torch.empty((max_num_reqs, ),
dtype=torch.float32,
device=device)
self.top_p_cpu_tensor = torch.empty((max_num_reqs, ),
dtype=torch.float32,
device="cpu",
pin_memory=pin_memory)
self.top_p_cpu = self.top_p_cpu_tensor.numpy()
self.top_p_reqs: set[str] = set()
self.top_k = torch.empty((max_num_reqs, ),
dtype=torch.int32,
device=device)
self.top_k_cpu_tensor = torch.empty((max_num_reqs, ),
dtype=torch.int32,
device="cpu",
pin_memory=pin_memory)
self.top_k_cpu = self.top_k_cpu_tensor.numpy()
self.top_k_reqs: set[str] = set()
# IDs of requests which do not support spec decoding
self.spec_decode_unsupported_reqs: set[str] = set()
# Frequency penalty related data structures
self.frequency_penalties = torch.empty((max_num_reqs, ),
dtype=torch.float,
device=device)
self.frequency_penalties_cpu_tensor = torch.empty(
(max_num_reqs, ),
dtype=torch.float,
device="cpu",
pin_memory=pin_memory)
self.frequency_penalties_cpu = \
self.frequency_penalties_cpu_tensor.numpy()
self.frequency_penalties_reqs: set[str] = set()
# Presence penalty related data structures
self.presence_penalties = torch.empty((max_num_reqs, ),
dtype=torch.float,
device=device)
self.presence_penalties_cpu_tensor = torch.empty((max_num_reqs, ),
dtype=torch.float,
device="cpu",
pin_memory=pin_memory)
self.presence_penalties_cpu = self.presence_penalties_cpu_tensor.numpy(
)
self.presence_penalties_reqs: set[str] = set()
# Repetition penalty related data structures
self.repetition_penalties = torch.empty((max_num_reqs, ),
dtype=torch.float,
device=device)
self.repetition_penalties_cpu_tensor = torch.empty(
(max_num_reqs, ),
dtype=torch.float,
device="cpu",
pin_memory=pin_memory)
self.repetition_penalties_cpu = \
self.repetition_penalties_cpu_tensor.numpy()
self.repetition_penalties_reqs: set[str] = set()
# lora related
self.request_lora_mapping = np.zeros((self.max_num_reqs, ),
dtype=np.int32)
self.lora_id_to_request_ids: dict[int, set[str]] = {}
self.lora_id_to_lora_request: dict[int, LoRARequest] = {}
# req_index -> generator
# NOTE(woosuk): The indices of the requests that do not have their own
# generator should not be included in the dictionary.
self.generators: dict[int, torch.Generator] = {}
self.num_logprobs: dict[str, int] = {}
# NOTE(rob): num_prompt_logprobs only includes reqs
# that are currently in the prefill phase.
self.num_prompt_logprobs: dict[str, int] = {}
# To accumulate prompt logprobs tensor chunks across prefill steps.
self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {}
# Internal representation of per-step batch state changes, used for
# reordering persistent batch and generating logitsprocs batch state
# updates. Should reset each step.
self.batch_update_builder = BatchUpdateBuilder()
# TODO convert this to LogitsProcessor
self.has_allowed_token_ids: set[str] = set()
# NOTE(lufang): In the mask tensor, if the corresponding token allowed,
# the value is False. Since we use masked_fill_ to set -inf.
self.allowed_token_ids_mask: Optional[torch.Tensor] = None
self.allowed_token_ids_mask_cpu_tensor: Optional[torch.Tensor] = None
# req_index -> bad_words_token_ids
self.bad_words_token_ids: dict[int, list[list[int]]] = {}
self.logits_processing_needs_token_ids = np.zeros(max_num_reqs,
dtype=bool)
self.req_output_token_ids: list[Optional[list[int]]] = []
# Store provided logitsprocs. If none are provided, initialize empty
# data structure
self.logitsprocs = logitsprocs or LogitsProcessors()
# This is updated each time the batch constituents change.
self.sampling_metadata = self._make_sampling_metadata()
self.pooling_params: dict[str, PoolingParams] = {}
@property
def req_ids(self) -> list[str]:
# None elements should only be present transiently
# while performing state updates to the batch.
return cast(list[str], self._req_ids)
def _register_add_request(self, request: "CachedRequestState") -> int:
"""Track add-request operations for logits processors.
Not applicable to pooling models.
"""
# Fill the next empty index if there is one.
if (new_req_index := self.batch_update_builder.pop_removed()) is None:
# Append to end otherwise.
new_req_index = self.num_reqs
assert new_req_index < self.max_num_reqs
self.batch_update_builder.batch_changed = True
if request.sampling_params:
# Detailed added request metadata is only required for non-pooling
# models, to support logitsprocs.
self.batch_update_builder.added.append(
(new_req_index, request.sampling_params,
request.prompt_token_ids, request.output_token_ids))
return new_req_index
def add_request(
self,
request: "CachedRequestState",
) -> int:
req_index = self._register_add_request(request)
req_id = request.req_id
if req_index == len(self._req_ids):
self._req_ids.append(req_id)
self.req_output_token_ids.append(request.output_token_ids)
else:
self._req_ids[req_index] = req_id
self.req_output_token_ids[req_index] = request.output_token_ids
self.req_id_to_index[req_id] = req_index
# Copy the prompt token ids and output token ids.
num_prompt_tokens = len(request.prompt_token_ids)
self.num_prompt_tokens[req_index] = num_prompt_tokens
self.token_ids_cpu[
req_index, :num_prompt_tokens] = request.prompt_token_ids
start_idx = num_prompt_tokens
end_idx = start_idx + len(request.output_token_ids)
self.token_ids_cpu[req_index,
start_idx:end_idx] = request.output_token_ids
# Number of token ids in token_ids_cpu.
# NOTE(woosuk): This may include spec decode tokens.
self.num_tokens[req_index] = request.num_tokens
# Number of tokens without spec decode tokens.
self.num_tokens_no_spec[req_index] = request.num_tokens
self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
self.block_table.add_row(request.block_ids, req_index)
if sampling_params := request.sampling_params:
if (self.is_spec_decode
and is_spec_decode_unsupported(sampling_params)):
self.spec_decode_unsupported_reqs.add(req_id)
if sampling_params.sampling_type == SamplingType.GREEDY:
# Avoid later division by zero.
self.temperature_cpu[req_index] = -1.0
self.greedy_reqs.add(req_id)
else:
self.temperature_cpu[req_index] = sampling_params.temperature
self.random_reqs.add(req_id)
self.top_p_cpu[req_index] = sampling_params.top_p
if sampling_params.top_p < 1:
self.top_p_reqs.add(req_id)
top_k = sampling_params.top_k
if 0 < top_k < self.vocab_size:
self.top_k_reqs.add(req_id)
else:
top_k = self.vocab_size
self.top_k_cpu[req_index] = top_k
self.frequency_penalties_cpu[
req_index] = sampling_params.frequency_penalty
if sampling_params.frequency_penalty != 0.0:
self.frequency_penalties_reqs.add(req_id)
self.presence_penalties_cpu[
req_index] = sampling_params.presence_penalty
if sampling_params.presence_penalty != 0.0:
self.presence_penalties_reqs.add(req_id)
self.repetition_penalties_cpu[
req_index] = sampling_params.repetition_penalty
if sampling_params.repetition_penalty != 1.0:
self.repetition_penalties_reqs.add(req_id)
# NOTE(woosuk): self.generators should not include the requests that
# do not have their own generator.
if request.generator is not None:
self.generators[req_index] = request.generator
if sampling_params.logprobs is not None:
self.num_logprobs[req_id] = (self.vocab_size
if sampling_params.logprobs == -1
else sampling_params.logprobs)
if sampling_params.prompt_logprobs is not None:
self.num_prompt_logprobs[
req_id] = sampling_params.prompt_logprobs
if sampling_params.allowed_token_ids:
self.has_allowed_token_ids.add(req_id)
if self.allowed_token_ids_mask_cpu_tensor is None:
# Lazy allocation for this tensor, which can be large.
# False means we don't fill with -inf.
self.allowed_token_ids_mask = torch.zeros(
self.max_num_reqs,
self.vocab_size,
dtype=torch.bool,
device=self.device)
self.allowed_token_ids_mask_cpu_tensor = torch.zeros(
self.max_num_reqs,
self.vocab_size,
dtype=torch.bool,
device="cpu")
self.allowed_token_ids_mask_cpu_tensor[req_index] = True
# False means we don't fill with -inf.
self.allowed_token_ids_mask_cpu_tensor[req_index][
sampling_params.allowed_token_ids] = False
if sampling_params.bad_words_token_ids:
self.bad_words_token_ids[
req_index] = sampling_params.bad_words_token_ids
elif pooling_params := request.pooling_params:
self.pooling_params[req_id] = pooling_params
self.logits_processing_needs_token_ids[req_index] = (
pooling_params.requires_token_ids)
else:
raise NotImplementedError("Unrecognized request type")
# Add request lora ID
if request.lora_request:
lora_id = request.lora_request.lora_int_id
if lora_id not in self.lora_id_to_request_ids:
self.lora_id_to_request_ids[lora_id] = set()
self.request_lora_mapping[req_index] = lora_id
self.lora_id_to_request_ids[lora_id].add(request.req_id)
self.lora_id_to_lora_request[lora_id] = request.lora_request
else:
# No LoRA
self.request_lora_mapping[req_index] = 0
return req_index
def remove_request(self, req_id: str) -> Optional[int]:
"""This method must always be followed by a call to condense().
Args:
req_id: request to remove
Returns:
Removed request index, or `None` if `req_id` not recognized
"""
req_index = self.req_id_to_index.pop(req_id, None)
if req_index is None:
return None
self.batch_update_builder.removed_append(req_index)
self._req_ids[req_index] = None
self.req_output_token_ids[req_index] = None
# LoRA
lora_id = self.request_lora_mapping[req_index]
if lora_id != 0:
lora_req_ids = self.lora_id_to_request_ids[lora_id]
lora_req_ids.discard(req_id)
if not lora_req_ids:
del self.lora_id_to_request_ids[lora_id]
del self.lora_id_to_lora_request[lora_id]
self.request_lora_mapping[req_index] = 0
if self.is_pooling_model:
self.pooling_params.pop(req_id, None)
return req_index
self.greedy_reqs.discard(req_id)
self.random_reqs.discard(req_id)
self.top_p_reqs.discard(req_id)
self.top_k_reqs.discard(req_id)
self.spec_decode_unsupported_reqs.discard(req_id)
self.frequency_penalties_reqs.discard(req_id)
self.presence_penalties_reqs.discard(req_id)
self.repetition_penalties_reqs.discard(req_id)
self.generators.pop(req_index, None)
self.num_logprobs.pop(req_id, None)
self.num_prompt_logprobs.pop(req_id, None)
self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
self.has_allowed_token_ids.discard(req_id)
if self.allowed_token_ids_mask_cpu_tensor is not None:
# False means we don't fill with -inf.
self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False)
self.bad_words_token_ids.pop(req_index, None)
return req_index
def swap_states(self, i1: int, i2: int) -> None:
old_id_i1 = self._req_ids[i1]
old_id_i2 = self._req_ids[i2]
self._req_ids[i1], self._req_ids[i2] =\
self._req_ids[i2], self._req_ids[i1] # noqa
self.req_output_token_ids[i1], self.req_output_token_ids[i2] =\
self.req_output_token_ids[i2], self.req_output_token_ids[i1]
assert old_id_i1 is not None and old_id_i2 is not None
self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] =\
self.req_id_to_index[old_id_i2], self.req_id_to_index[old_id_i1]
self.num_tokens[i1], self.num_tokens[i2] =\
self.num_tokens[i2], self.num_tokens[i1]
self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] =\
self.num_tokens_no_spec[i2], self.num_tokens_no_spec[i1]
self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] =\
self.num_prompt_tokens[i2], self.num_prompt_tokens[i1]
self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] =\
self.num_computed_tokens_cpu[i2], self.num_computed_tokens_cpu[i1]
# NOTE: the following is unsafe
# self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
# self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...]
# instead, we need to temporiarily copy the data for one of the indices
# TODO(lucas): optimize this by only copying valid indices
tmp = self.token_ids_cpu[i1, ...].copy()
self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...]
self.token_ids_cpu[i2, ...] = tmp
self.block_table.swap_row(i1, i2)
self.request_lora_mapping[i1], self.request_lora_mapping[i2] = \
self.request_lora_mapping[i2], self.request_lora_mapping[i1]
if self.is_pooling_model:
# Sampling and logits parameters don't apply to pooling models.
return
# For autoregressive models, track detailed request reordering info
# to support logitsprocs.
self.batch_update_builder.moved.append(
(i1, i2, MoveDirectionality.SWAP))
self.temperature_cpu[i1], self.temperature_cpu[i2] = \
self.temperature_cpu[i2], self.temperature_cpu[i1]
self.top_p_cpu[i1], self.top_p_cpu[i2] = \
self.top_p_cpu[i2], self.top_p_cpu[i1]
self.top_k_cpu[i1], self.top_k_cpu[i2] = \
self.top_k_cpu[i2], self.top_k_cpu[i1]
self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] = \
self.frequency_penalties_cpu[i2], self.frequency_penalties_cpu[i1]
self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] = \
self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1]
self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] = \
self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1]
swap_dict_values(self.generators, i1, i2)
swap_dict_values(self.bad_words_token_ids, i1, i2)
if self.allowed_token_ids_mask_cpu_tensor is not None:
self.allowed_token_ids_mask_cpu_tensor[i1], \
self.allowed_token_ids_mask_cpu_tensor[i2] =\
self.allowed_token_ids_mask_cpu_tensor[i2], \
self.allowed_token_ids_mask_cpu_tensor[i1]
def condense(self) -> None:
"""Slide non-empty requests down into lower, empty indices.
Any consecutive empty indices at the very end of the list are not
filled.
Returns:
swaps: list of (from,to) swap tuples for moved requests
empty_req_indices: indices not filled by condensation
"""
num_reqs = self.num_reqs
if not (empty_req_indices := self.batch_update_builder.removed):
# All removed requests were replaced by added requests, or else no
# requests were removed at all. No condense() needed
return
if num_reqs == 0:
# The batched states are empty.
self._req_ids.clear()
self.req_output_token_ids.clear()
return
# NOTE(woosuk): This function assumes that the empty_req_indices
# is sorted in descending order.
last_req_index = num_reqs + len(empty_req_indices) - 1
while empty_req_indices:
# Find the largest non-empty index.
while last_req_index in empty_req_indices:
last_req_index -= 1
# Find the smallest empty index.
empty_index = self.batch_update_builder.peek_removed()
assert empty_index is not None
if empty_index >= last_req_index:
break
# Move active request down into empty request
# index.
self.batch_update_builder.pop_removed()
req_id = self._req_ids[last_req_index]
output_token_ids = self.req_output_token_ids[last_req_index]
assert req_id is not None
self._req_ids[empty_index] = req_id
self._req_ids[last_req_index] = None
self.req_output_token_ids[empty_index] = output_token_ids
self.req_output_token_ids[last_req_index] = None
self.req_id_to_index[req_id] = empty_index
num_tokens = self.num_tokens[last_req_index]
self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
last_req_index, :num_tokens]
self.num_tokens[empty_index] = num_tokens
self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[
last_req_index]
self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[
last_req_index]
self.num_computed_tokens_cpu[
empty_index] = self.num_computed_tokens_cpu[last_req_index]
self.block_table.move_row(last_req_index, empty_index)
self.request_lora_mapping[empty_index] = self.request_lora_mapping[
last_req_index]
if self.is_pooling_model:
last_req_index -= 1
# Samping state not used by pooling models.
continue
# Autoregressive models require detailed tracking of condense
# operations to support logitsprocs
self.batch_update_builder.moved.append(
(last_req_index, empty_index,
MoveDirectionality.UNIDIRECTIONAL))
self.temperature_cpu[empty_index] = self.temperature_cpu[
last_req_index]
self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index]
self.frequency_penalties_cpu[
empty_index] = self.frequency_penalties_cpu[last_req_index]
self.presence_penalties_cpu[
empty_index] = self.presence_penalties_cpu[last_req_index]
self.repetition_penalties_cpu[
empty_index] = self.repetition_penalties_cpu[last_req_index]
generator = self.generators.pop(last_req_index, None)
if generator is not None:
self.generators[empty_index] = generator
# TODO convert these to LogitsProcessors
if self.allowed_token_ids_mask_cpu_tensor is not None:
self.allowed_token_ids_mask_cpu_tensor[
empty_index] = self.allowed_token_ids_mask_cpu_tensor[
last_req_index]
bad_words_token_ids = self.bad_words_token_ids.pop(
last_req_index, None)
if bad_words_token_ids is not None:
self.bad_words_token_ids[empty_index] = bad_words_token_ids
# Decrement last_req_index since it is now empty.
last_req_index -= 1
# Trim lists to the batch size.
del self._req_ids[num_reqs:]
del self.req_output_token_ids[num_reqs:]
def refresh_metadata(self):
"""Apply any batch updates to sampling metadata."""
if self.is_pooling_model:
batch_changed = self.batch_update_builder.reset()
if batch_changed:
self.sampling_metadata = self._make_sampling_metadata()
return
# For non-pooling models - generate and apply logitsprocs update;
# reset batch update tracking.
# Update sampling metadata if batch state is changed.
batch_update = self.batch_update_builder.get_and_reset(self.num_reqs)
for logit_proc in self.logitsprocs.all:
logit_proc.update_state(batch_update)
if batch_update:
self.sampling_metadata = self._make_sampling_metadata()
def _make_sampling_metadata(self) -> SamplingMetadata:
num_reqs = self.num_reqs
if not self.all_greedy:
temperature = copy_slice(self.temperature_cpu_tensor,
self.temperature, num_reqs)
else:
temperature = None
if not self.no_top_p:
copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs)
if not self.no_top_k:
copy_slice(self.top_k_cpu_tensor, self.top_k, num_reqs)
if not self.no_penalties:
# Since syncing these tensors is expensive only copy them
# if necessary i.e. if there are requests which require
# penalties to be applied during sampling.
copy_slice(self.frequency_penalties_cpu_tensor,
self.frequency_penalties, num_reqs)
copy_slice(self.presence_penalties_cpu_tensor,
self.presence_penalties, num_reqs)
copy_slice(self.repetition_penalties_cpu_tensor,
self.repetition_penalties, num_reqs)
needs_prompt_token_ids = (
not self.no_penalties
or self.logits_processing_needs_token_ids[:num_reqs].any())
if needs_prompt_token_ids:
# The prompt tokens are used only for applying penalties or
# step pooling during the sampling/pooling process.
# Hence copy these tensors only when there are requests which
# need penalties/step_pooler to be applied.
prompt_token_ids = self._make_prompt_token_ids_tensor()
else:
prompt_token_ids = None
allowed_token_ids_mask: Optional[torch.Tensor] = None
if not self.no_allowed_token_ids:
assert self.allowed_token_ids_mask is not None
copy_slice(self.allowed_token_ids_mask_cpu_tensor,
self.allowed_token_ids_mask, num_reqs)
allowed_token_ids_mask = self.allowed_token_ids_mask[:num_reqs]
return SamplingMetadata(
temperature=temperature,
all_greedy=self.all_greedy,
all_random=self.all_random,
top_p=None if self.no_top_p else self.top_p[:num_reqs],
top_k=None if self.no_top_k else self.top_k[:num_reqs],
generators=self.generators,
max_num_logprobs=self.max_num_logprobs,
prompt_token_ids=prompt_token_ids,
frequency_penalties=self.frequency_penalties[:num_reqs],
presence_penalties=self.presence_penalties[:num_reqs],
repetition_penalties=self.repetition_penalties[:num_reqs],
output_token_ids=cast(list[list[int]], self.req_output_token_ids),
no_penalties=self.no_penalties,
allowed_token_ids_mask=allowed_token_ids_mask,
bad_words_token_ids=self.bad_words_token_ids,
logitsprocs=self.logitsprocs,
)
@property
def pooling_metadata(self) -> PoolingMetadata:
if len(self.pooling_params) == 0:
pooling_params = []
else:
# Note, for now this assumes that all request in the batch
# are either sampling or pooling requests
assert len(self.req_ids) == len(self.pooling_params)
pooling_params = [
self.pooling_params[req_id] for req_id in self.req_ids
]
return PoolingMetadata(
prompt_lens=torch.from_numpy(
self.num_prompt_tokens[:self.num_reqs]),
prompt_token_ids=self.sampling_metadata.prompt_token_ids,
pooling_params=pooling_params,
)
def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
num_reqs = self.num_reqs
max_prompt_len = self.num_prompt_tokens[:num_reqs].max()
prompt_token_ids_cpu_tensor = torch.empty(
(self.num_reqs, max_prompt_len),
device="cpu",
dtype=torch.int64,
pin_memory=self.pin_memory,
)
prompt_token_ids = prompt_token_ids_cpu_tensor.numpy()
prompt_token_ids[:] = self.token_ids_cpu[:num_reqs, :max_prompt_len]
# Use the value of vocab_size as a pad since we don't have a
# token_id of this value.
for i in range(num_reqs):
prompt_token_ids[i, self.num_prompt_tokens[i]:] = self.vocab_size
return prompt_token_ids_cpu_tensor.to(device=self.device,
non_blocking=True)
def make_lora_inputs(
self, num_scheduled_tokens: np.ndarray
) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
"""
Given the num_scheduled_tokens for each request in the batch, return
datastructures used to activate the current LoRAs.
Returns:
1. prompt_lora_mapping: A tuple of size self.num_reqs where,
prompt_lora_mapping[i] is the LoRA id to use for the ith prompt.
2. token_lora_mapping: A tuple of size np.sum(num_scheduled_tokens)
where, token_lora_mapping[i] is the LoRA id to use for ith token.
3. lora_requests: Set of relevant LoRA requests.
"""
req_lora_mapping = self.request_lora_mapping[:self.num_reqs]
prompt_lora_mapping = tuple(req_lora_mapping)
token_lora_mapping = tuple(
req_lora_mapping.repeat(num_scheduled_tokens))
active_lora_requests: set[LoRARequest] = set(
self.lora_id_to_lora_request.values())
return prompt_lora_mapping, token_lora_mapping, active_lora_requests
@property
def num_reqs(self) -> int:
return len(self.req_id_to_index)
@property
def all_greedy(self) -> bool:
return len(self.random_reqs) == 0
@property
def all_random(self) -> bool:
return len(self.greedy_reqs) == 0
@property
def no_top_p(self) -> bool:
return len(self.top_p_reqs) == 0
@property
def no_top_k(self) -> bool:
return len(self.top_k_reqs) == 0
@property
def no_penalties(self) -> bool:
return (len(self.presence_penalties_reqs) == 0
and len(self.frequency_penalties_reqs) == 0
and len(self.repetition_penalties_reqs) == 0)
@property
def max_num_logprobs(self) -> Optional[int]:
return max(self.num_logprobs.values()) if self.num_logprobs else None
@property
def no_prompt_logprob(self) -> bool:
return not self.num_prompt_logprobs
@property
def no_allowed_token_ids(self) -> bool:
return len(self.has_allowed_token_ids) == 0
],
nopython=True,
cache=True,
)
def prepare_inputs(
# Inputs
idx_mapping: np.ndarray, # batch_idx -> req_idx
token_ids: np.ndarray, # [N, max_model_len]
num_computed_tokens: np.ndarray, # [N]
num_scheduled_tokens: np.ndarray, # [B]
# Outputs
input_ids: np.ndarray, # [num_input_tokens]
query_start_loc: np.ndarray, # [B + 1]
seq_lens: np.ndarray, # [B]
positions: np.ndarray, # [num_input_tokens]
) -> None:
num_reqs = num_scheduled_tokens.shape[0]
query_start_loc[0] = 0
cu_num_tokens = 0
for i in range(num_reqs):
req_idx = idx_mapping[i]
start = num_computed_tokens[req_idx]
end = start + num_scheduled_tokens[i]
seq_lens[i] = end
start_idx = cu_num_tokens
end_idx = start_idx + num_scheduled_tokens[i]
input_ids[start_idx:end_idx] = token_ids[req_idx, start:end]
positions[start_idx:end_idx] = np.arange(start, end)
cu_num_tokens = end_idx
query_start_loc[i + 1] = cu_num_tokens
# Pad the inputs for CUDA graphs.
# Note: pad query_start_loc to be non-decreasing, as kernels
# like FlashAttention requires that
query_start_loc[num_reqs + 1:].fill(cu_num_tokens)
# Fill unused with 0 for full cuda graph mode.
seq_lens[num_reqs:].fill(0)
def prepare_spec_decode(
# Inputs
query_start_loc: np.ndarray, # [B + 1]
num_draft_tokens: np.ndarray, # [B]
# Outputs
cu_num_draft_tokens: np.ndarray, # [B]
logits_indices: np.ndarray, # [N + B]
target_logits_indices: np.ndarray, # [N]
bonus_logits_indices: np.ndarray, # [B]
) -> int: # N
# Inputs:
# query_start_loc: [ 0, 4, 104, 107, 207, 209]
# num_draft_tokens: [ 3, 0, 2, 0, 1]
# Outputs:
# cu_num_draft_tokens: [ 3, 3, 5, 5, 6]
# logits_indices: [ 0, 1, 2, 3, 103, 104, 105, 106,
# 206, 207, 208]
# target_logits_indices: [ 0, 1, 2, 5, 6, 9]
# bonus_logits_indices: [ 3, 4, 7, 8, 10]
# return: 6 (total number of draft tokens)
cu_num_draft = 0
cu_num_sample = 0
num_reqs = num_draft_tokens.shape[0]
for i in range(num_reqs):
q_end_idx = query_start_loc[i + 1]
draft_len = num_draft_tokens[i]
# The last draft_len + 1 query tokens are used for sampling.
sample_len = draft_len + 1
sample_start_idx = cu_num_sample
sample_end_idx = sample_start_idx + sample_len
logits_indices[sample_start_idx:sample_end_idx] = (np.arange(
q_end_idx - sample_len, q_end_idx))
# For each query, the first draft_len tokens need target logits for
# rejection sampling. The draft_len + 1th token is used for bonus token.
draft_start_idx = cu_num_draft
draft_end_idx = draft_start_idx + draft_len
target_logits_indices[draft_start_idx:draft_end_idx] = (np.arange(
sample_start_idx, sample_end_idx - 1))
bonus_logits_indices[i] = sample_end_idx - 1
cu_num_draft += draft_len
cu_num_draft_tokens[i] = cu_num_draft
cu_num_sample += sample_len
return cu_num_draft

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,335 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Datastructures defining a GPU input batch
from dataclasses import dataclass
from typing import Optional, Union
import numpy as np
import torch
import triton
import triton.language as tl
from typing_extensions import deprecated
from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import (MultiModalKwargsItem,
MultiModalKwargsItems, PlaceholderRange)
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.v1.sample.logits_processor import LogitsProcessors
from vllm.v1.sample.metadata import SamplingMetadata
@dataclass
class RequestData:
mm_kwargs: list[MultiModalKwargsItem]
mm_positions: list[PlaceholderRange]
sampling_params: Optional[SamplingParams]
pooling_params: Optional[PoolingParams]
mm_hashes: list[str]
# M-RoPE (only for Qwen2/2.5-VL)
mrope_positions: Optional[torch.Tensor] = None
mrope_position_delta: Optional[int] = None
lora_request: Optional[LoRARequest] = None
# Temporary back-compatibility for plugins that define model runner
@property
@deprecated("`mm_inputs` is superseded by `mm_kwargs` and will be "
"removed in v0.13. Please use `mm_kwargs` instead.")
def mm_inputs(self) -> list[MultiModalKwargsItems]:
return [
MultiModalKwargsItems.from_seq([item]) for item in self.mm_kwargs
]
class Param:
def __init__(
self,
num_rows_cpu: int,
num_cols: int,
num_rows_gpu: int,
dtype: torch.dtype,
device: torch.device,
pin_memory: bool,
is_scalar: bool = False,
):
self.cpu = torch.zeros(num_rows_cpu,
num_cols,
dtype=dtype,
device="cpu",
pin_memory=pin_memory)
self.np = self.cpu.numpy()
self.gpu = torch.zeros(num_rows_gpu,
num_cols,
dtype=dtype,
device=device)
if is_scalar:
self.cpu.squeeze_(1)
self.np = self.cpu.numpy()
self.gpu.squeeze_(1)
# TODO(woosuk): Optimize this.
self.gpu_buffer = self.cpu.to(device)
def mirror_to_gpu(self) -> torch.Tensor:
return self.gpu_buffer.copy_(self.cpu, non_blocking=True)
class RequestState:
def __init__(
self,
max_num_reqs: int,
max_model_len: int,
max_num_batched_tokens: int,
max_num_cached_reqs: int,
device: torch.device,
pin_memory: bool,
vocab_size: int,
block_sizes: list[int], # The block_size of each kv cache group
logitsprocs: Optional[LogitsProcessors] = None,
is_spec_decode: bool = False,
is_pooling_model: bool = False,
):
self.max_num_reqs = max_num_reqs
self.max_model_len = max_model_len
self.max_num_batched_tokens = max_num_batched_tokens
self.max_num_cached_reqs = max_num_cached_reqs
self.device = device
self.pin_memory = pin_memory
self.vocab_size = vocab_size
self.is_spec_decode = is_spec_decode
self.pooling_params = None
self.block_sizes = block_sizes
self.num_prompt_logprobs = {}
self.req_id_to_index: dict[str, int] = {}
self.index_to_req_id: dict[int, str] = {}
self.free_indices = list(range(max_num_cached_reqs))
# Request states.
self.req_data: dict[int, RequestData] = {}
# TODO(woosuk): Because the token_ids tensor can be very big, we only
# initialize it on CPU memory.
self.token_ids = self._make_param(
num_cols=self.max_model_len,
dtype=torch.int32,
cpu_only=True,
)
self.num_prompt_tokens = self._make_param(torch.int32)
self.num_tokens = self._make_param(torch.int32)
self.num_computed_tokens = self._make_param(torch.int32)
# Sampling-related.
self.temperature = self._make_param(torch.float32)
self.greedy_reqs: set[str] = set()
self.random_reqs: set[str] = set()
self.top_p = self._make_param(torch.float32)
self.top_p_reqs: set[str] = set()
self.top_k = self._make_param(torch.int32)
self.top_k_reqs: set[str] = set()
self.frequency_penalties = self._make_param(torch.float32)
self.frequency_penalties_reqs: set[str] = set()
self.presence_penalties = self._make_param(torch.float32)
self.presence_penalties_reqs: set[str] = set()
self.repetition_penalties = self._make_param(torch.float32)
self.repetition_penalties_reqs: set[str] = set()
# req_idx -> generator
self.generators: dict[int, torch.Generator] = {}
def _make_param(
self,
dtype: torch.dtype,
num_cols: int = 1,
cpu_only: bool = False,
) -> Param:
return Param(
self.max_num_cached_reqs,
num_cols,
self.max_num_reqs if not cpu_only else 0,
dtype,
self.device,
self.pin_memory,
is_scalar=num_cols == 1,
)
def add_request(
self,
req_id: str,
prompt_token_ids: list[int],
num_computed_tokens: int,
sampling_params: SamplingParams,
) -> None:
req_idx = self.free_indices.pop()
self.req_id_to_index[req_id] = req_idx
self.index_to_req_id[req_idx] = req_id
self.num_prompt_tokens.np[req_idx] = len(prompt_token_ids)
self.num_computed_tokens.np[req_idx] = num_computed_tokens
self.append_token_ids(req_idx, prompt_token_ids)
self.temperature.np[req_idx] = sampling_params.temperature
if sampling_params.sampling_type == SamplingType.GREEDY:
# NOTE: Be careful about division by zero.
self.greedy_reqs.add(req_id)
elif sampling_params.sampling_type == SamplingType.RANDOM:
self.random_reqs.add(req_id)
self.top_p.np[req_idx] = sampling_params.top_p
if sampling_params.top_p < 1.0:
self.top_p_reqs.add(req_id)
top_k = sampling_params.top_k
if 0 < top_k < self.vocab_size:
self.top_k_reqs.add(req_id)
else:
top_k = self.vocab_size
self.top_k.np[req_idx] = top_k
self.frequency_penalties.np[
req_idx] = sampling_params.frequency_penalty
if sampling_params.frequency_penalty != 0.0:
self.frequency_penalties_reqs.add(req_id)
self.presence_penalties.np[req_idx] = sampling_params.presence_penalty
if sampling_params.presence_penalty != 0.0:
self.presence_penalties_reqs.add(req_id)
self.repetition_penalties.np[
req_idx] = sampling_params.repetition_penalty
if sampling_params.repetition_penalty != 1.0:
self.repetition_penalties_reqs.add(req_id)
if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
generator = torch.Generator(device=self.device)
generator.manual_seed(sampling_params.seed)
self.generators[req_idx] = generator
def append_token_ids(
self,
req_idx: int,
token_ids: Union[list[int], np.ndarray],
) -> None:
start_idx = self.num_tokens.np[req_idx]
end_idx = start_idx + len(token_ids)
self.token_ids.np[req_idx, start_idx:end_idx] = token_ids
self.num_tokens.np[req_idx] = end_idx
def append_sampled_token_ids(
self,
idx_mapping: np.ndarray,
sampled_token_ids: np.ndarray,
) -> None:
num_reqs = idx_mapping.shape[0]
for i in range(num_reqs):
req_idx = idx_mapping[i]
self.append_token_ids(req_idx, sampled_token_ids[i])
def remove_request(self, req_id: str) -> None:
req_idx = self.req_id_to_index.pop(req_id, None)
if req_idx is None:
# Request not found.
return
self.index_to_req_id.pop(req_idx, None)
self.free_indices.append(req_idx)
self.greedy_reqs.discard(req_id)
self.random_reqs.discard(req_id)
self.top_p_reqs.discard(req_id)
self.top_k_reqs.discard(req_id)
self.frequency_penalties_reqs.discard(req_id)
self.presence_penalties_reqs.discard(req_id)
self.repetition_penalties_reqs.discard(req_id)
self.generators.pop(req_idx, None)
def make_sampling_metadata(
self,
batch_idx_to_req_idx: torch.Tensor,
) -> SamplingMetadata:
batch_size = batch_idx_to_req_idx.shape[0]
# TODO(woosuk): Use UVA to optimize CPU -> GPU copy.
_make_sampling_metadata_kernel[(batch_size, )](
batch_idx_to_req_idx,
self.temperature.mirror_to_gpu(),
self.temperature.gpu,
self.top_p.mirror_to_gpu(),
self.top_p.gpu,
self.top_k.mirror_to_gpu(),
self.top_k.gpu,
self.frequency_penalties.mirror_to_gpu(),
self.frequency_penalties.gpu,
self.presence_penalties.mirror_to_gpu(),
self.presence_penalties.gpu,
self.repetition_penalties.mirror_to_gpu(),
self.repetition_penalties.gpu,
num_warps=1,
num_stages=1,
)
no_penalties = not (self.frequency_penalties_reqs
or self.presence_penalties_reqs
or self.repetition_penalties_reqs)
return SamplingMetadata(
temperature=self.temperature.gpu[:batch_size],
all_greedy=not self.random_reqs,
all_random=not self.greedy_reqs,
top_p=self.top_p.gpu[:batch_size],
top_k=self.top_k.gpu[:batch_size],
frequency_penalties=self.frequency_penalties.gpu[:batch_size],
presence_penalties=self.presence_penalties.gpu[:batch_size],
repetition_penalties=self.repetition_penalties.gpu[:batch_size],
no_penalties=no_penalties,
# TODO
generators={},
token_ids=None,
num_tokens=None,
num_prompt_tokens=None,
max_num_logprobs=None,
allowed_token_ids_mask=None,
bad_words_token_ids={},
logitsprocs=None,
)
@property
def num_cached_reqs(self) -> int:
return len(self.req_id_to_index)
@triton.jit
def _make_sampling_metadata_kernel(
batch_idx_to_req_idx, # [batch_size]
src_temperature,
dst_temperature,
src_top_p,
dst_top_p,
src_top_k,
dst_top_k,
src_frequency_penalties,
dst_frequency_penalties,
src_presence_penalties,
dst_presence_penalties,
src_repetition_penalties,
dst_repetition_penalties,
):
batch_idx = tl.program_id(0)
req_idx = tl.load(batch_idx_to_req_idx + batch_idx)
temperature = tl.load(src_temperature + req_idx)
tl.store(dst_temperature + batch_idx, temperature)
top_p = tl.load(src_top_p + req_idx)
tl.store(dst_top_p + batch_idx, top_p)
top_k = tl.load(src_top_k + req_idx)
tl.store(dst_top_k + batch_idx, top_k)
frequency_penalties = tl.load(src_frequency_penalties + req_idx)
tl.store(dst_frequency_penalties + batch_idx, frequency_penalties)
presence_penalties = tl.load(src_presence_penalties + req_idx)
tl.store(dst_presence_penalties + batch_idx, presence_penalties)
repetition_penalties = tl.load(src_repetition_penalties + req_idx)
tl.store(dst_repetition_penalties + batch_idx, repetition_penalties)

View File

@ -12,7 +12,6 @@ from vllm.sampling_params import SamplingType
from vllm.utils import swap_dict_values
from vllm.v1.outputs import LogprobsTensors
from vllm.v1.worker.block_table import MultiGroupBlockTable
from vllm.v1.worker.gpu_input_batch import CachedRequestState
_SAMPLING_EPS = 1e-5