Compare commits

...

47 Commits

Author SHA1 Message Date
9a6fcca030 fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-14 15:56:42 -07:00
633f9f006d Merge branch 'main' into woosuk/input-prep 2025-09-14 08:03:28 -07:00
eb3742c72a fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-13 19:19:40 -07:00
e47bb9970b fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-13 19:19:07 -07:00
5c133fc860 reorder
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-13 19:17:40 -07:00
caf963f2e9 fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-13 19:13:08 -07:00
9314a83b56 Merge branch 'main' into woosuk/input-prep 2025-09-14 00:44:56 +00:00
7a50a54390 Merge branch 'main' into woosuk/input-prep 2025-09-13 21:33:54 +00:00
787e59629c wip
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-08 16:42:26 -07:00
5f95309a6d rename
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-07 12:01:45 -07:00
286eeb91e8 merge
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-07 11:16:37 -07:00
6283995a6c minor
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-06 21:18:16 -07:00
0c56069c7e merge
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-06 16:35:45 -07:00
8e6cb9aa4a minor
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-06 12:23:02 -07:00
ead95fe5dc merge
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-06 10:56:27 -07:00
23eae07ea5 merge
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-04 20:19:22 -07:00
b16e2d9602 fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-01 02:10:48 -07:00
4c2a337e67 merge
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-01 01:45:29 -07:00
cc340e26af top_p top_k
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-01 01:30:08 -07:00
01bf16ede4 fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-01 01:16:26 -07:00
af7b6c5dd4 fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-31 23:50:20 -07:00
62d23b3006 fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-31 21:00:16 -07:00
ba1a58f51b MAX_SPEC_LEN
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-31 20:43:25 -07:00
22771e5d83 work
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-31 20:41:38 -07:00
c11d1e6781 optimize spec
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-31 16:40:54 -07:00
e696f78e05 minor
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-31 13:29:58 -07:00
efcb786d52 merge
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-31 10:44:36 -07:00
9ee9d0e274 fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-28 15:02:07 -07:00
405578121c minor
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-28 13:19:10 -07:00
19c0dfc469 minor
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-28 13:08:07 -07:00
e451045a66 fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-28 12:55:13 -07:00
efba25e21a minor
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-28 12:39:15 -07:00
b21393cd98 Merge branch 'main' into woosuk/input-prep 2025-08-28 09:58:08 -07:00
d6d719fb24 Merge branch 'main' into woosuk/input-prep 2025-08-28 09:57:49 -07:00
e570b0a4de merge
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-27 21:45:11 -07:00
a851aaa0fc simplify
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-25 09:23:05 -07:00
b1d52734f7 fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-25 08:55:12 -07:00
65f93694be merge
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-25 08:54:32 -07:00
7b4b72e551 fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-24 18:49:23 -07:00
da9cd26c78 Merge branch 'main' into woosuk/input-prep 2025-08-24 18:36:33 -07:00
a1e3745150 wip
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-24 18:36:18 -07:00
48bca9a109 merge
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-23 11:30:29 -07:00
64c8cced18 rename
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-22 01:48:35 -07:00
79e5eb3643 wip
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-22 01:37:43 -07:00
c472982746 merge
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-21 21:40:44 -07:00
699bd7928e Merge branch 'main' into woosuk/input-prep 2025-08-17 19:28:38 -07:00
33a3a26ca5 wip
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-08-17 14:38:24 -07:00
11 changed files with 4098 additions and 1482 deletions

View File

