Compare commits

...

1 Commits

Author SHA1 Message Date
79acf80471 Fast decode prepare path for prepare_inputs logic
Signed-off-by: Alexander Matveev <alexm@neuralmagic.com>
2025-05-08 17:26:00 +00:00
4 changed files with 225 additions and 4 deletions

View File

@ -10,12 +10,12 @@ prompts = [
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=10)
def main():
# Create an LLM.
llm = LLM(model="facebook/opt-125m")
llm = LLM(model="facebook/opt-125m", disable_cascade_attn=True)
# Generate texts from the prompts.
# The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.

View File

@ -85,6 +85,7 @@ if TYPE_CHECKING:
VLLM_ROCM_MOE_PADDING: bool = True
VLLM_ROCM_CUSTOM_PAGED_ATTN: bool = True
VLLM_ENABLE_V1_MULTIPROCESSING: bool = True
VLLM_ENABLE_V1_ADVANCE_STEP: bool = False
VLLM_LOG_BATCHSIZE_INTERVAL: float = -1
VLLM_DISABLE_COMPILE_CACHE: bool = False
Q_SCALE_CONSTANT: int = 200
@ -600,6 +601,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: float(os.getenv("VLLM_LOG_BATCHSIZE_INTERVAL", "-1")),
"VLLM_DISABLE_COMPILE_CACHE":
lambda: bool(int(os.getenv("VLLM_DISABLE_COMPILE_CACHE", "0"))),
"VLLM_ENABLE_V1_ADVANCE_STEP":
lambda: bool(int(os.getenv("VLLM_ENABLE_V1_ADVANCE_STEP", "0"))),
# If set, vllm will run in development mode, which will enable
# some additional endpoints for developing and debugging,

View File