@ -4,26 +4,21 @@
import ast
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional
from typing import Optional
import torch
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
from vllm.attention.ops.triton_unified_attention import unified_attention
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, CommonAttentionMetadata,
reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills)
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata,
split_decodes_and_prefills)
from vllm.v1.kv_cache_interface import AttentionSpec
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm import _custom_ops as ops
logger = init_logger(__name__)
@ -184,13 +179,6 @@ class TreeAttentionMetadataBuilder(
device=device,
)
def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
return reorder_batch_to_split_decodes_and_prefills(
input_batch,
scheduler_output,
decode_threshold=self.tree_attn_bias.shape[0])
def build(
self,
common_prefix_len: int,

View File

@ -697,69 +697,6 @@ def split_decodes_and_prefills(
return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens)
def reorder_batch_to_split_decodes_and_prefills(
input_batch: "InputBatch",
scheduler_output: "SchedulerOutput",
decode_threshold: int = 1,
) -> bool:
"""
Reorders the batch to split into prefill and decode requests; places all
requests with <= decode_threshold tokens at the front of the batch.
Returns:
True if the batch was modified, False otherwise.
"""
# We now want to reorder the batch so that the "decode" requests are at
# the front and the "prefill" requests are at the back using the least
# amount of swaps possible. (NOTE for now we loosely use "decode" to mean
# requests where attention is likely memory-bound and "prefill" to mean
# requests where attention is likely compute-bound, TODO(lucas): figure out
# a better naming here)
decodes = []
prefills = []
num_decode_tokens = 0
num_prefill_tokens = 0
for i, req_id in enumerate(input_batch.req_ids):
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
# for now treat 1 scheduled token as "decode" even if it's not,
# we should update this to something like < 8 in the future but
# currently the TritonMLA._forward_decode only supports
# num_tokens = 1
if num_tokens <= decode_threshold:
decodes.append(i)
num_decode_tokens += num_tokens
else:
prefills.append(i)
num_prefill_tokens += num_tokens
# We hope that this is fairly minimal since decodes
# should be around for a number of iterations so hopefully they are
# relatively stationary (and new request are generally appended to the
# persistent batch so already should be at the back)
# To achieve this we loop over the decodes in descending order and
# the prefills in ascending order. We swap decodes from the "back"
# i.e. past where the last decode should be in the reodorered with
# prefills from the front of the batch.
# `decodes` and `prefills` are already in ascending order just based on
# the above loop
num_decodes = len(decodes)
num_prefills = len(prefills)
modified_batch = False
for i in range(1, min(num_decodes, num_prefills) + 1):
# If the decode is at the "back" of the batch, i, we can swap it
# with the prefill closest to the front of the batch
decode_idx = decodes[num_decodes - i]
if decode_idx < num_decodes:
break
input_batch.swap_states(prefills[i - 1], decode_idx)
modified_batch = True
return modified_batch
KV_SHARING_FAST_PREFILL_METADATA_FIELDS = [
('logits_indices_padded', Optional[torch.Tensor], None),
('num_logits_indices', int, 0),

View File

@ -3,7 +3,7 @@
"""Attention layer with XFormersAttention."""
from dataclasses import dataclass
from typing import TYPE_CHECKING, ClassVar, Optional
from typing import ClassVar, Optional
import torch
@ -12,9 +12,9 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
from vllm.attention.ops.triton_unified_attention import unified_attention
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, CommonAttentionMetadata,
reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills)
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata,
split_decodes_and_prefills)
from vllm.v1.kv_cache_interface import AttentionSpec
try:
@ -26,10 +26,6 @@ try:
except ImportError:
XFORMERS_AVAILABLE = False
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm import _custom_ops as ops
logger = init_logger(__name__)
@ -213,13 +209,6 @@ class XFormersAttentionMetadataBuilder(
self._num_decodes = 0
self._num_decode_tokens = 0
def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
return reorder_batch_to_split_decodes_and_prefills(
input_batch,
scheduler_output,
decode_threshold=self.reorder_batch_threshold)
def build(
self,
common_prefix_len: int,

View File

@ -12,7 +12,7 @@ from vllm.v1.sample.logits_processor import LogitsProcessors
@dataclass
class SamplingMetadata:
temperature: Optional[torch.Tensor]
temperature: torch.Tensor
all_greedy: bool
all_random: bool
@ -25,12 +25,13 @@ class SamplingMetadata:
max_num_logprobs: Optional[int]
no_penalties: bool
prompt_token_ids: Optional[torch.Tensor]
frequency_penalties: torch.Tensor
presence_penalties: torch.Tensor
repetition_penalties: torch.Tensor
output_token_ids: list[list[int]]
token_ids: Optional[torch.Tensor]
num_tokens: Optional[torch.Tensor]
num_prompt_tokens: Optional[torch.Tensor]
# `allowed_token_ids_mask` is a 2D bool tensor of shape (max batch size,
# vocab size).

View File

@ -90,9 +90,9 @@ class Sampler(nn.Module):
# Apply bad words exclusion.
logits = self.apply_bad_words(logits, sampling_metadata)
# Apply logits processors which can impact greedy sampling
for processor in sampling_metadata.logitsprocs.non_argmax_invariant:
logits = processor.apply(logits)
# # Apply logits processors which can impact greedy sampling
# for processor in sampling_metadata.logitsprocs.non_argmax_invariant:
# logits = processor.apply(logits)
# Apply penalties (e.g., min_tokens, freq_penalties).
logits = self.apply_penalties(logits, sampling_metadata)
@ -167,10 +167,10 @@ class Sampler(nn.Module):
# Apply temperature.
logits = self.apply_temperature(logits, sampling_metadata.temperature)
# Apply logits processors that only apply to random sampling
# (argmax invariant)
for processor in sampling_metadata.logitsprocs.argmax_invariant:
logits = processor.apply(logits)
# # Apply logits processors that only apply to random sampling
# # (argmax invariant)
# for processor in sampling_metadata.logitsprocs.argmax_invariant:
# logits = processor.apply(logits)
# Apply top_k and/or top_p.
random_sampled, processed_logprobs = self.topk_topp_sampler(

View File

@ -0,0 +1,317 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
import torch
import triton
import triton.language as tl
from vllm.utils import cdiv
from vllm.v1.utils import CpuGpuBuffer
PAD_SLOT_ID = -1
class BlockTables:
def __init__(
self,
block_sizes: list[int],
max_num_reqs: int,
max_num_cached_reqs: int,
max_num_batched_tokens: int,
max_model_len: int,
device: torch.device,
pin_memory: bool,
):
self.block_sizes = block_sizes
self.max_num_reqs = max_num_reqs
self.max_num_cached_reqs = max_num_cached_reqs
self.max_num_batched_tokens = max_num_batched_tokens
self.max_model_len = max_model_len
self.device = device
self.pin_memory = pin_memory
self.num_kv_cache_groups = len(self.block_sizes)
# [num_kv_cache_groups, max_num_reqs, max_num_blocks]
self.block_tables: list[torch.Tensor] = []
# [num_kv_cache_groups, max_num_cached_reqs, max_num_blocks]
self.block_table_buffers: list[torch.Tensor] = []
for i in range(self.num_kv_cache_groups):
block_size = self.block_sizes[i]
max_num_blocks = cdiv(self.max_model_len, block_size)
block_table = torch.zeros(
self.max_num_reqs,
max_num_blocks,
dtype=torch.int32,
device=self.device,
)
self.block_tables.append(block_table)
block_table_buffer = torch.zeros(
self.max_num_cached_reqs,
max_num_blocks,
dtype=torch.int32,
device=self.device,
)
self.block_table_buffers.append(block_table_buffer)
self.block_table_ptrs = self._make_ptr_tensor(self.block_tables)
self.buffer_ptrs = self._make_ptr_tensor(self.block_table_buffers)
self.block_table_strides = torch.tensor(
[b.stride(0) for b in self.block_tables],
dtype=torch.int64,
device=self.device)
self.block_sizes_tensor = torch.tensor(self.block_sizes,
dtype=torch.int32,
device=self.device)
self.num_blocks = torch.zeros(self.num_kv_cache_groups,
self.max_num_cached_reqs,
dtype=torch.int32,
device=self.device)
self.slot_mappings = torch.zeros(self.num_kv_cache_groups,
self.max_num_batched_tokens,
dtype=torch.int64,
device=self.device)
# Misc buffers.
self.req_indices = self._make_buffer(self.max_num_reqs,
dtype=torch.int32)
self.overwrite = self._make_buffer(self.max_num_reqs, dtype=torch.bool)
self.cu_num_new_blocks = self._make_buffer(self.num_kv_cache_groups,
self.max_num_reqs + 1,
dtype=torch.int32)
def _make_buffer(self, *args, dtype: torch.dtype) -> CpuGpuBuffer:
return CpuGpuBuffer(*args,
dtype=dtype,
pin_memory=self.pin_memory,
device=self.device)
def _make_ptr_tensor(self, x: Iterable[torch.Tensor]) -> torch.Tensor:
ptrs_tensor_cpu = torch.tensor([t.data_ptr() for t in x],
dtype=torch.int64,
device="cpu",
pin_memory=self.pin_memory)
return ptrs_tensor_cpu.to(self.device, non_blocking=True)
def append_block_ids(
self,
# [num_reqs]
req_indices: list[int],
# [num_kv_cache_groups, num_reqs + 1]
cu_num_new_blocks: list[list[int]],
# [num_kv_cache_groups, num_new_blocks]
new_block_ids: list[list[int]],
# [num_reqs]
overwrite: list[bool],
) -> None:
# TODO(woosuk): Optimize & simplify this.
num_reqs = len(req_indices)
self.req_indices.np[:num_reqs] = req_indices
self.overwrite.np[:num_reqs] = overwrite
for i in range(self.num_kv_cache_groups):
self.cu_num_new_blocks.np[i, :num_reqs + 1] = cu_num_new_blocks[i]
# NOTE(woosuk): Here, we cannot use a fixed-size buffer because there's
# no clear upper bound on the number of new blocks.
new_block_ids_cpu = torch.empty(
self.num_kv_cache_groups,
max(len(x) for x in new_block_ids),
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory,
)
new_block_ids_np = new_block_ids_cpu.numpy()
for i in range(self.num_kv_cache_groups):
new_block_ids_np[i, :len(new_block_ids[i])] = new_block_ids[i]
new_block_ids_gpu = new_block_ids_cpu.to(self.device,
non_blocking=True)
_append_block_ids_kernel[(self.num_kv_cache_groups, num_reqs)](
self.req_indices.copy_to_gpu(num_reqs),
self.cu_num_new_blocks.copy_to_gpu(),
self.cu_num_new_blocks.gpu.stride(0),
new_block_ids_gpu,
new_block_ids_gpu.stride(0),
self.overwrite.copy_to_gpu(num_reqs),
self.block_table_strides,
self.buffer_ptrs,
self.num_blocks,
self.num_blocks.stride(0),
BLOCK_SIZE=1024,
)
def compute_block_tables(
self,
idx_mapping: torch.Tensor,
) -> tuple[torch.Tensor, ...]:
batch_size = idx_mapping.shape[0]
_compute_block_tables_kernel[(self.num_kv_cache_groups, batch_size)](
idx_mapping,
self.buffer_ptrs,
self.block_table_ptrs,
self.block_table_strides,
self.num_blocks,
self.num_blocks.stride(0),
BLOCK_SIZE=1024,
)
return tuple(b[:batch_size] for b in self.block_tables)
def compute_slot_mappings(
self,
query_start_loc: torch.Tensor,
positions: torch.Tensor,
) -> tuple[torch.Tensor, ...]:
num_reqs = query_start_loc.shape[0] - 1
num_tokens = positions.shape[0]
num_groups = self.num_kv_cache_groups
_compute_slot_mappings_kernel[(num_groups, num_reqs + 1)](
num_tokens,
self.max_num_batched_tokens,
query_start_loc,
positions,
self.block_table_ptrs,
self.block_table_strides,
self.block_sizes_tensor,
self.slot_mappings,
self.slot_mappings.stride(0),
PAD_ID=PAD_SLOT_ID,
BLOCK_SIZE=1024,
)
return tuple(x[:num_tokens] for x in self.slot_mappings)
@triton.jit
def _append_block_ids_kernel(
# Inputs
req_indices, # [num_reqs]
cu_num_new_blocks_ptr, # [num_kv_cache_groups, num_reqs + 1]
cu_num_new_blocks_stride,
new_block_ids_ptr, # [num_kv_cache_groups, num_new_blocks]
new_block_ids_stride,
overwrite, # [num_reqs]
block_table_strides, # [num_kv_cache_groups]
# Outputs
block_table_buffer_ptrs, # [num_kv_cache_groups]
num_blocks_ptr, # [num_kv_cache_groups, max_num_cached_reqs]
num_blocks_stride,
# Constants
BLOCK_SIZE: tl.constexpr,
):
group_id = tl.program_id(0)
batch_idx = tl.program_id(1)
req_idx = tl.load(req_indices + batch_idx)
do_overwrite = tl.load(overwrite + batch_idx)
group_new_blocks_ptr = (cu_num_new_blocks_ptr +
group_id * cu_num_new_blocks_stride)
start_idx = tl.load(group_new_blocks_ptr + batch_idx)
end_idx = tl.load(group_new_blocks_ptr + batch_idx + 1)
num_new_blocks = end_idx - start_idx
group_num_blocks_ptr = num_blocks_ptr + group_id * num_blocks_stride
if do_overwrite:
dst_start_idx = 0
else:
dst_start_idx = tl.load(group_num_blocks_ptr + req_idx)
dst_end_idx = dst_start_idx + num_new_blocks
tl.store(group_num_blocks_ptr + req_idx, dst_end_idx)
# Destination
block_table_buffer_ptr = _load_ptr(block_table_buffer_ptrs + group_id,
tl.int32)
block_table_stride = tl.load(block_table_strides + group_id)
buffer_row_ptr = block_table_buffer_ptr + req_idx * block_table_stride
group_new_block_ids_ptr = (new_block_ids_ptr +
group_id * new_block_ids_stride)
for i in tl.range(0, num_new_blocks, BLOCK_SIZE):
offset = i + tl.arange(0, BLOCK_SIZE)
block_ids = tl.load(group_new_block_ids_ptr + start_idx + offset,
mask=offset < num_new_blocks)
tl.store(buffer_row_ptr + dst_start_idx + offset,
block_ids,
mask=offset < num_new_blocks)
@triton.jit
def _compute_block_tables_kernel(
batch_idx_to_req_idx, # [batch_size]
src_block_table_ptrs, # [num_kv_cache_groups]
dst_block_table_ptrs, # [num_kv_cache_groups]
block_table_strides, # [num_kv_cache_groups]
num_blocks_ptr, # [num_kv_cache_groups, max_num_cached_reqs]
num_blocks_stride,
BLOCK_SIZE: tl.constexpr,
):
# kv cache group id
group_id = tl.program_id(0)
batch_idx = tl.program_id(1)
req_idx = tl.load(batch_idx_to_req_idx + batch_idx)
group_num_blocks_ptr = num_blocks_ptr + group_id * num_blocks_stride
num_blocks = tl.load(group_num_blocks_ptr + req_idx)
stride = tl.load(block_table_strides + group_id)
src_block_table_ptr = _load_ptr(src_block_table_ptrs + group_id, tl.int32)
src_row_ptr = src_block_table_ptr + req_idx * stride
dst_block_table_ptr = _load_ptr(dst_block_table_ptrs + group_id, tl.int32)
dst_row_ptr = dst_block_table_ptr + batch_idx * stride
for i in tl.range(0, num_blocks, BLOCK_SIZE):
offset = i + tl.arange(0, BLOCK_SIZE)
block_ids = tl.load(src_row_ptr + offset, mask=offset < num_blocks)
tl.store(dst_row_ptr + offset, block_ids, mask=offset < num_blocks)
@triton.jit
def _compute_slot_mappings_kernel(
num_tokens,
max_num_tokens,
cu_num_tokens, # [num_reqs + 1]
pos, # [num_tokens]
block_table_ptrs, # [num_kv_cache_groups]
block_table_strides, # [num_kv_cache_groups]
page_sizes, # [num_kv_cache_groups]
slot_mappings_ptr, # [num_kv_cache_groups, max_num_tokens]
slot_mappings_stride,
PAD_ID: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
# kv cache group id
group_id = tl.program_id(0)
req_idx = tl.program_id(1)
slot_mapping_ptr = slot_mappings_ptr + group_id * slot_mappings_stride
if req_idx == tl.num_programs(1) - 1:
# Pad remaining slots to -1. This is needed for CUDA graphs.
for i in tl.range(num_tokens, max_num_tokens, BLOCK_SIZE):
offset = i + tl.arange(0, BLOCK_SIZE)
tl.store(slot_mapping_ptr + offset,
PAD_ID,
mask=offset < max_num_tokens)
return
block_table_ptr = _load_ptr(block_table_ptrs + group_id, tl.int32)
block_table_stride = tl.load(block_table_strides + group_id)
page_size = tl.load(page_sizes + group_id)
start_idx = tl.load(cu_num_tokens + req_idx)
end_idx = tl.load(cu_num_tokens + req_idx + 1)
for i in tl.range(start_idx, end_idx, BLOCK_SIZE):
offset = i + tl.arange(0, BLOCK_SIZE)
positions = tl.load(pos + offset, mask=offset < end_idx, other=0)
block_indices = positions // page_size
block_numbers = tl.load(block_table_ptr +
req_idx * block_table_stride + block_indices)
slot_ids = block_numbers * page_size + positions % page_size
tl.store(slot_mapping_ptr + offset, slot_ids, mask=offset < end_idx)
@triton.jit
def _load_ptr(ptr_to_ptr, elem_dtype):
ptr = tl.load(ptr_to_ptr)
ptr = tl.cast(ptr, tl.pointer_type(elem_dtype))
return tl.multiple_of(ptr, 16)

View File

@ -1,819 +1,92 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Datastructures defining a GPU input batch
from dataclasses import dataclass
from typing import Optional, cast
from typing import Any, Optional
import numba
import numpy as np
import torch
from typing_extensions import deprecated
from numba import types
from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalKwargsItems
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.utils import swap_dict_values
from vllm.v1.outputs import LogprobsTensors
from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.logits_processor import (BatchUpdateBuilder,
LogitsProcessors,
MoveDirectionality)
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.utils import is_spec_decode_unsupported
from vllm.v1.utils import copy_slice
from vllm.v1.worker.block_table import MultiGroupBlockTable
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
@dataclass
class CachedRequestState:
req_id: str
prompt_token_ids: list[int]
mm_features: list[MultiModalFeatureSpec]
sampling_params: Optional[SamplingParams]
pooling_params: Optional[PoolingParams]
generator: Optional[torch.Generator]
block_ids: tuple[list[int], ...]
num_computed_tokens: int
output_token_ids: list[int]
mrope_positions: Optional[torch.Tensor] = None
mrope_position_delta: Optional[int] = None
lora_request: Optional[LoRARequest] = None
def __post_init__(self):
self.num_prompt_tokens = len(self.prompt_token_ids)
@property
def num_tokens(self) -> int:
return self.num_prompt_tokens + len(self.output_token_ids)
# Temporary back-compatibility for plugins that define model runner
@property
@deprecated("`mm_inputs` is superseded by `mm_kwargs` and will be "
"removed in v0.13. Please use `mm_kwargs` instead.")
def mm_inputs(self) -> list[MultiModalKwargsItems]:
return [
MultiModalKwargsItems.from_seq([f.data]) for f in self.mm_features
if f.data is not None
]
def get_token_id(self, idx: int) -> int:
if idx < self.num_prompt_tokens:
return self.prompt_token_ids[idx]
return self.output_token_ids[idx - self.num_prompt_tokens]
class InputBatch:
def __init__(
self,
max_num_reqs: int,
max_model_len: int,
max_num_batched_tokens: int,
device: torch.device,
pin_memory: bool,
vocab_size: int,
block_sizes: list[int], # The block_size of each kv cache group
logitsprocs: Optional[LogitsProcessors] = None,
is_spec_decode: bool = False,
is_pooling_model: bool = False,
num_speculative_tokens: int = 0,
):
self.is_pooling_model = is_pooling_model
self.is_spec_decode = is_spec_decode
self.max_num_reqs = max_num_reqs
self.max_model_len = max_model_len
self.max_num_batched_tokens = max_num_batched_tokens
self.device = device
self.pin_memory = pin_memory
self.vocab_size = vocab_size
# batch_idx -> req_id
req_ids: list[str]
self._req_ids: list[Optional[str]] = []
self.req_id_to_index: dict[str, int] = {}
# req_id -> batch_idx
req_id_to_batch_idx: dict[str, int]
# TODO(woosuk): This buffer could be too large if max_model_len is big.
# Find a way to reduce the CPU memory usage.
# This buffer is not directly transferred to the GPU, so it does not
# need to be pinned.
self.token_ids_cpu_tensor = torch.zeros(
(max_num_reqs, max_model_len),
device="cpu",
dtype=torch.int32,
pin_memory=False,
# batch_idx -> req_state_idx
idx_mapping: torch.Tensor
idx_mapping_np: np.ndarray
# batch_idx -> num_scheduled_tokens
num_scheduled_tokens: np.ndarray
total_num_tokens: int
max_query_len: int
num_reqs: int
attn_metadata: dict[str, Any]
spec_decode_common_attn_metadata: Optional[Any]
spec_decode_metadata: Optional[SpecDecodeMetadata]
logits_indices: torch.Tensor
# NOTE: With the type annotations, this function is pre-compiled
# before the first call.
@numba.jit(
[
types.none(
types.int32[:], # idx_mapping
types.int32[:, :], # token_ids
types.int32[:], # num_computed_tokens
types.int32[:], # num_scheduled_tokens
types.int32[:], # input_ids
types.int32[:], # query_start_loc
types.int32[:], # seq_lens
types.int64[:], # positions
)
self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32)
self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32)
self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
self.num_computed_tokens_cpu_tensor = torch.zeros(
(max_num_reqs, ),
device="cpu",
dtype=torch.int32,
pin_memory=pin_memory,
)
self.num_computed_tokens_cpu = \
self.num_computed_tokens_cpu_tensor.numpy()
# Block table.
self.block_table = MultiGroupBlockTable(
max_num_reqs=max_num_reqs,
max_model_len=max_model_len,
max_num_batched_tokens=max_num_batched_tokens,
pin_memory=pin_memory,
device=device,
block_sizes=block_sizes,
num_speculative_tokens=num_speculative_tokens,
)
# Sampling-related.
self.temperature = torch.empty((max_num_reqs, ),
dtype=torch.float32,
device=device)
self.temperature_cpu_tensor = torch.empty((max_num_reqs, ),
dtype=torch.float32,
device="cpu",
pin_memory=pin_memory)
self.temperature_cpu = self.temperature_cpu_tensor.numpy()
self.greedy_reqs: set[str] = set()
self.random_reqs: set[str] = set()
self.top_p = torch.empty((max_num_reqs, ),
dtype=torch.float32,
device=device)
self.top_p_cpu_tensor = torch.empty((max_num_reqs, ),
dtype=torch.float32,
device="cpu",
pin_memory=pin_memory)
self.top_p_cpu = self.top_p_cpu_tensor.numpy()
self.top_p_reqs: set[str] = set()
self.top_k = torch.empty((max_num_reqs, ),
dtype=torch.int32,
device=device)
self.top_k_cpu_tensor = torch.empty((max_num_reqs, ),
dtype=torch.int32,
device="cpu",
pin_memory=pin_memory)
self.top_k_cpu = self.top_k_cpu_tensor.numpy()
self.top_k_reqs: set[str] = set()
# IDs of requests which do not support spec decoding
self.spec_decode_unsupported_reqs: set[str] = set()
# Frequency penalty related data structures
self.frequency_penalties = torch.empty((max_num_reqs, ),
dtype=torch.float,
device=device)
self.frequency_penalties_cpu_tensor = torch.empty(
(max_num_reqs, ),
dtype=torch.float,
device="cpu",
pin_memory=pin_memory)
self.frequency_penalties_cpu = \
self.frequency_penalties_cpu_tensor.numpy()
self.frequency_penalties_reqs: set[str] = set()
# Presence penalty related data structures
self.presence_penalties = torch.empty((max_num_reqs, ),
dtype=torch.float,
device=device)
self.presence_penalties_cpu_tensor = torch.empty((max_num_reqs, ),
dtype=torch.float,
device="cpu",
pin_memory=pin_memory)
self.presence_penalties_cpu = self.presence_penalties_cpu_tensor.numpy(
)
self.presence_penalties_reqs: set[str] = set()
# Repetition penalty related data structures
self.repetition_penalties = torch.empty((max_num_reqs, ),
dtype=torch.float,
device=device)
self.repetition_penalties_cpu_tensor = torch.empty(
(max_num_reqs, ),
dtype=torch.float,
device="cpu",
pin_memory=pin_memory)
self.repetition_penalties_cpu = \
self.repetition_penalties_cpu_tensor.numpy()
self.repetition_penalties_reqs: set[str] = set()
# Speculative decoding
self.num_accepted_tokens_cpu_tensor = torch.ones((max_num_reqs, ),
dtype=torch.int64,
device="cpu",
pin_memory=pin_memory)
self.num_accepted_tokens_cpu = \
self.num_accepted_tokens_cpu_tensor.numpy()
# lora related
self.request_lora_mapping = np.zeros((self.max_num_reqs, ),
dtype=np.int32)
self.lora_id_to_request_ids: dict[int, set[str]] = {}
self.lora_id_to_lora_request: dict[int, LoRARequest] = {}
# req_index -> generator
# NOTE(woosuk): The indices of the requests that do not have their own
# generator should not be included in the dictionary.
self.generators: dict[int, torch.Generator] = {}
self.num_logprobs: dict[str, int] = {}
# NOTE(rob): num_prompt_logprobs only includes reqs
# that are currently in the prefill phase.
self.num_prompt_logprobs: dict[str, int] = {}
# To accumulate prompt logprobs tensor chunks across prefill steps.
self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {}
# Internal representation of per-step batch state changes, used for
# reordering persistent batch and generating logitsprocs batch state
# updates. Should reset each step.
self.batch_update_builder = BatchUpdateBuilder()
# TODO convert this to LogitsProcessor
self.has_allowed_token_ids: set[str] = set()
# NOTE(lufang): In the mask tensor, if the corresponding token allowed,
# the value is False. Since we use masked_fill_ to set -inf.
self.allowed_token_ids_mask: Optional[torch.Tensor] = None
self.allowed_token_ids_mask_cpu_tensor: Optional[torch.Tensor] = None
# req_index -> bad_words_token_ids
self.bad_words_token_ids: dict[int, list[list[int]]] = {}
self.logits_processing_needs_token_ids = np.zeros(max_num_reqs,
dtype=bool)
self.req_output_token_ids: list[Optional[list[int]]] = []
# Store provided logitsprocs. If none are provided, initialize empty
# data structure
self.logitsprocs = logitsprocs or LogitsProcessors()
# This is updated each time the batch constituents change.
self.sampling_metadata = self._make_sampling_metadata()
self.pooling_params: dict[str, PoolingParams] = {}
# Cached reference to the GPU tensor of previously sampled tokens
self.prev_sampled_token_ids: Optional[torch.Tensor] = None
self.prev_sampled_token_ids_invalid_indices: Optional[set[int]] = None
self.prev_req_id_to_index: Optional[dict[str, int]] = None
@property
def req_ids(self) -> list[str]:
# None elements should only be present transiently
# while performing state updates to the batch.
return cast(list[str], self._req_ids)
def _register_add_request(self, request: "CachedRequestState") -> int:
"""Track add-request operations for logits processors.
Not applicable to pooling models.
"""
# Fill the next empty index if there is one.
if (new_req_index := self.batch_update_builder.pop_removed()) is None:
# Append to end otherwise.
new_req_index = self.num_reqs
assert new_req_index < self.max_num_reqs
self.batch_update_builder.batch_changed = True
if request.sampling_params:
# Detailed added request metadata is only required for non-pooling
# models, to support logitsprocs.
self.batch_update_builder.added.append(
(new_req_index, request.sampling_params,
request.prompt_token_ids, request.output_token_ids))
return new_req_index
def add_request(
self,
request: "CachedRequestState",
) -> int:
req_index = self._register_add_request(request)
req_id = request.req_id
if req_index == len(self._req_ids):
self._req_ids.append(req_id)
self.req_output_token_ids.append(request.output_token_ids)
else:
self._req_ids[req_index] = req_id
self.req_output_token_ids[req_index] = request.output_token_ids
self.req_id_to_index[req_id] = req_index
# Copy the prompt token ids and output token ids.
num_prompt_tokens = len(request.prompt_token_ids)
self.num_prompt_tokens[req_index] = num_prompt_tokens
self.token_ids_cpu[
req_index, :num_prompt_tokens] = request.prompt_token_ids
start_idx = num_prompt_tokens
end_idx = start_idx + len(request.output_token_ids)
self.token_ids_cpu[req_index,
start_idx:end_idx] = request.output_token_ids
# Number of token ids in token_ids_cpu.
# NOTE(woosuk): This may include spec decode tokens.
self.num_tokens[req_index] = request.num_tokens
# Number of tokens without spec decode tokens.
self.num_tokens_no_spec[req_index] = request.num_tokens
self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
self.block_table.add_row(request.block_ids, req_index)
if sampling_params := request.sampling_params:
if (self.is_spec_decode
and is_spec_decode_unsupported(sampling_params)):
self.spec_decode_unsupported_reqs.add(req_id)
if sampling_params.sampling_type == SamplingType.GREEDY:
# Avoid later division by zero.
self.temperature_cpu[req_index] = -1.0
self.greedy_reqs.add(req_id)
else:
self.temperature_cpu[req_index] = sampling_params.temperature
self.random_reqs.add(req_id)
self.top_p_cpu[req_index] = sampling_params.top_p
if sampling_params.top_p < 1:
self.top_p_reqs.add(req_id)
top_k = sampling_params.top_k
if 0 < top_k < self.vocab_size:
self.top_k_reqs.add(req_id)
else:
top_k = self.vocab_size
self.top_k_cpu[req_index] = top_k
self.frequency_penalties_cpu[
req_index] = sampling_params.frequency_penalty
if sampling_params.frequency_penalty != 0.0:
self.frequency_penalties_reqs.add(req_id)
self.presence_penalties_cpu[
req_index] = sampling_params.presence_penalty
if sampling_params.presence_penalty != 0.0:
self.presence_penalties_reqs.add(req_id)
self.repetition_penalties_cpu[
req_index] = sampling_params.repetition_penalty
if sampling_params.repetition_penalty != 1.0:
self.repetition_penalties_reqs.add(req_id)
# NOTE(woosuk): self.generators should not include the requests that
# do not have their own generator.
if request.generator is not None:
self.generators[req_index] = request.generator
if sampling_params.logprobs is not None:
self.num_logprobs[req_id] = (self.vocab_size
if sampling_params.logprobs == -1
else sampling_params.logprobs)
if sampling_params.prompt_logprobs is not None:
self.num_prompt_logprobs[req_id] = (
self.vocab_size if sampling_params.prompt_logprobs == -1
else sampling_params.prompt_logprobs)
if sampling_params.allowed_token_ids:
self.has_allowed_token_ids.add(req_id)
if self.allowed_token_ids_mask_cpu_tensor is None:
# Lazy allocation for this tensor, which can be large.
# False means we don't fill with -inf.
self.allowed_token_ids_mask = torch.zeros(
self.max_num_reqs,
self.vocab_size,
dtype=torch.bool,
device=self.device)
self.allowed_token_ids_mask_cpu_tensor = torch.zeros(
self.max_num_reqs,
self.vocab_size,
dtype=torch.bool,
device="cpu")
self.allowed_token_ids_mask_cpu_tensor[req_index] = True
# False means we don't fill with -inf.
self.allowed_token_ids_mask_cpu_tensor[req_index][
sampling_params.allowed_token_ids] = False
if sampling_params.bad_words_token_ids:
self.bad_words_token_ids[
req_index] = sampling_params.bad_words_token_ids
elif pooling_params := request.pooling_params:
self.pooling_params[req_id] = pooling_params
self.logits_processing_needs_token_ids[req_index] = (
pooling_params.requires_token_ids)
else:
raise NotImplementedError("Unrecognized request type")
# Speculative decoding: by default 1 token is generated.
self.num_accepted_tokens_cpu[req_index] = 1
# Add request lora ID
if request.lora_request:
lora_id = request.lora_request.lora_int_id
if lora_id not in self.lora_id_to_request_ids:
self.lora_id_to_request_ids[lora_id] = set()
self.request_lora_mapping[req_index] = lora_id
self.lora_id_to_request_ids[lora_id].add(request.req_id)
self.lora_id_to_lora_request[lora_id] = request.lora_request
else:
# No LoRA
self.request_lora_mapping[req_index] = 0
return req_index
def remove_request(self, req_id: str) -> Optional[int]:
"""This method must always be followed by a call to condense().
Args:
req_id: request to remove
Returns:
Removed request index, or `None` if `req_id` not recognized
"""
req_index = self.req_id_to_index.pop(req_id, None)
if req_index is None:
return None
self.batch_update_builder.removed_append(req_index)
self._req_ids[req_index] = None
self.req_output_token_ids[req_index] = None
# LoRA
lora_id = self.request_lora_mapping[req_index]
if lora_id != 0:
lora_req_ids = self.lora_id_to_request_ids[lora_id]
lora_req_ids.discard(req_id)
if not lora_req_ids:
del self.lora_id_to_request_ids[lora_id]
del self.lora_id_to_lora_request[lora_id]
self.request_lora_mapping[req_index] = 0
if self.is_pooling_model:
self.pooling_params.pop(req_id, None)
return req_index
self.greedy_reqs.discard(req_id)
self.random_reqs.discard(req_id)
self.top_p_reqs.discard(req_id)
self.top_k_reqs.discard(req_id)
self.spec_decode_unsupported_reqs.discard(req_id)
self.frequency_penalties_reqs.discard(req_id)
self.presence_penalties_reqs.discard(req_id)
self.repetition_penalties_reqs.discard(req_id)
self.generators.pop(req_index, None)
self.num_logprobs.pop(req_id, None)
self.num_prompt_logprobs.pop(req_id, None)
self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
self.has_allowed_token_ids.discard(req_id)
if self.allowed_token_ids_mask_cpu_tensor is not None:
# False means we don't fill with -inf.
self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False)
self.bad_words_token_ids.pop(req_index, None)
return req_index
def swap_states(self, i1: int, i2: int) -> None:
old_id_i1 = self._req_ids[i1]
old_id_i2 = self._req_ids[i2]
self._req_ids[i1], self._req_ids[i2] =\
self._req_ids[i2], self._req_ids[i1] # noqa
self.req_output_token_ids[i1], self.req_output_token_ids[i2] =\
self.req_output_token_ids[i2], self.req_output_token_ids[i1]
assert old_id_i1 is not None and old_id_i2 is not None
self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] =\
self.req_id_to_index[old_id_i2], self.req_id_to_index[old_id_i1]
self.num_tokens[i1], self.num_tokens[i2] =\
self.num_tokens[i2], self.num_tokens[i1]
self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] =\
self.num_tokens_no_spec[i2], self.num_tokens_no_spec[i1]
self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] =\
self.num_prompt_tokens[i2], self.num_prompt_tokens[i1]
self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] =\
self.num_computed_tokens_cpu[i2], self.num_computed_tokens_cpu[i1]
# NOTE: the following is unsafe
# self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
# self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...]
# instead, we need to temporiarily copy the data for one of the indices
# TODO(lucas): optimize this by only copying valid indices
tmp = self.token_ids_cpu[i1, ...].copy()
self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...]
self.token_ids_cpu[i2, ...] = tmp
self.block_table.swap_row(i1, i2)
self.request_lora_mapping[i1], self.request_lora_mapping[i2] = \
self.request_lora_mapping[i2], self.request_lora_mapping[i1]
if self.is_pooling_model:
# Sampling and logits parameters don't apply to pooling models.
return
# For autoregressive models, track detailed request reordering info
# to support logitsprocs.
self.batch_update_builder.moved.append(
(i1, i2, MoveDirectionality.SWAP))
self.temperature_cpu[i1], self.temperature_cpu[i2] = \
self.temperature_cpu[i2], self.temperature_cpu[i1]
self.top_p_cpu[i1], self.top_p_cpu[i2] = \
self.top_p_cpu[i2], self.top_p_cpu[i1]
self.top_k_cpu[i1], self.top_k_cpu[i2] = \
self.top_k_cpu[i2], self.top_k_cpu[i1]
self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] = \
self.frequency_penalties_cpu[i2], self.frequency_penalties_cpu[i1]
self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] = \
self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1]
self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] = \
self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1]
self.num_accepted_tokens_cpu[i1], self.num_accepted_tokens_cpu[i2] =\
self.num_accepted_tokens_cpu[i2], self.num_accepted_tokens_cpu[i1]
swap_dict_values(self.generators, i1, i2)
swap_dict_values(self.bad_words_token_ids, i1, i2)
if self.allowed_token_ids_mask_cpu_tensor is not None:
self.allowed_token_ids_mask_cpu_tensor[i1], \
self.allowed_token_ids_mask_cpu_tensor[i2] =\
self.allowed_token_ids_mask_cpu_tensor[i2], \
self.allowed_token_ids_mask_cpu_tensor[i1]
def condense(self) -> None:
"""Slide non-empty requests down into lower, empty indices.
Any consecutive empty indices at the very end of the list are not
filled.
Returns:
swaps: list of (from,to) swap tuples for moved requests
empty_req_indices: indices not filled by condensation
"""
num_reqs = self.num_reqs
if not (empty_req_indices := self.batch_update_builder.removed):
# All removed requests were replaced by added requests, or else no
# requests were removed at all. No condense() needed
return
if num_reqs == 0:
# The batched states are empty.
self._req_ids.clear()
self.req_output_token_ids.clear()
return
# NOTE(woosuk): This function assumes that the empty_req_indices
# is sorted in descending order.
last_req_index = num_reqs + len(empty_req_indices) - 1
while empty_req_indices:
# Find the largest non-empty index.
while last_req_index in empty_req_indices:
last_req_index -= 1
# Find the smallest empty index.
empty_index = self.batch_update_builder.peek_removed()
assert empty_index is not None
if empty_index >= last_req_index:
break
# Move active request down into empty request
# index.
self.batch_update_builder.pop_removed()
req_id = self._req_ids[last_req_index]
output_token_ids = self.req_output_token_ids[last_req_index]
assert req_id is not None
self._req_ids[empty_index] = req_id
self._req_ids[last_req_index] = None
self.req_output_token_ids[empty_index] = output_token_ids
self.req_output_token_ids[last_req_index] = None
self.req_id_to_index[req_id] = empty_index
num_tokens = self.num_tokens[last_req_index]
self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
last_req_index, :num_tokens]
self.num_tokens[empty_index] = num_tokens
self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[
last_req_index]
self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[
last_req_index]
self.num_computed_tokens_cpu[
empty_index] = self.num_computed_tokens_cpu[last_req_index]
self.block_table.move_row(last_req_index, empty_index)
self.request_lora_mapping[empty_index] = self.request_lora_mapping[
last_req_index]
if self.is_pooling_model:
last_req_index -= 1
# Sampling state not used by pooling models.
continue
# Autoregressive models require detailed tracking of condense
# operations to support logitsprocs
self.batch_update_builder.moved.append(
(last_req_index, empty_index,
MoveDirectionality.UNIDIRECTIONAL))
self.temperature_cpu[empty_index] = self.temperature_cpu[
last_req_index]
self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index]
self.frequency_penalties_cpu[
empty_index] = self.frequency_penalties_cpu[last_req_index]
self.presence_penalties_cpu[
empty_index] = self.presence_penalties_cpu[last_req_index]
self.repetition_penalties_cpu[
empty_index] = self.repetition_penalties_cpu[last_req_index]
self.num_accepted_tokens_cpu[
empty_index] = self.num_accepted_tokens_cpu[last_req_index]
generator = self.generators.pop(last_req_index, None)
if generator is not None:
self.generators[empty_index] = generator
# TODO convert these to LogitsProcessors
if self.allowed_token_ids_mask_cpu_tensor is not None:
self.allowed_token_ids_mask_cpu_tensor[
empty_index] = self.allowed_token_ids_mask_cpu_tensor[
last_req_index]
bad_words_token_ids = self.bad_words_token_ids.pop(
last_req_index, None)
if bad_words_token_ids is not None:
self.bad_words_token_ids[empty_index] = bad_words_token_ids
# Decrement last_req_index since it is now empty.
last_req_index -= 1
# Trim lists to the batch size.
del self._req_ids[num_reqs:]
del self.req_output_token_ids[num_reqs:]
def refresh_metadata(self):
"""Apply any batch updates to sampling metadata."""
if self.is_pooling_model:
batch_changed = self.batch_update_builder.reset()
if batch_changed:
self.sampling_metadata = self._make_sampling_metadata()
return
# For non-pooling models - generate and apply logitsprocs update;
# reset batch update tracking.
# Update sampling metadata if batch state is changed.
batch_update = self.batch_update_builder.get_and_reset(self.num_reqs)
for logit_proc in self.logitsprocs.all:
logit_proc.update_state(batch_update)
if batch_update:
self.sampling_metadata = self._make_sampling_metadata()
def _make_sampling_metadata(self) -> SamplingMetadata:
num_reqs = self.num_reqs
if not self.all_greedy:
temperature = copy_slice(self.temperature_cpu_tensor,
self.temperature, num_reqs)
else:
temperature = None
if not self.no_top_p:
copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs)
if not self.no_top_k:
copy_slice(self.top_k_cpu_tensor, self.top_k, num_reqs)
if not self.no_penalties:
# Since syncing these tensors is expensive only copy them
# if necessary i.e. if there are requests which require
# penalties to be applied during sampling.
copy_slice(self.frequency_penalties_cpu_tensor,
self.frequency_penalties, num_reqs)
copy_slice(self.presence_penalties_cpu_tensor,
self.presence_penalties, num_reqs)
copy_slice(self.repetition_penalties_cpu_tensor,
self.repetition_penalties, num_reqs)
needs_prompt_token_ids = (
not self.no_penalties
or self.logits_processing_needs_token_ids[:num_reqs].any())
if needs_prompt_token_ids:
# The prompt tokens are used only for applying penalties or
# step pooling during the sampling/pooling process.
# Hence copy these tensors only when there are requests which
# need penalties/step_pooler to be applied.
prompt_token_ids = self._make_prompt_token_ids_tensor()
else:
prompt_token_ids = None
allowed_token_ids_mask: Optional[torch.Tensor] = None
if not self.no_allowed_token_ids:
assert self.allowed_token_ids_mask is not None
copy_slice(self.allowed_token_ids_mask_cpu_tensor,
self.allowed_token_ids_mask, num_reqs)
allowed_token_ids_mask = self.allowed_token_ids_mask[:num_reqs]
return SamplingMetadata(
temperature=temperature,
all_greedy=self.all_greedy,
all_random=self.all_random,
top_p=None if self.no_top_p else self.top_p[:num_reqs],
top_k=None if self.no_top_k else self.top_k[:num_reqs],
generators=self.generators,
max_num_logprobs=self.max_num_logprobs,
prompt_token_ids=prompt_token_ids,
frequency_penalties=self.frequency_penalties[:num_reqs],
presence_penalties=self.presence_penalties[:num_reqs],
repetition_penalties=self.repetition_penalties[:num_reqs],
output_token_ids=cast(list[list[int]], self.req_output_token_ids),
no_penalties=self.no_penalties,
allowed_token_ids_mask=allowed_token_ids_mask,
bad_words_token_ids=self.bad_words_token_ids,
logitsprocs=self.logitsprocs,
)
def get_pooling_params(self) -> list[PoolingParams]:
assert len(self.req_ids) == len(self.pooling_params)
return [self.pooling_params[req_id] for req_id in self.req_ids]
def get_pooling_metadata(self) -> PoolingMetadata:
pooling_params = self.get_pooling_params()
return PoolingMetadata(
prompt_lens=torch.from_numpy(
self.num_prompt_tokens[:self.num_reqs]),
prompt_token_ids=self.sampling_metadata.prompt_token_ids,
pooling_params=pooling_params,
)
def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
num_reqs = self.num_reqs
max_prompt_len = self.num_prompt_tokens[:num_reqs].max()
prompt_token_ids_cpu_tensor = torch.empty(
(self.num_reqs, max_prompt_len),
device="cpu",
dtype=torch.int64,
pin_memory=self.pin_memory,
)
prompt_token_ids = prompt_token_ids_cpu_tensor.numpy()
prompt_token_ids[:] = self.token_ids_cpu[:num_reqs, :max_prompt_len]
# Use the value of vocab_size as a pad since we don't have a
# token_id of this value.
for i in range(num_reqs):
prompt_token_ids[i, self.num_prompt_tokens[i]:] = self.vocab_size
return prompt_token_ids_cpu_tensor.to(device=self.device,
non_blocking=True)
def make_lora_inputs(
self, num_scheduled_tokens: np.ndarray
) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
"""
Given the num_scheduled_tokens for each request in the batch, return
datastructures used to activate the current LoRAs.
Returns:
1. prompt_lora_mapping: A tuple of size self.num_reqs where,
prompt_lora_mapping[i] is the LoRA id to use for the ith prompt.
2. token_lora_mapping: A tuple of size np.sum(num_scheduled_tokens)
where, token_lora_mapping[i] is the LoRA id to use for ith token.
3. lora_requests: Set of relevant LoRA requests.
"""
req_lora_mapping = self.request_lora_mapping[:self.num_reqs]
prompt_lora_mapping = tuple(req_lora_mapping)
token_lora_mapping = tuple(
req_lora_mapping.repeat(num_scheduled_tokens))
active_lora_requests: set[LoRARequest] = set(
self.lora_id_to_lora_request.values())
return prompt_lora_mapping, token_lora_mapping, active_lora_requests
@property
def num_reqs(self) -> int:
return len(self.req_id_to_index)
@property
def all_greedy(self) -> bool:
return len(self.random_reqs) == 0
@property
def all_random(self) -> bool:
return len(self.greedy_reqs) == 0
@property
def no_top_p(self) -> bool:
return len(self.top_p_reqs) == 0
@property
def no_top_k(self) -> bool:
return len(self.top_k_reqs) == 0
@property
def no_penalties(self) -> bool:
return (len(self.presence_penalties_reqs) == 0
and len(self.frequency_penalties_reqs) == 0
and len(self.repetition_penalties_reqs) == 0)
@property
def max_num_logprobs(self) -> Optional[int]:
return max(self.num_logprobs.values()) if self.num_logprobs else None
@property
def no_prompt_logprob(self) -> bool:
return not self.num_prompt_logprobs
@property
def no_allowed_token_ids(self) -> bool:
return len(self.has_allowed_token_ids) == 0
],
nopython=True,
cache=True,
)
def prepare_inputs(
idx_mapping: np.ndarray, # batch_idx -> req_idx
token_ids: np.ndarray, # [N, max_model_len]
num_computed_tokens: np.ndarray, # [N]
num_scheduled_tokens: np.ndarray, # [B]
input_ids: np.ndarray, # [num_input_tokens]
query_start_loc: np.ndarray, # [B + 1]
seq_lens: np.ndarray, # [B]
positions: np.ndarray, # [num_input_tokens]
) -> None:
num_reqs = num_scheduled_tokens.shape[0]
query_start_loc[0] = 0
cu_num_tokens = 0
for i in range(num_reqs):
req_idx = idx_mapping[i]
query_len = num_scheduled_tokens[i]
start = num_computed_tokens[req_idx]
end = start + query_len
seq_lens[i] = end
start_idx = cu_num_tokens
end_idx = start_idx + query_len
input_ids[start_idx:end_idx] = token_ids[req_idx, start:end]
positions[start_idx:end_idx] = np.arange(start, end, dtype=np.int64)
cu_num_tokens = end_idx
query_start_loc[i + 1] = cu_num_tokens
# Pad the inputs for CUDA graphs.
# Note: pad query_start_loc to be non-decreasing, as kernels
# like FlashAttention requires that
query_start_loc[num_reqs + 1:].fill(cu_num_tokens)
# Fill unused with 0 for full cuda graph mode.
seq_lens[num_reqs:].fill(0)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,291 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Optional, Union
import numpy as np
import torch
import triton
import triton.language as tl
from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.v1.sample.logits_processor import LogitsProcessors
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
_MAX_SPEC_LEN = 32
@dataclass
class RequestData:
mm_kwargs: list[MultiModalKwargsItem]
mm_positions: list[PlaceholderRange]
sampling_params: Optional[SamplingParams]
pooling_params: Optional[PoolingParams]
mm_hashes: list[str]
# M-RoPE (only for Qwen2/2.5-VL)
mrope_positions: Optional[torch.Tensor] = None
mrope_position_delta: Optional[int] = None
lora_request: Optional[LoRARequest] = None
class RequestState:
def __init__(
self,
max_num_reqs: int,
max_model_len: int,
max_num_batched_tokens: int,
max_num_cached_reqs: int,
device: torch.device,
pin_memory: bool,
vocab_size: int,
logitsprocs: Optional[LogitsProcessors] = None,
is_spec_decode: bool = False,
is_pooling_model: bool = False,
):
self.max_num_reqs = max_num_reqs
self.max_model_len = max_model_len
self.max_num_batched_tokens = max_num_batched_tokens
self.max_num_cached_reqs = max_num_cached_reqs
self.device = device
self.pin_memory = pin_memory
self.vocab_size = vocab_size
self.is_spec_decode = is_spec_decode
self.pooling_params = None
self.num_prompt_logprobs: dict[int, int] = {}
self.req_id_to_index: dict[str, int] = {}
self.index_to_req_id: dict[int, str] = {}
self.free_indices = list(range(max_num_cached_reqs))
# Request states.
self.req_data: dict[int, RequestData] = {}
# TODO(woosuk): Because the token_ids tensor can be very big, we only
# initialize it on CPU memory.
self.token_ids = np.zeros(
(self.max_num_cached_reqs, self.max_model_len),
dtype=np.int32,
)
self.num_prompt_tokens = np.zeros(self.max_num_cached_reqs,
dtype=np.int32)
self.num_tokens = np.zeros(self.max_num_cached_reqs, dtype=np.int32)
self.num_computed_tokens = np.zeros(self.max_num_cached_reqs,
dtype=np.int32)
# Last sampled token ids.
self.last_sampled_token = torch.zeros(
self.max_num_cached_reqs,
dtype=torch.int32,
device=self.device,
)
self.temperature = np.zeros(self.max_num_cached_reqs, dtype=np.float32)
self.top_p = np.zeros(self.max_num_cached_reqs, dtype=np.float32)
self.top_k = np.zeros(self.max_num_cached_reqs, dtype=np.int32)
self.frequency_penalties = np.zeros(self.max_num_cached_reqs,
dtype=np.float32)
self.presence_penalties = np.zeros(self.max_num_cached_reqs,
dtype=np.float32)
self.repetition_penalties = np.zeros(self.max_num_cached_reqs,
dtype=np.float32)
self.generators: dict[int, torch.Generator] = {}
def add_request(
self,
req_id: str,
prompt_token_ids: list[int],
num_computed_tokens: int,
sampling_params: SamplingParams,
) -> None:
assert len(self.free_indices) > 0, "No free space in GPU worker states"
req_idx = self.free_indices.pop()
self.req_id_to_index[req_id] = req_idx
self.index_to_req_id[req_idx] = req_id
prompt_len = len(prompt_token_ids)
self.num_prompt_tokens[req_idx] = prompt_len
self.num_tokens[req_idx] = prompt_len
self.token_ids[req_idx, :prompt_len] = prompt_token_ids
self.num_computed_tokens[req_idx] = num_computed_tokens
self.temperature[req_idx] = sampling_params.temperature
self.top_p[req_idx] = sampling_params.top_p
if 0 < sampling_params.top_k < self.vocab_size:
top_k = sampling_params.top_k
else:
top_k = self.vocab_size
self.top_k[req_idx] = top_k
self.frequency_penalties[req_idx] = sampling_params.frequency_penalty
self.presence_penalties[req_idx] = sampling_params.presence_penalty
self.repetition_penalties[req_idx] = sampling_params.repetition_penalty
if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
generator = torch.Generator(device=self.device)
generator.manual_seed(sampling_params.seed)
self.generators[req_idx] = generator
@property
def num_cached_reqs(self) -> int:
return len(self.req_id_to_index)
def append_token_ids(
self,
req_idx: int,
token_ids: Union[list[int], np.ndarray],
) -> None:
start_idx = self.num_tokens.np[req_idx]
end_idx = start_idx + len(token_ids)
self.token_ids.np[req_idx, start_idx:end_idx] = token_ids
self.num_tokens.np[req_idx] = end_idx
def remove_request(self, req_id: str) -> None:
req_idx = self.req_id_to_index.pop(req_id, None)
if req_idx is None:
# Request not found.
return
self.index_to_req_id.pop(req_idx, None)
self.free_indices.append(req_idx)
def make_sampling_metadata(
self,
idx_mapping: np.ndarray,
) -> SamplingMetadata:
temperature = self.temperature[idx_mapping]
all_greedy = np.all(temperature == 0.0)
all_random = np.all(temperature != 0.0)
temperature = self._copy_np_to_gpu(temperature)
top_p = self.top_p[idx_mapping]
no_top_p = np.all(top_p == 1.0)
top_p = self._copy_np_to_gpu(top_p) if not no_top_p else None
top_k = self.top_k[idx_mapping]
no_top_k = np.all(top_k == self.vocab_size)
top_k = self._copy_np_to_gpu(top_k) if not no_top_k else None
frequency_penalties = self.frequency_penalties[idx_mapping]
presence_penalties = self.presence_penalties[idx_mapping]
repetition_penalties = self.repetition_penalties[idx_mapping]
no_penalties = (np.all(frequency_penalties == 0.0)
and np.all(presence_penalties == 0.0)
and np.all(repetition_penalties == 1.0))
if no_penalties:
frequency_penalties = None
presence_penalties = None
repetition_penalties = None
else:
frequency_penalties = self._copy_np_to_gpu(frequency_penalties)
presence_penalties = self._copy_np_to_gpu(presence_penalties)
repetition_penalties = self._copy_np_to_gpu(repetition_penalties)
if self.generators:
generators = {
req_idx: self.generators[req_idx]
for req_idx in idx_mapping if req_idx in self.generators
}
else:
generators = {}
return SamplingMetadata(
temperature=temperature,
all_greedy=all_greedy,
all_random=all_random,
top_p=top_p,
top_k=top_k,
frequency_penalties=frequency_penalties,
presence_penalties=presence_penalties,
repetition_penalties=repetition_penalties,
no_penalties=no_penalties,
generators=generators,
token_ids=None,
num_tokens=None,
num_prompt_tokens=None,
max_num_logprobs=None,
allowed_token_ids_mask=None,
bad_words_token_ids={},
logitsprocs=None,
)
def _copy_np_to_gpu(self, src: np.ndarray) -> torch.Tensor:
cpu_tensor = torch.from_numpy(src)
return cpu_tensor.to(device=self.device, non_blocking=True)
def make_spec_decode_metadata(
self,
query_start_loc: torch.Tensor,
cu_num_draft_tokens: torch.Tensor,
cu_num_draft_tokens_np: np.ndarray,
input_ids: torch.Tensor,
) -> SpecDecodeMetadata:
batch_size = query_start_loc.shape[0] - 1
total_num_draft_tokens = cu_num_draft_tokens_np[batch_size - 1]
logits_indices = torch.empty(total_num_draft_tokens + batch_size,
dtype=torch.int32,
device=self.device)
target_logits_indices = torch.empty(total_num_draft_tokens,
dtype=torch.int32,
device=self.device)
bonus_logits_indices = torch.empty(batch_size,
dtype=torch.int32,
device=self.device)
_prepare_spec_decode_kernel[(batch_size, )](
query_start_loc,
cu_num_draft_tokens,
logits_indices,
target_logits_indices,
bonus_logits_indices,
BLOCK_SIZE=triton.next_power_of_2(_MAX_SPEC_LEN + 1),
)
draft_token_ids = input_ids[logits_indices]
draft_token_ids = draft_token_ids[target_logits_indices + 1]
return SpecDecodeMetadata(
draft_token_ids=draft_token_ids,
num_draft_tokens=cu_num_draft_tokens_np.tolist(),
cu_num_draft_tokens=cu_num_draft_tokens,
target_logits_indices=target_logits_indices,
bonus_logits_indices=bonus_logits_indices,
logits_indices=logits_indices,
)
@triton.jit
def _prepare_spec_decode_kernel(
query_start_loc, # [B + 1]
cu_num_draft_tokens, # [B]
logits_indices, # [N + B]
target_logits_indices, # [N]
bonus_logits_indices, # [B]
BLOCK_SIZE: tl.constexpr,
):
batch_idx = tl.program_id(0)
if batch_idx == 0:
draft_start_idx = 0
else:
draft_start_idx = tl.load(cu_num_draft_tokens + batch_idx - 1)
draft_end_idx = tl.load(cu_num_draft_tokens + batch_idx)
draft_len = draft_end_idx - draft_start_idx
sample_len = draft_len + 1
q_end_idx = tl.load(query_start_loc + batch_idx + 1)
sample_start_idx = draft_start_idx + batch_idx
sample_end_idx = sample_start_idx + sample_len
offset = tl.arange(0, BLOCK_SIZE)
tl.store(logits_indices + sample_start_idx + offset,
q_end_idx - sample_len + offset,
mask=offset < sample_len)
tl.store(target_logits_indices + draft_start_idx + offset,
sample_start_idx + offset,
mask=offset < draft_len)
tl.store(bonus_logits_indices + batch_idx, sample_end_idx - 1)

View File

@ -12,7 +12,6 @@ from vllm.sampling_params import SamplingType
from vllm.utils import swap_dict_values
from vllm.v1.outputs import LogprobsTensors
from vllm.v1.worker.block_table import MultiGroupBlockTable
from vllm.v1.worker.gpu_input_batch import CachedRequestState
_SAMPLING_EPS = 1e-5