@ -3,6 +3,7 @@
import numpy as np
import torch
import vllm.envs as envs
from vllm.logger import init_logger
logger = init_logger(__name__)
@ -36,6 +37,9 @@ class BlockTable:
self.block_table_np = self.block_table_cpu.numpy()
self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32)
self.prev_num_reqs = 0
self.is_updated = True
def append_row(
self,
block_ids: list[int],
@ -48,16 +52,22 @@ class BlockTable:
self.num_blocks_per_row[row_idx] += num_blocks
self.block_table_np[row_idx, start:start + num_blocks] = block_ids
self.is_updated = True
def add_row(self, block_ids: list[int], row_idx: int) -> None:
self.num_blocks_per_row[row_idx] = 0
self.append_row(block_ids, row_idx)
self.is_updated = True
def move_row(self, src: int, tgt: int) -> None:
num_blocks = self.num_blocks_per_row[src]
self.block_table_np[tgt, :num_blocks] = self.block_table_np[
src, :num_blocks]
self.num_blocks_per_row[tgt] = num_blocks
self.is_updated = True
def swap_row(self, src: int, tgt: int) -> None:
num_blocks_src = self.num_blocks_per_row[src]
num_blocks_tgt = self.num_blocks_per_row[tgt]
@ -66,14 +76,28 @@ class BlockTable:
self.block_table_np[[src, tgt]] = self.block_table_np[[tgt, src]]
self.is_updated = True
def commit(self, num_reqs: int) -> None:
self.block_table[:num_reqs].copy_(self.block_table_cpu[:num_reqs],
non_blocking=True)
if envs.VLLM_ENABLE_V1_ADVANCE_STEP:
# Incremental copy
if self.prev_num_reqs != num_reqs or self.is_updated:
self.block_table[:num_reqs].copy_(
self.block_table_cpu[:num_reqs], non_blocking=True)
self.prev_num_reqs = num_reqs
self.is_updated = False
else:
# Always copy
self.block_table[:num_reqs].copy_(self.block_table_cpu[:num_reqs],
non_blocking=True)
def clear(self) -> None:
self.block_table.fill_(0)
self.block_table_cpu.fill_(0)
self.is_updated = True
def get_device_tensor(self) -> torch.Tensor:
"""Ruturns the device tensor of the block table."""
return self.block_table

View File

@ -10,6 +10,7 @@ import torch
import torch.distributed
import torch.nn as nn
import vllm.envs as envs
from vllm.attention import AttentionType, get_attn_backend
from vllm.attention.layer import Attention
from vllm.config import (CompilationLevel, VllmConfig,
@ -142,6 +143,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
weakref.proxy(self))
self.cascade_attn_enabled = not self.model_config.disable_cascade_attn
if envs.VLLM_ENABLE_V1_ADVANCE_STEP:
logger.info("Advance_step is enabled")
if self.cascade_attn_enabled:
logger.warning(
"Disabling cascade attn (since advance_step is on)")
self.cascade_attn_enabled = False
else:
logger.info("Advance_step is disabled")
# Multi-modal data support
self.mm_registry = MULTIMODAL_REGISTRY
self.uses_mrope = model_config.uses_mrope
@ -271,16 +281,51 @@ class GPUModelRunner(LoRAModelRunnerMixin):
device="cpu",
pin_memory=self.pin_memory)
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
self.slot_mapping_gpu = torch.zeros(self.max_num_tokens,
dtype=torch.int64,
device=self.device)
self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1,
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory)
self.query_start_loc_np = self.query_start_loc_cpu.numpy()
self.query_start_loc_gpu = torch.zeros(self.max_num_reqs + 1,
dtype=torch.int32,
device=self.device)
self.seq_lens_cpu = torch.zeros(self.max_num_reqs,
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory)
self.seq_lens_np = self.seq_lens_cpu.numpy()
self.seq_lens_gpu = torch.zeros(self.max_num_reqs,
dtype=torch.int32,
device=self.device)
# Cached
self.prev_num_reqs = 0
self.req_indices_gpu = torch.arange(self.max_num_reqs,
dtype=torch.int32,
device=self.device)
self.req_indices_block_table_offsets_gpu = (
self.req_indices_gpu * self.max_num_blocks_per_req)
self.num_scheduled_tokens_gpu = torch.ones(self.max_num_reqs,
dtype=torch.int32,
device=self.device)
self.cu_num_tokens_gpu = torch.cumsum(self.num_scheduled_tokens_gpu, 0)
self.query_start_loc_gpu[0] = 0
self.query_start_loc_gpu[1:self.max_num_reqs +
1] = self.cu_num_tokens_gpu
self.logits_indices_gpu = self.query_start_loc_gpu[1:] - 1
self.prev_sampled_token_ids: Optional[torch.Tensor] = None
self.prev_attn_metadata = None
self.is_first_advance_decode = True
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
"""Update the cached states and the persistent batch with the scheduler
@ -485,6 +530,119 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if batch_changed or batch_reordered:
self.input_batch.refresh_sampling_metadata()
def _advance_decode_step(
self,
scheduler_output,
num_scheduled_tokens,
):
# print(" -- inside advance_decode_step")
num_reqs = self.input_batch.num_reqs
assert num_reqs > 0
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
assert total_num_scheduled_tokens == num_reqs
# TODO: Add if needed
# Get request indices.
# E.g., num_reqs == 3 -> [0, 1, 2]
# req_indices_gpu = self.req_indices_gpu[:num_reqs]
# Get cu_sums
# cu_num_tokens = self.cu_num_tokens_gpu[:num_reqs]
# Increment positions
positions_gpu = self.positions[:total_num_scheduled_tokens]
positions_gpu[:total_num_scheduled_tokens] += 1
# TODO: Verify MROPE is ok here
# Calculate M-RoPE positions.
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if self.uses_mrope:
self._calc_mrope_positions(scheduler_output)
# Set next tokens
# (prev iteration tokens are cached in prev_sampled_token_ids tensor)
assert self.prev_sampled_token_ids is not None
self.input_ids[:total_num_scheduled_tokens] = \
self.prev_sampled_token_ids[:,0]
# Calculate the slot mapping
block_table_indices_gpu = (
self.req_indices_block_table_offsets_gpu[:num_reqs] +
positions_gpu // self.block_size)
block_table_gpu = self.input_batch.block_table.get_device_tensor()
# Note: The block table tensor is async copied from CPU to GPU
# (inside the .commit() call) if was previously modified
block_numbers_gpu = block_table_gpu.flatten()[block_table_indices_gpu]
block_offsets_gpu = positions_gpu % self.block_size
slot_mapping_gpu = self.slot_mapping_gpu[:total_num_scheduled_tokens]
slot_mapping_gpu[:] = (block_numbers_gpu * self.block_size +
block_offsets_gpu)
# Prepare the attention metadata.
# query_start_loc is always the same for all decode iterations
query_start_loc_gpu = self.query_start_loc_gpu[:num_reqs + 1]
if self.uses_mrope:
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
self.mrope_positions[:, :total_num_scheduled_tokens].copy_(
self.mrope_positions_cpu[:, :total_num_scheduled_tokens],
non_blocking=True)
# TODO: Add cascade attn support
# Verify cascade attention is disabled
assert not self.cascade_attn_enabled
# TODO: Add support for other attn backends
assert self.prev_attn_metadata is not None
assert isinstance(self.prev_attn_metadata, FlashAttentionMetadata)
attn_metadata = self.prev_attn_metadata
attn_metadata.max_seq_len += 1
attn_metadata.query_start_loc = query_start_loc_gpu
attn_metadata.seq_lens += 1
attn_metadata.slot_mapping = slot_mapping_gpu
# print("attn_metadata.seq_lens: shape = {} data = {}".format(
# attn_metadata.seq_lens.shape, attn_metadata.seq_lens))
use_spec_decode = len(
scheduler_output.scheduled_spec_decode_tokens) > 0
if not use_spec_decode:
# NOTE(woosuk): Due to chunked prefills, the batch may contain
# partial requests. While we should not sample any token
# from these partial requests, we do so for simplicity.
# We will ignore the sampled tokens from the partial requests.
# TODO: Support prompt logprobs.
logits_indices = self.logits_indices_gpu[:num_reqs]
spec_decode_metadata = None
else:
# TODO: Check if spec_decode can be enabled here
raise Exception("advance_step has no support for spec_decode yet")
# # Get the number of draft tokens for each request.
# # Iterate over the dictionary rather than all requests since
# # not all requests have draft tokens.
# num_draft_tokens = np.zeros(num_reqs, dtype=np.int32)
# for req_id, draft_token_ids in (
# scheduler_output.scheduled_spec_decode_tokens.items()):
# req_idx = self.input_batch.req_id_to_index[req_id]
# num_draft_tokens[req_idx] = len(draft_token_ids)
# spec_decode_metadata = self._calc_spec_decode_metadata(
# num_draft_tokens, cu_num_tokens)
# logits_indices = spec_decode_metadata.logits_indices
# Hot-Swap lora model
if self.lora_config:
# TODO: Check if this works
raise Exception("advance_step has no LORA support yet")
self.set_active_loras(self.input_batch, num_scheduled_tokens)
return attn_metadata, logits_indices, spec_decode_metadata
def _prepare_inputs(
self,
scheduler_output: "SchedulerOutput",
@ -505,6 +663,38 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_scheduled_tokens = np.array(tokens, dtype=np.int32)
max_num_scheduled_tokens = max(tokens)
# Determine if advance step can be used
use_spec_decode = len(
scheduler_output.scheduled_spec_decode_tokens) > 0
is_flash_attn = self.prev_attn_metadata is not None and isinstance(
self.prev_attn_metadata, FlashAttentionMetadata)
is_advance_decode = (envs.VLLM_ENABLE_V1_ADVANCE_STEP
and self.prev_num_reqs == num_reqs
and max_num_scheduled_tokens == 1
and not use_spec_decode
and not self.cascade_attn_enabled
and is_flash_attn)
if is_advance_decode:
if self.is_first_advance_decode:
# The first time advance_step can be used,
# we run the usual prepare, so that positions tensor
# is initialized
self.is_first_advance_decode = False
else:
# This is the fast-path advance_step
# (all tensors are on the GPU and are updated on the GPU)
(attn_metadata, logits_indices,
spec_decode_metadata) = self._advance_decode_step(
scheduler_output, num_scheduled_tokens)
return attn_metadata, logits_indices, spec_decode_metadata
else:
self.is_first_advance_decode = True
self.prev_num_reqs = num_reqs
# Get request indices.
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
req_indices = np.repeat(self.arange_np[:num_reqs],
@ -523,6 +713,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
arange = self.arange_np[:total_num_scheduled_tokens] - cumsums_offsets
# Get positions.
positions_np = self.positions_np[:total_num_scheduled_tokens]
np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
arange,
@ -599,6 +790,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
max_query_len=max_num_scheduled_tokens,
common_prefix_len=common_prefix_len,
)
self.prev_attn_metadata = attn_metadata
use_spec_decode = len(
scheduler_output.scheduled_spec_decode_tokens) > 0
@ -1177,6 +1369,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Get the valid generated tokens.
sampled_token_ids = sampler_output.sampled_token_ids
self.prev_sampled_token_ids = sampled_token_ids
max_gen_len = sampled_token_ids.shape[-1]
if max_gen_len == 1:
# No spec decode tokens.