mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
4603 lines
196 KiB
Python
4603 lines
196 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import gc
|
|
import itertools
|
|
import time
|
|
from collections import defaultdict
|
|
from collections.abc import Iterator
|
|
from contextlib import contextmanager
|
|
from copy import deepcopy
|
|
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union, cast
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.distributed
|
|
import torch.nn as nn
|
|
from tqdm import tqdm
|
|
from typing_extensions import TypeAlias
|
|
|
|
import vllm.envs as envs
|
|
from vllm.attention import Attention, AttentionType
|
|
from vllm.attention.backends.abstract import AttentionBackend
|
|
from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention
|
|
from vllm.compilation.counter import compilation_counter
|
|
from vllm.compilation.cuda_graph import CUDAGraphWrapper
|
|
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
|
|
from vllm.config import (
|
|
CompilationLevel,
|
|
CUDAGraphMode,
|
|
VllmConfig,
|
|
get_layers_from_vllm_config,
|
|
update_config,
|
|
)
|
|
from vllm.distributed.eplb.eplb_state import EplbState
|
|
from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group
|
|
from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks
|
|
from vllm.distributed.parallel_state import (
|
|
get_pp_group,
|
|
get_tp_group,
|
|
graph_capture,
|
|
is_global_first_rank,
|
|
prepare_communication_buffer_for_model,
|
|
)
|
|
from vllm.forward_context import BatchDescriptor, DPMetadata, set_forward_context
|
|
from vllm.logger import init_logger
|
|
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
|
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
|
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
|
from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
|
|
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
|
|
from vllm.model_executor.models.interfaces import (
|
|
SupportsMultiModal,
|
|
is_mixture_of_experts,
|
|
supports_eagle3,
|
|
supports_mrope,
|
|
supports_multimodal_pruning,
|
|
supports_transcription,
|
|
)
|
|
from vllm.model_executor.models.interfaces_base import (
|
|
VllmModelForPooling,
|
|
is_pooling_model,
|
|
is_text_generation_model,
|
|
)
|
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
|
from vllm.multimodal.inputs import (
|
|
BatchedTensorInputs,
|
|
MultiModalKwargsItem,
|
|
PlaceholderRange,
|
|
)
|
|
from vllm.multimodal.utils import group_mm_kwargs_by_modality
|
|
from vllm.pooling_params import PoolingParams
|
|
from vllm.sampling_params import SamplingType
|
|
from vllm.sequence import IntermediateTensors
|
|
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
|
|
from vllm.utils import (
|
|
STR_DTYPE_TO_TORCH_DTYPE,
|
|
DeviceMemoryProfiler,
|
|
GiB_bytes,
|
|
cdiv,
|
|
check_use_alibi,
|
|
get_dtype_size,
|
|
is_pin_memory_available,
|
|
length_from_prompt_token_ids_or_embeds,
|
|
round_up,
|
|
supports_dynamo,
|
|
)
|
|
from vllm.utils.jsontree import json_map_leaves
|
|
from vllm.v1.attention.backends.flash_attn import AttentionMetadata
|
|
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
|
|
from vllm.v1.attention.backends.utils import (
|
|
AttentionCGSupport,
|
|
AttentionMetadataBuilder,
|
|
CommonAttentionMetadata,
|
|
create_fast_prefill_custom_backend,
|
|
reorder_batch_to_split_decodes_and_prefills,
|
|
split_attn_metadata,
|
|
)
|
|
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
|
|
from vllm.v1.kv_cache_interface import (
|
|
AttentionSpec,
|
|
ChunkedLocalAttentionSpec,
|
|
CrossAttentionSpec,
|
|
EncoderOnlyAttentionSpec,
|
|
FullAttentionSpec,
|
|
KVCacheConfig,
|
|
KVCacheGroupSpec,
|
|
KVCacheSpec,
|
|
MambaSpec,
|
|
MLAAttentionSpec,
|
|
SlidingWindowSpec,
|
|
UniformTypeKVCacheSpecs,
|
|
)
|
|
from vllm.v1.outputs import (
|
|
EMPTY_MODEL_RUNNER_OUTPUT,
|
|
AsyncModelRunnerOutput,
|
|
DraftTokenIds,
|
|
LogprobsLists,
|
|
LogprobsTensors,
|
|
ModelRunnerOutput,
|
|
PoolerOutput,
|
|
SamplerOutput,
|
|
)
|
|
from vllm.v1.pool.metadata import PoolingMetadata
|
|
from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs
|
|
from vllm.v1.sample.metadata import SamplingMetadata
|
|
from vllm.v1.sample.rejection_sampler import RejectionSampler
|
|
from vllm.v1.sample.sampler import Sampler
|
|
from vllm.v1.spec_decode.eagle import EagleProposer
|
|
from vllm.v1.spec_decode.medusa import MedusaProposer
|
|
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
|
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
|
from vllm.v1.structured_output.utils import apply_grammar_bitmask
|
|
from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext
|
|
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
|
from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper
|
|
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
|
|
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
|
from vllm.v1.worker.ubatch_splitting import check_ubatch_thresholds, ubatch_split
|
|
from vllm.v1.worker.ubatch_utils import UBatchSlice, UBatchSlices
|
|
from vllm.v1.worker.utils import is_residual_scattered_for_sp
|
|
|
|
from .utils import (
|
|
AttentionGroup,
|
|
MultiModalBudget,
|
|
add_kv_sharing_layers_to_kv_cache_groups,
|
|
bind_kv_cache,
|
|
gather_mm_placeholders,
|
|
sanity_check_mm_encoder_outputs,
|
|
scatter_mm_placeholders,
|
|
)
|
|
|
|
if TYPE_CHECKING:
|
|
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
|
from vllm.v1.core.sched.output import SchedulerOutput
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
AttnMetadataDict: TypeAlias = dict[str, AttentionMetadata]
|
|
# list when ubatching is enabled
|
|
PerLayerAttnMetadata: TypeAlias = Union[list[AttnMetadataDict], AttnMetadataDict]
|
|
|
|
|
|
# Wrapper for ModelRunnerOutput to support overlapped execution.
|
|
class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
|
|
def __init__(
|
|
self,
|
|
model_runner_output: ModelRunnerOutput,
|
|
sampled_token_ids: torch.Tensor,
|
|
invalid_req_indices: list[int],
|
|
async_output_copy_stream: torch.cuda.Stream,
|
|
):
|
|
self._model_runner_output = model_runner_output
|
|
self._invalid_req_indices = invalid_req_indices
|
|
|
|
# Event on the copy stream so we can synchronize the non-blocking copy.
|
|
self._async_copy_ready_event = torch.cuda.Event()
|
|
|
|
# Keep a reference to the device tensor to avoid it being
|
|
# deallocated until we finish copying it to the host.
|
|
self._sampled_token_ids = sampled_token_ids
|
|
|
|
# Initiate the copy on a separate stream, but do not synchronize it.
|
|
default_stream = torch.cuda.current_stream()
|
|
with torch.cuda.stream(async_output_copy_stream):
|
|
async_output_copy_stream.wait_stream(default_stream)
|
|
self._sampled_token_ids_cpu = self._sampled_token_ids.to(
|
|
"cpu", non_blocking=True
|
|
)
|
|
self._async_copy_ready_event.record()
|
|
|
|
def get_output(self) -> ModelRunnerOutput:
|
|
"""Copy the device tensors to the host and return a ModelRunnerOutput.
|
|
|
|
This function blocks until the copy is finished.
|
|
"""
|
|
self._async_copy_ready_event.synchronize()
|
|
|
|
# Release the device tensor once the copy has completed
|
|
del self._sampled_token_ids
|
|
|
|
valid_sampled_token_ids = self._sampled_token_ids_cpu.tolist()
|
|
for i in self._invalid_req_indices:
|
|
valid_sampled_token_ids[i].clear()
|
|
|
|
output = self._model_runner_output
|
|
output.sampled_token_ids = valid_sampled_token_ids
|
|
return output
|
|
|
|
|
|
class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|
def __init__(
|
|
self,
|
|
vllm_config: VllmConfig,
|
|
device: torch.device,
|
|
):
|
|
self.vllm_config = vllm_config
|
|
self.model_config = vllm_config.model_config
|
|
self.cache_config = vllm_config.cache_config
|
|
self.compilation_config = vllm_config.compilation_config
|
|
self.lora_config = vllm_config.lora_config
|
|
self.load_config = vllm_config.load_config
|
|
self.parallel_config = vllm_config.parallel_config
|
|
self.scheduler_config = vllm_config.scheduler_config
|
|
self.speculative_config = vllm_config.speculative_config
|
|
self.observability_config = vllm_config.observability_config
|
|
|
|
from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
|
|
|
|
set_cpu_offload_max_bytes(int(self.cache_config.cpu_offload_gb * 1024**3))
|
|
from vllm.model_executor.layers.batch_invariant import init_batch_invariance
|
|
|
|
init_batch_invariance()
|
|
|
|
model_config = self.model_config
|
|
cache_config = self.cache_config
|
|
scheduler_config = self.scheduler_config
|
|
parallel_config = self.parallel_config
|
|
self.device = device
|
|
self.pin_memory = is_pin_memory_available()
|
|
self.dtype = self.model_config.dtype
|
|
if cache_config.cache_dtype == "auto":
|
|
self.kv_cache_dtype = self.dtype
|
|
else:
|
|
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
|
|
|
|
self.is_pooling_model = model_config.runner_type == "pooling"
|
|
self.enable_prompt_embeds = model_config.enable_prompt_embeds
|
|
self.is_multimodal_raw_input_only_model = (
|
|
model_config.is_multimodal_raw_input_only_model
|
|
)
|
|
# This will be overridden in load_model()
|
|
self.is_multimodal_pruning_enabled = False
|
|
self.max_model_len = model_config.max_model_len
|
|
self.dcp_world_size = self.parallel_config.decode_context_parallel_size
|
|
self.max_num_tokens = scheduler_config.max_num_batched_tokens
|
|
self.max_num_reqs = scheduler_config.max_num_seqs
|
|
|
|
# Broadcast PP output for external_launcher (torchrun)
|
|
# to make sure we are synced across pp ranks
|
|
# TODO: Support overlapping mirco-batches
|
|
# https://github.com/vllm-project/vllm/issues/18019
|
|
self.broadcast_pp_output = (
|
|
self.parallel_config.distributed_executor_backend == "external_launcher"
|
|
and len(get_pp_group().ranks) > 0
|
|
)
|
|
|
|
# Model-related.
|
|
self.num_query_heads = model_config.get_num_attention_heads(parallel_config)
|
|
self.hidden_size = model_config.get_hidden_size()
|
|
self.attention_chunk_size = model_config.attention_chunk_size
|
|
# Only relevant for models using ALiBi (e.g, MPT)
|
|
self.use_alibi = check_use_alibi(model_config)
|
|
|
|
self.cascade_attn_enabled = not self.model_config.disable_cascade_attn
|
|
|
|
# Multi-modal data support
|
|
self.mm_registry = MULTIMODAL_REGISTRY
|
|
self.uses_mrope = model_config.uses_mrope
|
|
self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
|
|
model_config
|
|
)
|
|
|
|
if self.model_config.is_encoder_decoder:
|
|
# Maximum length of the encoder input, only for encoder-decoder
|
|
# models.
|
|
self.max_encoder_len = scheduler_config.max_num_encoder_input_tokens
|
|
else:
|
|
self.max_encoder_len = 0
|
|
|
|
# Sampler
|
|
self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode)
|
|
|
|
self.eplb_state: Optional[EplbState] = None
|
|
"""
|
|
State of the expert parallelism load balancer.
|
|
|
|
Will be lazily initialized when the model is loaded.
|
|
"""
|
|
|
|
# Lazy initializations
|
|
# self.model: nn.Module # Set after load_model
|
|
# Initialize in initialize_kv_cache
|
|
self.kv_caches: list[torch.Tensor] = []
|
|
# indexes: [kv_cache_group_id][attn_group]
|
|
self.attn_groups: list[list[AttentionGroup]] = []
|
|
# self.kv_cache_config: KVCacheConfig
|
|
|
|
# mm_hash -> encoder_output
|
|
self.encoder_cache: dict[str, torch.Tensor] = {}
|
|
|
|
self.use_aux_hidden_state_outputs = False
|
|
# Set up speculative decoding.
|
|
# NOTE(Jiayi): currently we put the entire draft model on
|
|
# the last PP rank. This is not ideal if there are many
|
|
# layers in the draft model.
|
|
if self.speculative_config and get_pp_group().is_last_rank:
|
|
if self.speculative_config.method == "ngram":
|
|
self.drafter = NgramProposer(self.vllm_config)
|
|
elif self.speculative_config.use_eagle():
|
|
self.drafter = EagleProposer(self.vllm_config, self.device, self) # type: ignore
|
|
if self.speculative_config.method == "eagle3":
|
|
self.use_aux_hidden_state_outputs = True
|
|
elif self.speculative_config.method == "medusa":
|
|
self.drafter = MedusaProposer(
|
|
vllm_config=self.vllm_config, device=self.device
|
|
) # type: ignore
|
|
else:
|
|
raise ValueError(
|
|
"Unknown speculative decoding method: "
|
|
f"{self.speculative_config.method}"
|
|
)
|
|
self.rejection_sampler = RejectionSampler()
|
|
|
|
# Request states.
|
|
self.requests: dict[str, CachedRequestState] = {}
|
|
self.comm_stream = torch.cuda.Stream()
|
|
|
|
# Input Batch
|
|
# NOTE(Chen): Ideally, we should initialize the input batch inside
|
|
# `initialize_kv_cache` based on the kv cache config. However, as in
|
|
# https://github.com/vllm-project/vllm/pull/18298, due to some unknown
|
|
# reasons, we have to initialize the input batch before `load_model`,
|
|
# quantization + weight offloading will fail otherwise. As a temporary
|
|
# solution, we initialize the input batch here, and re-initialize it
|
|
# in `initialize_kv_cache` if the block_sizes here is different from
|
|
# the block_sizes in the kv cache config.
|
|
self.input_batch = InputBatch(
|
|
max_num_reqs=self.max_num_reqs,
|
|
# We need to use the encoder length for encoder-decoer
|
|
# because of KV cache for cross-attention.
|
|
max_model_len=max(self.max_model_len, self.max_encoder_len),
|
|
max_num_batched_tokens=self.max_num_tokens,
|
|
device=self.device,
|
|
pin_memory=self.pin_memory,
|
|
vocab_size=self.model_config.get_vocab_size(),
|
|
block_sizes=[self.cache_config.block_size],
|
|
is_spec_decode=bool(self.vllm_config.speculative_config),
|
|
logitsprocs=build_logitsprocs(
|
|
self.vllm_config,
|
|
self.device,
|
|
self.pin_memory,
|
|
self.is_pooling_model,
|
|
self.vllm_config.model_config.logits_processors,
|
|
),
|
|
is_pooling_model=self.is_pooling_model,
|
|
)
|
|
|
|
self.use_async_scheduling = self.scheduler_config.async_scheduling
|
|
self.async_output_copy_stream = (
|
|
torch.cuda.Stream() if self.use_async_scheduling else None
|
|
)
|
|
|
|
# TODO(woosuk): Provide an option to tune the max cudagraph batch size.
|
|
# The convention is different.
|
|
# self.cudagraph_batch_sizes sorts in ascending order.
|
|
# The batch sizes in the config are in descending order.
|
|
if (
|
|
self.compilation_config.cudagraph_capture_sizes
|
|
and self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
|
):
|
|
self.cudagraph_batch_sizes = list(
|
|
reversed(self.compilation_config.cudagraph_capture_sizes)
|
|
)
|
|
|
|
# Cache the device properties.
|
|
self._init_device_properties()
|
|
|
|
# Persistent buffers for CUDA graphs.
|
|
self.input_ids = self._make_buffer(self.max_num_tokens, dtype=torch.int32)
|
|
self.positions = self._make_buffer(self.max_num_tokens, dtype=torch.int64)
|
|
self.query_start_loc = self._make_buffer(
|
|
self.max_num_reqs + 1, dtype=torch.int32
|
|
)
|
|
self.seq_lens = self._make_buffer(self.max_num_reqs, dtype=torch.int32)
|
|
# Because inputs_embeds may be bfloat16 and we don't need a numpy
|
|
# version of this tensor, avoid a RuntimeError by not creating a
|
|
# numpy buffer.
|
|
self.inputs_embeds = self._make_buffer(
|
|
self.max_num_tokens, self.hidden_size, dtype=self.dtype, numpy=False
|
|
)
|
|
self.is_token_ids = self._make_buffer(self.max_num_tokens, dtype=torch.bool)
|
|
self.discard_request_indices = self._make_buffer(
|
|
self.max_num_reqs, dtype=torch.int64
|
|
)
|
|
self.num_discarded_requests = 0
|
|
|
|
self.num_decode_draft_tokens = self._make_buffer(
|
|
self.max_num_reqs, dtype=torch.int32
|
|
)
|
|
self.num_accepted_tokens = self._make_buffer(
|
|
self.max_num_reqs, dtype=torch.int64
|
|
)
|
|
|
|
# Only relevant for multimodal models
|
|
if self.supports_mm_inputs:
|
|
self.is_mm_embed = self._make_buffer(self.max_num_tokens, dtype=torch.bool)
|
|
|
|
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
|
if self.uses_mrope:
|
|
# NOTE: `mrope_positions` is implemented with one additional dummy
|
|
# position on purpose to make it non-contiguous so that it can work
|
|
# with torch compile.
|
|
# See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923
|
|
|
|
# NOTE: When M-RoPE is enabled, position ids are 3D regardless of
|
|
# the modality of inputs. For text-only inputs, each dimension has
|
|
# identical position IDs, making M-RoPE functionally equivalent to
|
|
# 1D-RoPE.
|
|
# See page 5 of https://arxiv.org/abs/2409.12191
|
|
self.mrope_positions = self._make_buffer(
|
|
(3, self.max_num_tokens + 1), dtype=torch.int64
|
|
)
|
|
|
|
# CUDA event to synchronize use of reused CPU tensors between steps
|
|
# when async scheduling is enabled.
|
|
self.prepare_inputs_event: Optional[torch.cuda.Event] = None
|
|
if self.use_async_scheduling:
|
|
self.prepare_inputs_event = torch.cuda.Event()
|
|
# Start in a completed state.
|
|
self.prepare_inputs_event.record(torch.cuda.default_stream())
|
|
|
|
# None in the first PP rank. The rest are set after load_model.
|
|
self.intermediate_tensors: Optional[IntermediateTensors] = None
|
|
|
|
# OPTIMIZATION: Cache the tensors rather than creating them every step.
|
|
# Keep in int64 to avoid overflow with long context
|
|
self.arange_np = np.arange(
|
|
max(self.max_num_reqs + 1, self.max_model_len, self.max_num_tokens),
|
|
dtype=np.int64,
|
|
)
|
|
|
|
# Layer pairings for cross-layer KV sharing.
|
|
# If an Attention layer `layer_name` is in the keys of this dict, it
|
|
# means this layer will perform attention using the keys and values
|
|
# from the KV cache of `shared_kv_cache_layers[layer_name]`.
|
|
self.shared_kv_cache_layers: dict[str, str] = {}
|
|
self.kv_sharing_fast_prefill_eligible_layers: set[str] = set()
|
|
|
|
self.kv_sharing_fast_prefill_logits_indices = None
|
|
if self.cache_config.kv_sharing_fast_prefill:
|
|
self.kv_sharing_fast_prefill_logits_indices = torch.zeros(
|
|
self.max_num_tokens, dtype=torch.int32, device=self.device
|
|
)
|
|
|
|
self.uniform_decode_query_len = (
|
|
1
|
|
if not self.speculative_config
|
|
else 1 + self.speculative_config.num_speculative_tokens
|
|
)
|
|
|
|
# Cudagraph dispatcher for runtime cudagraph dispatching.
|
|
self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config)
|
|
|
|
self.mm_budget = (
|
|
MultiModalBudget(
|
|
self.model_config,
|
|
self.scheduler_config,
|
|
self.mm_registry,
|
|
)
|
|
if self.supports_mm_inputs
|
|
else None
|
|
)
|
|
|
|
self.reorder_batch_threshold: Optional[int] = None
|
|
|
|
# Attention layers that are only in the KVCacheConfig of the runner
|
|
# (e.g., KV sharing, encoder-only attention), but not in the
|
|
# KVCacheConfig of the scheduler.
|
|
self.runner_only_attn_layers: set[str] = set()
|
|
|
|
# Cached outputs.
|
|
self._draft_token_ids: Optional[Union[list[list[int]], torch.Tensor]] = None
|
|
self.transfer_event = torch.cuda.Event()
|
|
self.sampled_token_ids_pinned_cpu = torch.empty(
|
|
(self.max_model_len, 1),
|
|
dtype=torch.int64,
|
|
device="cpu",
|
|
pin_memory=self.pin_memory,
|
|
)
|
|
|
|
def _get_positions(self, num_tokens: Any):
|
|
if isinstance(num_tokens, int):
|
|
if self.uses_mrope:
|
|
return self.mrope_positions.gpu[:, :num_tokens]
|
|
return self.positions.gpu[:num_tokens]
|
|
else:
|
|
if self.uses_mrope:
|
|
return self.mrope_positions.gpu[:, num_tokens]
|
|
return self.positions.gpu[num_tokens]
|
|
|
|
def _make_buffer(
|
|
self, *size: Union[int, torch.SymInt], dtype: torch.dtype, numpy: bool = True
|
|
) -> CpuGpuBuffer:
|
|
return CpuGpuBuffer(
|
|
*size,
|
|
dtype=dtype,
|
|
device=self.device,
|
|
pin_memory=self.pin_memory,
|
|
with_numpy=numpy,
|
|
)
|
|
|
|
def _init_model_kwargs(self, num_tokens: int):
|
|
model_kwargs = dict[str, Any]()
|
|
|
|
if not self.is_pooling_model:
|
|
return model_kwargs
|
|
|
|
num_reqs = self.input_batch.num_reqs
|
|
pooling_params = self.input_batch.get_pooling_params()
|
|
|
|
token_type_id_requests = dict[int, Any]()
|
|
for i, param in enumerate(pooling_params):
|
|
if (
|
|
param.extra_kwargs is not None
|
|
and (token_types := param.extra_kwargs.get("compressed_token_type_ids"))
|
|
is not None
|
|
):
|
|
token_type_id_requests[i] = token_types
|
|
|
|
if len(token_type_id_requests) == 0:
|
|
return model_kwargs
|
|
|
|
seq_lens = self.seq_lens.gpu[:num_reqs]
|
|
token_type_ids = []
|
|
|
|
for i in range(num_reqs):
|
|
pos = token_type_id_requests.get(i, seq_lens[i])
|
|
ids = (torch.arange(seq_lens[i]) >= pos).int()
|
|
token_type_ids.append(ids)
|
|
|
|
model_kwargs["token_type_ids"] = torch.concat(token_type_ids).to(
|
|
device=self.device
|
|
)
|
|
return model_kwargs
|
|
|
|
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
|
|
"""
|
|
Update the order of requests in the batch based on the attention
|
|
backend's needs. For example, some attention backends (namely MLA) may
|
|
want to separate requests based on if the attention computation will be
|
|
compute-bound or memory-bound.
|
|
|
|
Args:
|
|
scheduler_output: The scheduler output.
|
|
"""
|
|
# Attention free models have zero kv_cache_goups, however models
|
|
# like Mamba are also attention free but use the kv_cache for
|
|
# keeping its internal state. This is why we check the number
|
|
# of kv_cache groups instead of solely checking
|
|
# for self.model_config.is_attention_free.
|
|
if len(self.kv_cache_config.kv_cache_groups) == 0:
|
|
return
|
|
|
|
if self.reorder_batch_threshold is not None:
|
|
# NOTE(lucas): currently no backend supports the custom masking
|
|
# required for DCP with q_len > 1, so we assert here. Remove this
|
|
# assert once the custom mask is support is added to FA3.
|
|
if self.dcp_world_size > 1:
|
|
assert self.reorder_batch_threshold == 1, (
|
|
"DCP not support reorder_batch_threshold > 1 now."
|
|
)
|
|
reorder_batch_to_split_decodes_and_prefills(
|
|
self.input_batch,
|
|
scheduler_output,
|
|
decode_threshold=self.reorder_batch_threshold,
|
|
)
|
|
|
|
# Note: used for model runner override.
|
|
def _init_device_properties(self) -> None:
|
|
"""Initialize attributes from torch.cuda.get_device_properties"""
|
|
self.device_properties = torch.cuda.get_device_properties(self.device)
|
|
self.num_sms = self.device_properties.multi_processor_count
|
|
|
|
# Note: used for model runner override.
|
|
def _sync_device(self) -> None:
|
|
torch.cuda.synchronize()
|
|
|
|
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
|
|
"""Update the cached states and the persistent batch with the scheduler
|
|
output.
|
|
|
|
The updated states are used by the `_prepare_inputs` function to create
|
|
the input GPU tensors for the model.
|
|
|
|
The SamplingMetadata is updated and copied to the GPU if there is a
|
|
new/resumed/paused/finished request in the batch.
|
|
"""
|
|
# Remove finished requests from the cached states.
|
|
for req_id in scheduler_output.finished_req_ids:
|
|
self.requests.pop(req_id, None)
|
|
# Remove the finished requests from the persistent batch.
|
|
# NOTE(woosuk): There could be an edge case where finished_req_ids and
|
|
# scheduled_req_ids overlap. This happens when a request is aborted and
|
|
# then resubmitted with the same ID. In this case, we treat them as two
|
|
# distinct requests - clearing the cached states for the first request
|
|
# and handling the second as a new request.
|
|
for req_id in scheduler_output.finished_req_ids:
|
|
self.input_batch.remove_request(req_id)
|
|
|
|
# Free the cached encoder outputs.
|
|
for mm_hash in scheduler_output.free_encoder_mm_hashes:
|
|
self.encoder_cache.pop(mm_hash, None)
|
|
|
|
# Remove the unscheduled requests from the persistent batch.
|
|
# NOTE(woosuk): The unscheduled requests are either preempted requests
|
|
# or running requests that are not scheduled in this step. We remove
|
|
# them from the persistent batch but keep their cached states since
|
|
# they will be scheduled again sometime in the future.
|
|
scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys()
|
|
cached_req_ids = self.input_batch.req_id_to_index.keys()
|
|
unscheduled_req_ids = cached_req_ids - scheduled_req_ids
|
|
# NOTE(woosuk): The persistent batch optimization assumes that
|
|
# consecutive batches contain mostly the same requests. If batches
|
|
# have low request overlap (e.g., alternating between two distinct
|
|
# sets of requests), this optimization becomes very inefficient.
|
|
for req_id in unscheduled_req_ids:
|
|
self.input_batch.remove_request(req_id)
|
|
|
|
reqs_to_add: list[CachedRequestState] = []
|
|
# Add new requests to the cached states.
|
|
for new_req_data in scheduler_output.scheduled_new_reqs:
|
|
req_id = new_req_data.req_id
|
|
sampling_params = new_req_data.sampling_params
|
|
pooling_params = new_req_data.pooling_params
|
|
|
|
if (
|
|
sampling_params
|
|
and sampling_params.sampling_type == SamplingType.RANDOM_SEED
|
|
):
|
|
generator = torch.Generator(device=self.device)
|
|
generator.manual_seed(sampling_params.seed)
|
|
else:
|
|
generator = None
|
|
|
|
if self.is_pooling_model:
|
|
assert pooling_params is not None
|
|
task = pooling_params.task
|
|
assert task is not None, "You did not set `task` in the API"
|
|
|
|
model = cast(VllmModelForPooling, self.get_model())
|
|
to_update = model.pooler.get_pooling_updates(task)
|
|
to_update.apply(pooling_params)
|
|
|
|
req_state = CachedRequestState(
|
|
req_id=req_id,
|
|
prompt_token_ids=new_req_data.prompt_token_ids,
|
|
prompt_embeds=new_req_data.prompt_embeds,
|
|
mm_features=new_req_data.mm_features,
|
|
sampling_params=sampling_params,
|
|
pooling_params=pooling_params,
|
|
generator=generator,
|
|
block_ids=new_req_data.block_ids,
|
|
num_computed_tokens=new_req_data.num_computed_tokens,
|
|
output_token_ids=[],
|
|
lora_request=new_req_data.lora_request,
|
|
)
|
|
self.requests[req_id] = req_state
|
|
|
|
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
|
if self.uses_mrope:
|
|
self._init_mrope_positions(req_state)
|
|
|
|
reqs_to_add.append(req_state)
|
|
|
|
# Update the states of the running/resumed requests.
|
|
is_last_rank = get_pp_group().is_last_rank
|
|
req_data = scheduler_output.scheduled_cached_reqs
|
|
for i, req_id in enumerate(req_data.req_ids):
|
|
req_state = self.requests[req_id]
|
|
num_computed_tokens = req_data.num_computed_tokens[i]
|
|
new_block_ids = req_data.new_block_ids[i]
|
|
resumed_from_preemption = req_data.resumed_from_preemption[i]
|
|
num_output_tokens = req_data.num_output_tokens[i]
|
|
|
|
# Update the cached states.
|
|
|
|
req_state.num_computed_tokens = num_computed_tokens
|
|
|
|
if not is_last_rank:
|
|
# When using PP, the scheduler sends the sampled tokens back,
|
|
# because there's no direct communication between the first-
|
|
# stage worker and the last-stage worker.
|
|
new_token_ids = req_data.new_token_ids[i]
|
|
# Add the sampled token(s) from the previous step (if any).
|
|
# This doesn't include "unverified" tokens like spec tokens.
|
|
num_new_tokens = (
|
|
num_computed_tokens + len(new_token_ids) - req_state.num_tokens
|
|
)
|
|
if num_new_tokens == 1:
|
|
# Avoid slicing list in most common case.
|
|
req_state.output_token_ids.append(new_token_ids[-1])
|
|
elif num_new_tokens > 0:
|
|
req_state.output_token_ids.extend(new_token_ids[-num_new_tokens:])
|
|
elif num_output_tokens < len(req_state.output_token_ids):
|
|
# Some output tokens were discarded due to a sync-KV-load
|
|
# failure. Align the cached state.
|
|
del req_state.output_token_ids[num_output_tokens:]
|
|
|
|
req_index = self.input_batch.req_id_to_index.get(req_id)
|
|
if req_index is not None:
|
|
old_end_idx = self.input_batch.num_tokens_no_spec[req_index]
|
|
end_idx = (
|
|
self.input_batch.num_prompt_tokens[req_index]
|
|
+ num_output_tokens
|
|
)
|
|
self.input_batch.num_tokens[req_index] = end_idx
|
|
self.input_batch.num_tokens_no_spec[req_index] = end_idx
|
|
self.input_batch.is_token_ids[req_index, end_idx:old_end_idx] = (
|
|
False
|
|
)
|
|
|
|
# Update the block IDs.
|
|
if not resumed_from_preemption:
|
|
if new_block_ids is not None:
|
|
# Append the new blocks to the existing block IDs.
|
|
for block_ids, new_ids in zip(req_state.block_ids, new_block_ids):
|
|
block_ids.extend(new_ids)
|
|
else:
|
|
assert new_block_ids is not None
|
|
# The request is resumed from preemption.
|
|
# Replace the existing block IDs with the new ones.
|
|
req_state.block_ids = new_block_ids
|
|
|
|
req_index = self.input_batch.req_id_to_index.get(req_id)
|
|
if req_index is None:
|
|
# The request is not in the persistent batch.
|
|
# The request was either preempted and resumed later, or was not
|
|
# scheduled in the previous step and needs to be added again.
|
|
reqs_to_add.append(req_state)
|
|
continue
|
|
|
|
# Update the persistent batch.
|
|
self.input_batch.num_computed_tokens_cpu[req_index] = num_computed_tokens
|
|
if new_block_ids is not None:
|
|
self.input_batch.block_table.append_row(new_block_ids, req_index)
|
|
|
|
# For the last rank, we don't need to update the token_ids_cpu
|
|
# because the sampled tokens are already cached.
|
|
if not is_last_rank:
|
|
# Add new_token_ids to token_ids_cpu.
|
|
start_token_index = num_computed_tokens
|
|
end_token_index = num_computed_tokens + len(new_token_ids)
|
|
self.input_batch.token_ids_cpu[
|
|
req_index, start_token_index:end_token_index
|
|
] = new_token_ids
|
|
self.input_batch.num_tokens_no_spec[req_index] = end_token_index
|
|
self.input_batch.num_tokens[req_index] = end_token_index
|
|
|
|
# Add spec_token_ids to token_ids_cpu.
|
|
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
|
|
req_id, ()
|
|
)
|
|
if spec_token_ids:
|
|
num_spec_tokens = len(spec_token_ids)
|
|
start_index = self.input_batch.num_tokens_no_spec[req_index]
|
|
end_token_index = start_index + num_spec_tokens
|
|
self.input_batch.token_ids_cpu[
|
|
req_index, start_index:end_token_index
|
|
] = spec_token_ids
|
|
# NOTE(woosuk): `num_tokens` here may include spec tokens.
|
|
self.input_batch.num_tokens[req_index] += num_spec_tokens
|
|
|
|
# Add the new or resumed requests to the persistent batch.
|
|
# The smaller empty indices are filled first.
|
|
for request in reqs_to_add:
|
|
self.input_batch.add_request(request)
|
|
|
|
# Condense the batched states if there are gaps left by removed requests
|
|
self.input_batch.condense()
|
|
# Allow attention backend to reorder the batch, potentially
|
|
self._may_reorder_batch(scheduler_output)
|
|
# Refresh batch metadata with any pending updates.
|
|
self.input_batch.refresh_metadata()
|
|
|
|
def _update_states_after_model_execute(
|
|
self, output_token_ids: torch.Tensor
|
|
) -> None:
|
|
"""Update the cached states after model execution.
|
|
|
|
This is used for MTP/EAGLE for hybrid models, as in linear attention,
|
|
only the last token's state is kept. In MTP/EAGLE, for draft tokens
|
|
the state are kept util we decide how many tokens are accepted for
|
|
each sequence, and a shifting is done during the next iteration
|
|
based on the number of accepted tokens.
|
|
"""
|
|
if not self.model_config.is_hybrid or not self.speculative_config:
|
|
return
|
|
|
|
# Find the number of accepted tokens for each sequence.
|
|
num_accepted_tokens = (
|
|
(
|
|
torch.cat(
|
|
[
|
|
output_token_ids,
|
|
torch.full(
|
|
(output_token_ids.size(0), 1),
|
|
-1,
|
|
device=output_token_ids.device,
|
|
),
|
|
],
|
|
dim=1,
|
|
)
|
|
== -1
|
|
)
|
|
.int()
|
|
.argmax(-1)
|
|
.cpu()
|
|
.numpy()
|
|
)
|
|
for i, num_tokens in enumerate(num_accepted_tokens):
|
|
self.input_batch.num_accepted_tokens_cpu[i] = num_tokens
|
|
|
|
def _init_mrope_positions(self, req_state: CachedRequestState):
|
|
image_grid_thw = []
|
|
video_grid_thw = []
|
|
second_per_grid_ts = []
|
|
audio_feature_lengths = []
|
|
use_audio_in_video = False
|
|
for mm_feature in req_state.mm_features:
|
|
mm_item = mm_feature.data
|
|
if mm_item is None:
|
|
continue
|
|
mm_input = mm_item.get_data()
|
|
if (t := mm_input.get("image_grid_thw")) is not None:
|
|
image_grid_thw.append(t.tolist())
|
|
if (t := mm_input.get("video_grid_thw")) is not None:
|
|
video_grid_thw.append(t.tolist())
|
|
if (t := mm_input.get("second_per_grid_ts")) is not None:
|
|
second_per_grid_ts.append(t)
|
|
if (t := mm_input.get("audio_feature_lengths")) is not None:
|
|
audio_feature_lengths.append(t)
|
|
if mm_input.get("use_audio_in_video") is True:
|
|
use_audio_in_video = True
|
|
|
|
if supports_mrope(self.model):
|
|
req_state.mrope_positions, req_state.mrope_position_delta = (
|
|
self.model.get_mrope_input_positions(
|
|
req_state.prompt_token_ids,
|
|
hf_config=self.model_config.hf_config,
|
|
image_grid_thw=image_grid_thw,
|
|
video_grid_thw=video_grid_thw,
|
|
second_per_grid_ts=second_per_grid_ts,
|
|
audio_feature_lengths=audio_feature_lengths,
|
|
use_audio_in_video=use_audio_in_video,
|
|
)
|
|
)
|
|
else:
|
|
req_state.mrope_positions, req_state.mrope_position_delta = (
|
|
MRotaryEmbedding.get_input_positions_tensor(
|
|
req_state.prompt_token_ids,
|
|
hf_config=self.model_config.hf_config,
|
|
image_grid_thw=image_grid_thw,
|
|
video_grid_thw=video_grid_thw,
|
|
second_per_grid_ts=second_per_grid_ts,
|
|
audio_feature_lengths=audio_feature_lengths,
|
|
use_audio_in_video=use_audio_in_video,
|
|
)
|
|
)
|
|
|
|
def _extract_mm_kwargs(
|
|
self,
|
|
scheduler_output: "SchedulerOutput",
|
|
) -> BatchedTensorInputs:
|
|
if not scheduler_output or not self.is_multimodal_raw_input_only_model:
|
|
return {}
|
|
|
|
mm_kwargs = list[MultiModalKwargsItem]()
|
|
for req in scheduler_output.scheduled_new_reqs:
|
|
for feature in req.mm_features:
|
|
if feature.data is not None:
|
|
mm_kwargs.append(feature.data)
|
|
|
|
# Input all modalities at once
|
|
model = cast(SupportsMultiModal, self.model)
|
|
mm_kwargs_combined: BatchedTensorInputs = {}
|
|
for _, _, mm_kwargs_group in group_mm_kwargs_by_modality(
|
|
mm_kwargs,
|
|
device=self.device,
|
|
pin_memory=self.pin_memory,
|
|
merge_by_field_config=model.merge_by_field_config,
|
|
):
|
|
mm_kwargs_combined.update(mm_kwargs_group)
|
|
|
|
return mm_kwargs_combined
|
|
|
|
def _dummy_mm_kwargs(self, num_seqs: int) -> BatchedTensorInputs:
|
|
if not self.is_multimodal_raw_input_only_model:
|
|
return {}
|
|
|
|
mm_budget = self.mm_budget
|
|
assert mm_budget is not None
|
|
|
|
dummy_modality = mm_budget.get_modality_with_max_tokens()
|
|
return self._get_mm_dummy_batch(dummy_modality, num_seqs)
|
|
|
|
def _get_cumsum_and_arange(
|
|
self,
|
|
num_tokens: np.ndarray,
|
|
cumsum_dtype: Optional[np.dtype] = None,
|
|
) -> tuple[np.ndarray, np.ndarray]:
|
|
"""Get the cumulative sum and batched arange of the given array.
|
|
# E.g., [2, 5, 3] -> ([2, 7, 10], [0, 1, 0, 1, 2, 3, 4, 0, 1, 2])
|
|
# Equivalent to but faster than:
|
|
# np.concatenate([np.arange(n) for n in num_tokens])
|
|
"""
|
|
# Step 1. [2, 5, 3] -> [2, 7, 10]
|
|
cu_num_tokens = np.cumsum(num_tokens, dtype=cumsum_dtype)
|
|
total_num_tokens = cu_num_tokens[-1]
|
|
# Step 2. [2, 7, 10] -> [0, 0, 2, 2, 2, 2, 2, 7, 7, 7]
|
|
cumsums_offsets = np.repeat(cu_num_tokens - num_tokens, num_tokens)
|
|
# Step 3. [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
|
|
arange = self.arange_np[:total_num_tokens] - cumsums_offsets
|
|
|
|
return cu_num_tokens, arange
|
|
|
|
def _prepare_input_ids(
|
|
self, total_num_scheduled_tokens: int, cu_num_tokens: np.ndarray
|
|
) -> None:
|
|
"""Prepare the input IDs for the current batch.
|
|
|
|
Carefully handles the `prev_sampled_token_ids` which can be cached
|
|
from the previous engine iteration, in which case those tokens on the
|
|
GPU need to be copied into the corresponding slots into input_ids."""
|
|
|
|
if self.input_batch.prev_sampled_token_ids is None:
|
|
# Normal scheduling case
|
|
self.input_ids.copy_to_gpu(total_num_scheduled_tokens)
|
|
if self.enable_prompt_embeds:
|
|
self.inputs_embeds.copy_to_gpu(total_num_scheduled_tokens)
|
|
self.is_token_ids.copy_to_gpu(total_num_scheduled_tokens)
|
|
return
|
|
|
|
# Async scheduling case, where some decode requests from the previous
|
|
# iteration won't have entries in input_ids_cpu and need to be copied
|
|
# on the GPU from prev_sampled_token_ids.
|
|
prev_req_id_to_index = self.input_batch.prev_req_id_to_index
|
|
assert prev_req_id_to_index is not None
|
|
flattened_indices = []
|
|
prev_common_req_indices = []
|
|
indices_match = True
|
|
max_flattened_index = -1
|
|
for req_id, cur_index in self.input_batch.req_id_to_index.items():
|
|
if (prev_index := prev_req_id_to_index.get(req_id)) is not None:
|
|
prev_common_req_indices.append(prev_index)
|
|
# We need to compute the flattened input_ids index of the
|
|
# last token in each common request.
|
|
flattened_index = cu_num_tokens[cur_index].item() - 1
|
|
flattened_indices.append(flattened_index)
|
|
indices_match &= prev_index == flattened_index
|
|
max_flattened_index = max(max_flattened_index, flattened_index)
|
|
num_commmon_tokens = len(flattened_indices)
|
|
if num_commmon_tokens < total_num_scheduled_tokens:
|
|
# If not all requests are decodes from the last iteration,
|
|
# We need to copy the input_ids_cpu to the GPU first.
|
|
self.input_ids.copy_to_gpu(total_num_scheduled_tokens)
|
|
if self.enable_prompt_embeds:
|
|
self.inputs_embeds.copy_to_gpu(total_num_scheduled_tokens)
|
|
self.is_token_ids.copy_to_gpu(total_num_scheduled_tokens)
|
|
if num_commmon_tokens == 0:
|
|
# No requests in common with the previous iteration
|
|
# So input_ids_cpu will have all the input ids.
|
|
return
|
|
if indices_match and max_flattened_index == (num_commmon_tokens - 1):
|
|
# Common-case optimization: the batch is unchanged
|
|
# and no reordering happened.
|
|
# The indices are both the same permutation of 0..N-1 so
|
|
# we can copy directly using a single slice.
|
|
self.input_ids.gpu[:num_commmon_tokens].copy_(
|
|
self.input_batch.prev_sampled_token_ids[:num_commmon_tokens, 0],
|
|
non_blocking=True,
|
|
)
|
|
if self.enable_prompt_embeds:
|
|
self.is_token_ids.gpu[:num_commmon_tokens] = True
|
|
return
|
|
# Upload the index tensors asynchronously
|
|
# so the scatter can be non-blocking.
|
|
input_ids_index_tensor = torch.tensor(
|
|
flattened_indices, dtype=torch.int64, pin_memory=self.pin_memory
|
|
).to(self.device, non_blocking=True)
|
|
prev_common_req_indices_tensor = torch.tensor(
|
|
prev_common_req_indices, dtype=torch.int64, pin_memory=self.pin_memory
|
|
).to(self.device, non_blocking=True)
|
|
self.input_ids.gpu.scatter_(
|
|
dim=0,
|
|
index=input_ids_index_tensor,
|
|
src=self.input_batch.prev_sampled_token_ids[
|
|
prev_common_req_indices_tensor, 0
|
|
],
|
|
)
|
|
|
|
def _get_encoder_seq_lens(
|
|
self,
|
|
scheduler_output: "SchedulerOutput",
|
|
kv_cache_spec: KVCacheSpec,
|
|
num_reqs: int,
|
|
) -> Optional[np.ndarray]:
|
|
if not isinstance(kv_cache_spec, CrossAttentionSpec):
|
|
return None
|
|
|
|
# Build encoder_seq_lens array mapping request indices to
|
|
# encoder lengths for inputs scheduled in this batch
|
|
encoder_seq_lens = np.zeros(num_reqs, dtype=np.int32)
|
|
for req_id in scheduler_output.scheduled_encoder_inputs:
|
|
req_index = self.input_batch.req_id_to_index[req_id]
|
|
encoder_seq_lens[req_index] = self.max_encoder_len
|
|
|
|
return encoder_seq_lens
|
|
|
|
def _prepare_inputs(
|
|
self, scheduler_output: "SchedulerOutput"
|
|
) -> tuple[
|
|
PerLayerAttnMetadata,
|
|
torch.Tensor,
|
|
Optional[SpecDecodeMetadata],
|
|
np.ndarray,
|
|
Optional[CommonAttentionMetadata],
|
|
int,
|
|
Optional[UBatchSlices],
|
|
Optional[torch.Tensor],
|
|
bool,
|
|
]:
|
|
"""
|
|
:return: tuple[
|
|
attn_metadata: layer-to-attention_metadata mapping,
|
|
logits_indices, spec_decode_metadata,
|
|
num_scheduled_tokens, spec_decode_common_attn_metadata,
|
|
max_num_scheduled_tokens, use_cascade_attn
|
|
]
|
|
"""
|
|
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
|
assert total_num_scheduled_tokens > 0
|
|
num_reqs = self.input_batch.num_reqs
|
|
assert num_reqs > 0
|
|
|
|
# OPTIMIZATION: Start copying the block table first.
|
|
# This way, we can overlap the copy with the following CPU operations.
|
|
self.input_batch.block_table.commit_block_table(num_reqs)
|
|
|
|
# Get the number of scheduled tokens for each request.
|
|
req_ids = self.input_batch.req_ids
|
|
tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
|
|
num_scheduled_tokens = np.array(tokens, dtype=np.int32)
|
|
max_num_scheduled_tokens = max(tokens)
|
|
|
|
# Get request indices.
|
|
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
|
|
req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens)
|
|
|
|
# cu_num_tokens: [2, 5, 3] -> [2, 7, 10]
|
|
# arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
|
|
cu_num_tokens, arange = self._get_cumsum_and_arange(num_scheduled_tokens)
|
|
|
|
# Get positions.
|
|
positions_np = self.positions.np[:total_num_scheduled_tokens]
|
|
np.add(
|
|
self.input_batch.num_computed_tokens_cpu[req_indices],
|
|
arange,
|
|
out=positions_np,
|
|
)
|
|
|
|
# Calculate M-RoPE positions.
|
|
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
|
if self.uses_mrope:
|
|
self._calc_mrope_positions(scheduler_output)
|
|
|
|
# Get token indices.
|
|
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
|
|
# -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
|
|
# where M is the max_model_len.
|
|
token_indices = (
|
|
positions_np + req_indices * self.input_batch.token_ids_cpu.shape[1]
|
|
)
|
|
token_indices_tensor = torch.from_numpy(token_indices)
|
|
|
|
# NOTE(woosuk): We use torch.index_select instead of np.take here
|
|
# because torch.index_select is much faster than np.take for large
|
|
# tensors.
|
|
torch.index_select(
|
|
self.input_batch.token_ids_cpu_tensor.flatten(),
|
|
0,
|
|
token_indices_tensor,
|
|
out=self.input_ids.cpu[:total_num_scheduled_tokens],
|
|
)
|
|
if self.enable_prompt_embeds:
|
|
is_token_ids = self.input_batch.is_token_ids.flatten()
|
|
torch.index_select(
|
|
is_token_ids,
|
|
0,
|
|
token_indices_tensor,
|
|
out=self.is_token_ids.cpu[:total_num_scheduled_tokens],
|
|
)
|
|
|
|
# Because we did not pre-allocate a massive prompt_embeds CPU tensor on
|
|
# the InputBatch, we need to fill in the prompt embeds into the expected
|
|
# spots in the GpuModelRunner's pre-allocated prompt_embeds tensor.
|
|
if self.input_batch.req_prompt_embeds:
|
|
output_idx = 0
|
|
for req_idx in range(num_reqs):
|
|
num_sched = num_scheduled_tokens[req_idx]
|
|
|
|
# Skip if this request doesn't have embeddings
|
|
if req_idx not in self.input_batch.req_prompt_embeds:
|
|
output_idx += num_sched
|
|
continue
|
|
|
|
# Skip if no tokens scheduled
|
|
if num_sched <= 0:
|
|
output_idx += num_sched
|
|
continue
|
|
|
|
req_embeds = self.input_batch.req_prompt_embeds[req_idx]
|
|
start_pos = self.input_batch.num_computed_tokens_cpu[req_idx]
|
|
|
|
# Skip if trying to read beyond available embeddings
|
|
if start_pos >= req_embeds.shape[0]:
|
|
output_idx += num_sched
|
|
continue
|
|
|
|
# Copy available embeddings
|
|
end_pos = start_pos + num_sched
|
|
actual_end = min(end_pos, req_embeds.shape[0])
|
|
actual_num_sched = actual_end - start_pos
|
|
|
|
if actual_num_sched > 0:
|
|
self.inputs_embeds.cpu[
|
|
output_idx : output_idx + actual_num_sched
|
|
].copy_(req_embeds[start_pos:actual_end])
|
|
|
|
output_idx += num_sched
|
|
|
|
self.input_batch.block_table.compute_slot_mapping(req_indices, positions_np)
|
|
self.input_batch.block_table.commit_slot_mapping(total_num_scheduled_tokens)
|
|
|
|
# Prepare the attention metadata.
|
|
self.query_start_loc.np[0] = 0
|
|
self.query_start_loc.np[1 : num_reqs + 1] = cu_num_tokens
|
|
# Note: pad query_start_loc to be non-decreasing, as kernels
|
|
# like FlashAttention requires that
|
|
self.query_start_loc.np[num_reqs + 1 :].fill(cu_num_tokens[-1])
|
|
self.query_start_loc.copy_to_gpu()
|
|
query_start_loc = self.query_start_loc.gpu[: num_reqs + 1]
|
|
|
|
num_tokens_unpadded = scheduler_output.total_num_scheduled_tokens
|
|
num_tokens_padded = num_tokens_unpadded + self.get_local_padding(
|
|
num_tokens_unpadded
|
|
)
|
|
uniform_decode = (
|
|
max_num_scheduled_tokens == self.uniform_decode_query_len
|
|
) and (total_num_scheduled_tokens == num_reqs * max_num_scheduled_tokens)
|
|
ubatch_slices, num_tokens_after_padding = ubatch_split(
|
|
num_scheduled_tokens,
|
|
num_tokens_unpadded,
|
|
num_tokens_padded,
|
|
uniform_decode=uniform_decode,
|
|
vllm_config=self.vllm_config,
|
|
)
|
|
|
|
self.seq_lens.np[:num_reqs] = (
|
|
self.input_batch.num_computed_tokens_cpu[:num_reqs] + num_scheduled_tokens
|
|
)
|
|
# Fill unused with 0 for full cuda graph mode.
|
|
self.seq_lens.np[num_reqs:].fill(0)
|
|
self.seq_lens.copy_to_gpu()
|
|
seq_lens = self.seq_lens.gpu[:num_reqs]
|
|
max_seq_len = self.seq_lens.np[:num_reqs].max().item()
|
|
|
|
num_tokens = [self.requests[r].num_tokens for r in self.input_batch.req_ids]
|
|
num_tokens_np = np.array(num_tokens, dtype=np.int32)
|
|
|
|
# Record the index of requests that should not be sampled,
|
|
# so that we could clear the sampled tokens before returning
|
|
discard_requests_mask = self.seq_lens.np[:num_reqs] < num_tokens_np
|
|
discard_request_indices = np.nonzero(discard_requests_mask)[0]
|
|
self.num_discarded_requests = len(discard_request_indices)
|
|
self.discard_request_indices.np[: self.num_discarded_requests] = (
|
|
discard_request_indices
|
|
)
|
|
|
|
self.discard_request_indices.copy_to_gpu(self.num_discarded_requests)
|
|
|
|
# Copy the tensors to the GPU.
|
|
self._prepare_input_ids(total_num_scheduled_tokens, cu_num_tokens)
|
|
|
|
if self.uses_mrope:
|
|
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
|
self.mrope_positions.gpu[:, :total_num_scheduled_tokens].copy_(
|
|
self.mrope_positions.cpu[:, :total_num_scheduled_tokens],
|
|
non_blocking=True,
|
|
)
|
|
else:
|
|
# Common case (1D positions)
|
|
self.positions.copy_to_gpu(total_num_scheduled_tokens)
|
|
|
|
use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0
|
|
if not use_spec_decode:
|
|
# NOTE(woosuk): Due to chunked prefills, the batch may contain
|
|
# partial requests. While we should not sample any token
|
|
# from these partial requests, we do so for simplicity.
|
|
# We will ignore the sampled tokens from the partial requests.
|
|
# TODO: Support prompt logprobs.
|
|
logits_indices = query_start_loc[1:] - 1
|
|
num_draft_tokens = None
|
|
spec_decode_metadata = None
|
|
else:
|
|
# Get the number of draft tokens for each request.
|
|
# Iterate over the dictionary rather than all requests since not all
|
|
# requests have draft tokens.
|
|
num_draft_tokens = np.zeros(num_reqs, dtype=np.int32)
|
|
# For chunked prefills, use -1 as mask rather than 0, as guided
|
|
# decoding may rollback speculative tokens.
|
|
num_decode_draft_tokens = np.full(num_reqs, -1, dtype=np.int32)
|
|
for (
|
|
req_id,
|
|
draft_token_ids,
|
|
) in scheduler_output.scheduled_spec_decode_tokens.items():
|
|
req_idx = self.input_batch.req_id_to_index[req_id]
|
|
num_draft_tokens[req_idx] = len(draft_token_ids)
|
|
num_decode_draft_tokens[req_idx] = (
|
|
len(draft_token_ids)
|
|
if (
|
|
self.input_batch.num_computed_tokens_cpu[req_idx]
|
|
>= self.input_batch.num_prompt_tokens[req_idx]
|
|
)
|
|
else -1
|
|
)
|
|
spec_decode_metadata = self._calc_spec_decode_metadata(
|
|
num_draft_tokens, cu_num_tokens
|
|
)
|
|
logits_indices = spec_decode_metadata.logits_indices
|
|
|
|
# For DECODE only cuda graph of some attention backends (e.g., GDN).
|
|
self.num_decode_draft_tokens.np[:num_reqs] = num_decode_draft_tokens
|
|
self.num_decode_draft_tokens.np[num_reqs:].fill(-1)
|
|
self.num_decode_draft_tokens.copy_to_gpu()
|
|
|
|
logits_indices_padded = None
|
|
if self.cache_config.kv_sharing_fast_prefill:
|
|
logits_indices_padded = self._prepare_kv_sharing_fast_prefill(
|
|
logits_indices
|
|
)
|
|
|
|
attn_metadata: PerLayerAttnMetadata = {}
|
|
if ubatch_slices is not None:
|
|
attn_metadata = [dict() for _ in range(len(ubatch_slices))]
|
|
use_cascade_attn = False
|
|
|
|
# Used in the below loop.
|
|
query_start_loc_cpu = self.query_start_loc.cpu[: num_reqs + 1]
|
|
seq_lens_cpu = self.seq_lens.cpu[:num_reqs]
|
|
num_computed_tokens_cpu = self.input_batch.num_computed_tokens_cpu_tensor[
|
|
:num_reqs
|
|
]
|
|
spec_decode_common_attn_metadata = None
|
|
if use_spec_decode:
|
|
self.num_accepted_tokens.np[:num_reqs] = (
|
|
self.input_batch.num_accepted_tokens_cpu[:num_reqs]
|
|
)
|
|
self.num_accepted_tokens.np[num_reqs:].fill(1)
|
|
self.num_accepted_tokens.copy_to_gpu()
|
|
|
|
# Prepare the attention metadata for each KV cache group and make layers
|
|
# in the same group share the same metadata.
|
|
for kv_cache_group_id, kv_cache_group_spec in enumerate(
|
|
self.kv_cache_config.kv_cache_groups
|
|
):
|
|
encoder_seq_lens = self._get_encoder_seq_lens(
|
|
scheduler_output, kv_cache_group_spec.kv_cache_spec, num_reqs
|
|
)
|
|
|
|
if isinstance(kv_cache_group_spec.kv_cache_spec, EncoderOnlyAttentionSpec):
|
|
# Encoder-only layers do not have KV cache, so we need to
|
|
# create a dummy block table and slot mapping for them.
|
|
blk_table_tensor = torch.zeros(
|
|
(num_reqs, 1),
|
|
dtype=torch.int32,
|
|
device=self.device,
|
|
)
|
|
slot_mapping = torch.zeros(
|
|
(total_num_scheduled_tokens,),
|
|
dtype=torch.int64,
|
|
device=self.device,
|
|
)
|
|
num_common_prefix_blocks = 0
|
|
else:
|
|
blk_table = self.input_batch.block_table[kv_cache_group_id]
|
|
blk_table_tensor = blk_table.get_device_tensor(num_reqs)
|
|
slot_mapping = blk_table.slot_mapping.gpu[:total_num_scheduled_tokens]
|
|
|
|
# Fill unused with -1. Needed for reshape_and_cache in full cuda
|
|
# graph mode.
|
|
blk_table.slot_mapping.gpu[total_num_scheduled_tokens:].fill_(-1)
|
|
num_common_prefix_blocks = scheduler_output.num_common_prefix_blocks[
|
|
kv_cache_group_id
|
|
]
|
|
|
|
common_attn_metadata = CommonAttentionMetadata(
|
|
query_start_loc=query_start_loc,
|
|
query_start_loc_cpu=query_start_loc_cpu,
|
|
seq_lens=seq_lens,
|
|
seq_lens_cpu=seq_lens_cpu,
|
|
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
|
num_reqs=num_reqs,
|
|
num_actual_tokens=total_num_scheduled_tokens,
|
|
max_query_len=max_num_scheduled_tokens,
|
|
max_seq_len=max_seq_len,
|
|
block_table_tensor=blk_table_tensor,
|
|
slot_mapping=slot_mapping,
|
|
logits_indices_padded=logits_indices_padded,
|
|
num_logits_indices=logits_indices.size(0),
|
|
causal=True,
|
|
encoder_seq_lens=encoder_seq_lens,
|
|
)
|
|
|
|
if self.speculative_config and spec_decode_common_attn_metadata is None:
|
|
if isinstance(self.drafter, EagleProposer):
|
|
if (
|
|
self.drafter.attn_layer_names[0]
|
|
in kv_cache_group_spec.layer_names
|
|
):
|
|
spec_decode_common_attn_metadata = common_attn_metadata
|
|
else:
|
|
spec_decode_common_attn_metadata = common_attn_metadata
|
|
|
|
for attn_group in self.attn_groups[kv_cache_group_id]:
|
|
# Prepare for cascade attention if enabled & beneficial.
|
|
common_prefix_len = 0
|
|
builder = attn_group.get_metadata_builder()
|
|
if self.cascade_attn_enabled:
|
|
common_prefix_len = self._compute_cascade_attn_prefix_len(
|
|
num_scheduled_tokens,
|
|
num_common_prefix_blocks,
|
|
attn_group.kv_cache_spec,
|
|
builder,
|
|
)
|
|
|
|
extra_attn_metadata_args = {}
|
|
if use_spec_decode and isinstance(builder, GDNAttentionMetadataBuilder):
|
|
extra_attn_metadata_args = dict(
|
|
num_accepted_tokens=self.num_accepted_tokens.gpu[:num_reqs],
|
|
num_decode_draft_tokens_cpu=self.num_decode_draft_tokens.cpu[
|
|
:num_reqs
|
|
],
|
|
)
|
|
|
|
if ubatch_slices is not None:
|
|
common_attn_metadata_list = split_attn_metadata(
|
|
ubatch_slices, common_attn_metadata
|
|
)
|
|
for ubid, common_attn_metadata in enumerate(
|
|
common_attn_metadata_list
|
|
):
|
|
attn_metadata_i = attn_group.get_metadata_builder(
|
|
ubatch_id=ubid
|
|
).build(
|
|
common_prefix_len=common_prefix_len,
|
|
common_attn_metadata=common_attn_metadata,
|
|
)
|
|
for layer_name in kv_cache_group_spec.layer_names:
|
|
assert type(attn_metadata) is list
|
|
attn_metadata[ubid][layer_name] = attn_metadata_i
|
|
else:
|
|
assert isinstance(attn_metadata, dict)
|
|
attn_metadata_i = builder.build(
|
|
common_prefix_len=common_prefix_len,
|
|
common_attn_metadata=common_attn_metadata,
|
|
**extra_attn_metadata_args,
|
|
)
|
|
use_cascade_attn |= getattr(attn_metadata_i, "use_cascade", False)
|
|
for layer_name in attn_group.layer_names:
|
|
attn_metadata[layer_name] = attn_metadata_i
|
|
|
|
# disable cascade attention when DBO
|
|
if ubatch_slices is not None:
|
|
use_cascade_attn = False
|
|
|
|
# Hot-Swap lora model
|
|
if self.lora_config:
|
|
self.set_active_loras(self.input_batch, num_scheduled_tokens)
|
|
|
|
return (
|
|
attn_metadata,
|
|
logits_indices,
|
|
spec_decode_metadata,
|
|
num_scheduled_tokens,
|
|
spec_decode_common_attn_metadata,
|
|
max_num_scheduled_tokens,
|
|
ubatch_slices,
|
|
num_tokens_after_padding,
|
|
use_cascade_attn,
|
|
)
|
|
|
|
def _compute_cascade_attn_prefix_len(
|
|
self,
|
|
num_scheduled_tokens: np.ndarray,
|
|
num_common_prefix_blocks: int,
|
|
kv_cache_spec: KVCacheSpec,
|
|
attn_metadata_builder: AttentionMetadataBuilder,
|
|
) -> int:
|
|
"""Compute the length of the common prefix for cascade attention.
|
|
|
|
NOTE(woosuk): The common prefix length returned by this function
|
|
represents the length used specifically for cascade attention, not the
|
|
actual number of tokens shared between requests. When cascade attention
|
|
is disabled (use_cascade=False), this function returns 0 even if
|
|
requests share common tokens. Additionally, the common prefix length is
|
|
truncated to a multiple of the block size and may be further truncated
|
|
due to implementation details explained below.
|
|
|
|
Args:
|
|
num_scheduled_tokens: Number of tokens scheduled per request.
|
|
num_common_prefix_blocks: Number of shared KV cache blocks.
|
|
|
|
Returns:
|
|
int: Length of common prefix in tokens.
|
|
"""
|
|
common_prefix_len = num_common_prefix_blocks * kv_cache_spec.block_size
|
|
if common_prefix_len == 0:
|
|
# Common case.
|
|
return 0
|
|
|
|
# NOTE(woosuk): Cascade attention uses two attention kernels: one
|
|
# for the common prefix and the other for the rest. For the first
|
|
# kernel, we concatenate all the query tokens (possibly from
|
|
# different requests) and treat them as if they are from the same
|
|
# request. Then, we use bi-directional attention to process the
|
|
# common prefix in the KV cache. Importantly, this means that the
|
|
# first kernel does not do any masking.
|
|
|
|
# Consider the following example:
|
|
# Request 1's input query: [D, E, X]
|
|
# Request 1's kv cache: [A, B, C, D, E, X]
|
|
# Request 1's num_computed_tokens: 3 (i.e., [A, B, C])
|
|
# Request 2's input query: [E, Y]
|
|
# Request 2's kv cache: [A, B, C, D, E, Y]
|
|
# Request 2's num_computed_tokens: 4 (i.e., [A, B, C, D])
|
|
|
|
# If we use [A, B, C, D, E] as the common prefix, then the
|
|
# first kernel will compute the bi-directional attention between
|
|
# input query [D, E, X, E, Y] and common prefix [A, B, C, D, E].
|
|
# However, this is wrong because D in Request 1 should not attend to
|
|
# E in the common prefix (i.e., we need masking).
|
|
# To avoid this, [A, B, C, D] should be the common prefix.
|
|
# That is, the common prefix should be capped by the minimum
|
|
# num_computed_tokens among the requests, and plus one to include
|
|
# the first token of the query.
|
|
|
|
# In practice, we use [A, B, C] as the common prefix, instead of
|
|
# [A, B, C, D] (i.e., the common prefix is capped by the minimum
|
|
# num_computed_tokens, without plus one).
|
|
# This is because of an implementation detail: We want to always
|
|
# use two kernels for cascade attention. Let's imagine:
|
|
# Request 3's input query: [D]
|
|
# Request 3's kv cache: [A, B, C, D]
|
|
# Request 3's num_computed_tokens: 3 (i.e., [A, B, C])
|
|
# If we use [A, B, C, D] as the common prefix for Request 1-3,
|
|
# then Request 3 will be processed only by the first kernel,
|
|
# and the second kernel will get an empty input. While this is not
|
|
# a fundamental problem, our current implementation does not support
|
|
# this case.
|
|
num_reqs = len(num_scheduled_tokens)
|
|
common_prefix_len = min(
|
|
common_prefix_len, self.input_batch.num_computed_tokens_cpu[:num_reqs].min()
|
|
)
|
|
# common_prefix_len should be a multiple of the block size.
|
|
common_prefix_len = (
|
|
common_prefix_len // kv_cache_spec.block_size * kv_cache_spec.block_size
|
|
)
|
|
use_sliding_window = isinstance(kv_cache_spec, SlidingWindowSpec) or (
|
|
isinstance(kv_cache_spec, FullAttentionSpec)
|
|
and kv_cache_spec.sliding_window is not None
|
|
)
|
|
use_local_attention = isinstance(kv_cache_spec, ChunkedLocalAttentionSpec) or (
|
|
isinstance(kv_cache_spec, FullAttentionSpec)
|
|
and kv_cache_spec.attention_chunk_size is not None
|
|
)
|
|
assert isinstance(kv_cache_spec, AttentionSpec)
|
|
use_cascade = attn_metadata_builder.use_cascade_attention(
|
|
common_prefix_len=common_prefix_len,
|
|
query_lens=num_scheduled_tokens,
|
|
num_query_heads=self.num_query_heads,
|
|
num_kv_heads=kv_cache_spec.num_kv_heads,
|
|
use_alibi=self.use_alibi,
|
|
use_sliding_window=use_sliding_window,
|
|
use_local_attention=use_local_attention,
|
|
num_sms=self.num_sms,
|
|
)
|
|
return common_prefix_len if use_cascade else 0
|
|
|
|
def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"):
|
|
mrope_pos_ptr = 0
|
|
for index, req_id in enumerate(self.input_batch.req_ids):
|
|
req = self.requests[req_id]
|
|
assert req.mrope_positions is not None
|
|
|
|
num_computed_tokens = self.input_batch.num_computed_tokens_cpu[index]
|
|
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
|
num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
|
|
req.prompt_token_ids, req.prompt_embeds
|
|
)
|
|
|
|
if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens:
|
|
prompt_part_len = max(0, num_prompt_tokens - num_computed_tokens)
|
|
completion_part_len = max(0, num_scheduled_tokens - prompt_part_len)
|
|
else:
|
|
prompt_part_len = num_scheduled_tokens
|
|
completion_part_len = 0
|
|
|
|
assert num_scheduled_tokens == prompt_part_len + completion_part_len
|
|
|
|
if prompt_part_len > 0:
|
|
# prompt's mrope_positions are pre-computed
|
|
dst_start = mrope_pos_ptr
|
|
dst_end = mrope_pos_ptr + prompt_part_len
|
|
src_start = num_computed_tokens
|
|
src_end = num_computed_tokens + prompt_part_len
|
|
|
|
self.mrope_positions.cpu[:, dst_start:dst_end] = req.mrope_positions[
|
|
:, src_start:src_end
|
|
]
|
|
mrope_pos_ptr += prompt_part_len
|
|
|
|
if completion_part_len > 0:
|
|
# compute completion's mrope_positions on-the-fly
|
|
dst_start = mrope_pos_ptr
|
|
dst_end = mrope_pos_ptr + completion_part_len
|
|
|
|
MRotaryEmbedding.get_next_input_positions_tensor(
|
|
out=self.mrope_positions.np,
|
|
out_offset=dst_start,
|
|
mrope_position_delta=req.mrope_position_delta,
|
|
context_len=num_computed_tokens + prompt_part_len,
|
|
num_new_tokens=completion_part_len,
|
|
)
|
|
|
|
mrope_pos_ptr += completion_part_len
|
|
|
|
def _calc_spec_decode_metadata(
|
|
self,
|
|
num_draft_tokens: np.ndarray,
|
|
cu_num_scheduled_tokens: np.ndarray,
|
|
) -> SpecDecodeMetadata:
|
|
# Inputs:
|
|
# cu_num_scheduled_tokens: [ 4, 104, 107, 207, 209]
|
|
# num_draft_tokens: [ 3, 0, 2, 0, 1]
|
|
# Outputs:
|
|
# cu_num_draft_tokens: [ 3, 3, 5, 5, 6]
|
|
# logits_indices: [ 0, 1, 2, 3, 103, 104, 105, 106,
|
|
# 206, 207, 208]
|
|
# target_logits_indices: [ 0, 1, 2, 5, 6, 9]
|
|
# bonus_logits_indices: [ 3, 4, 7, 8, 10]
|
|
|
|
# Compute the logits indices.
|
|
# [4, 1, 3, 1, 2]
|
|
num_sampled_tokens = num_draft_tokens + 1
|
|
|
|
# Step 1. cu_num_sampled_tokens: [4, 5, 8, 9, 11]
|
|
# arange: [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1]
|
|
cu_num_sampled_tokens, arange = self._get_cumsum_and_arange(
|
|
num_sampled_tokens, cumsum_dtype=np.int32
|
|
)
|
|
# Step 2. [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207]
|
|
logits_indices = np.repeat(
|
|
cu_num_scheduled_tokens - num_sampled_tokens, num_sampled_tokens
|
|
)
|
|
# Step 3. [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208]
|
|
logits_indices += arange
|
|
|
|
# Compute the bonus logits indices.
|
|
bonus_logits_indices = cu_num_sampled_tokens - 1
|
|
|
|
# Compute the draft logits indices.
|
|
# cu_num_draft_tokens: [3, 3, 5, 5, 6]
|
|
# arange: [0, 1, 2, 0, 1, 0]
|
|
cu_num_draft_tokens, arange = self._get_cumsum_and_arange(
|
|
num_draft_tokens, cumsum_dtype=np.int32
|
|
)
|
|
# [0, 0, 0, 5, 5, 9]
|
|
target_logits_indices = np.repeat(
|
|
cu_num_sampled_tokens - num_sampled_tokens, num_draft_tokens
|
|
)
|
|
# [0, 1, 2, 5, 6, 9]
|
|
target_logits_indices += arange
|
|
|
|
# TODO: Optimize the CPU -> GPU copy.
|
|
cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).to(
|
|
self.device, non_blocking=True
|
|
)
|
|
logits_indices = torch.from_numpy(logits_indices).to(
|
|
self.device, non_blocking=True
|
|
)
|
|
target_logits_indices = torch.from_numpy(target_logits_indices).to(
|
|
self.device, non_blocking=True
|
|
)
|
|
bonus_logits_indices = torch.from_numpy(bonus_logits_indices).to(
|
|
self.device, non_blocking=True
|
|
)
|
|
|
|
# Compute the draft token ids.
|
|
# draft_token_indices: [ 1, 2, 3, 105, 106, 208]
|
|
draft_token_ids = self.input_ids.gpu[logits_indices]
|
|
draft_token_ids = draft_token_ids[target_logits_indices + 1]
|
|
|
|
metadata = SpecDecodeMetadata(
|
|
draft_token_ids=draft_token_ids,
|
|
num_draft_tokens=num_draft_tokens.tolist(),
|
|
cu_num_draft_tokens=cu_num_draft_tokens,
|
|
target_logits_indices=target_logits_indices,
|
|
bonus_logits_indices=bonus_logits_indices,
|
|
logits_indices=logits_indices,
|
|
)
|
|
return metadata
|
|
|
|
def _prepare_kv_sharing_fast_prefill(
|
|
self,
|
|
logits_indices: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
assert self.kv_sharing_fast_prefill_logits_indices is not None
|
|
num_logits = logits_indices.shape[0]
|
|
assert num_logits > 0
|
|
self.kv_sharing_fast_prefill_logits_indices[:num_logits].copy_(logits_indices)
|
|
# There might have leftover indices in logits_indices[num_logits:]
|
|
# from previous iterations, whose values may be greater than the
|
|
# batch size in the current iteration. To ensure indices are always
|
|
# valid, we fill the padded indices with the last index.
|
|
self.kv_sharing_fast_prefill_logits_indices[num_logits:].fill_(
|
|
logits_indices[-1].item()
|
|
)
|
|
if (
|
|
self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
|
and num_logits <= self.cudagraph_batch_sizes[-1]
|
|
):
|
|
# Use piecewise CUDA graphs.
|
|
# Add padding to the batch size.
|
|
num_logits_padded = self.vllm_config.pad_for_cudagraph(num_logits)
|
|
else:
|
|
num_logits_padded = num_logits
|
|
logits_indices_padded = self.kv_sharing_fast_prefill_logits_indices[
|
|
:num_logits_padded
|
|
]
|
|
return logits_indices_padded
|
|
|
|
def _batch_mm_kwargs_from_scheduler(
|
|
self,
|
|
scheduler_output: "SchedulerOutput",
|
|
) -> tuple[list[MultiModalKwargsItem], list[tuple[str, PlaceholderRange]]]:
|
|
"""Batch multimodal kwargs from scheduled encoder inputs.
|
|
|
|
Args:
|
|
scheduler_output: The scheduler output containing scheduled encoder
|
|
inputs.
|
|
|
|
Returns:
|
|
A tuple of (mm_kwargs, req_ids_pos) where:
|
|
- mm_kwargs: List of multimodal kwargs items to be batched
|
|
- mm_hashes_pos: List of (mm_hash, position_info) tuples
|
|
"""
|
|
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
|
|
if not scheduled_encoder_inputs:
|
|
return [], []
|
|
# Batch the multi-modal inputs.
|
|
mm_kwargs = list[MultiModalKwargsItem]()
|
|
# list of tuple (mm_hash, position_info)
|
|
mm_hashes_pos = list[tuple[str, PlaceholderRange]]()
|
|
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
|
|
req_state = self.requests[req_id]
|
|
|
|
for mm_input_id in encoder_input_ids:
|
|
mm_feature = req_state.mm_features[mm_input_id]
|
|
mm_hash = mm_feature.identifier
|
|
mm_kwargs.append(mm_feature.data)
|
|
mm_hashes_pos.append((mm_hash, mm_feature.mm_position))
|
|
|
|
return mm_kwargs, mm_hashes_pos
|
|
|
|
def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
|
|
# Batch the multi-modal inputs using the helper method.
|
|
mm_kwargs, mm_hashes_pos = self._batch_mm_kwargs_from_scheduler(
|
|
scheduler_output
|
|
)
|
|
|
|
if not mm_kwargs:
|
|
return
|
|
|
|
# Batch mm inputs as much as we can: if a request in the batch has
|
|
# multiple modalities or a different modality than the previous one,
|
|
# we process it separately to preserve item order.
|
|
# FIXME(ywang96): This is a hacky way to deal with multiple modalities
|
|
# in the same batch while still being able to benefit from batching
|
|
# multimodal inputs. The proper solution should be reordering the
|
|
# encoder outputs.
|
|
model = cast(SupportsMultiModal, self.model)
|
|
encoder_outputs = []
|
|
for modality, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
|
|
mm_kwargs,
|
|
device=self.device,
|
|
pin_memory=self.pin_memory,
|
|
merge_by_field_config=model.merge_by_field_config,
|
|
):
|
|
# (ekhvedchenia): Temporary hack to limit peak memory usage when
|
|
# processing multimodal data.This solves the issue with scheduler
|
|
# putting too many video samples into a single batch. Scheduler
|
|
# uses pruned vision tokens count to compare it versus compute
|
|
# budget which is incorrect (Either input media size or non-pruned
|
|
# output vision tokens count should be considered)
|
|
curr_group_outputs = []
|
|
|
|
if self.is_multimodal_pruning_enabled and modality == "video":
|
|
micro_batch_size = 1
|
|
for i in range(0, num_items, micro_batch_size):
|
|
micro_batch_mm_inputs = dict(
|
|
(k, v[i : i + micro_batch_size])
|
|
for k, v in mm_kwargs_group.items()
|
|
)
|
|
|
|
micro_batch_outputs = model.get_multimodal_embeddings(
|
|
**micro_batch_mm_inputs
|
|
)
|
|
|
|
curr_group_outputs.extend(micro_batch_outputs)
|
|
else:
|
|
# Run the encoder.
|
|
# `curr_group_outputs` is either of the following:
|
|
# 1. A tensor of shape (num_items, feature_size, hidden_size)
|
|
# in case feature_size is fixed across all multimodal items.
|
|
# 2. A list or tuple (length: num_items) of tensors,
|
|
# each of shape (feature_size, hidden_size) in case the feature
|
|
# size is dynamic depending on the input multimodal items.
|
|
curr_group_outputs = model.get_multimodal_embeddings(**mm_kwargs_group)
|
|
|
|
sanity_check_mm_encoder_outputs(
|
|
curr_group_outputs,
|
|
expected_num_items=num_items,
|
|
)
|
|
encoder_outputs.extend(curr_group_outputs)
|
|
|
|
# Cache the encoder outputs by mm_hash
|
|
for (mm_hash, pos_info), output in zip(mm_hashes_pos, encoder_outputs):
|
|
self.encoder_cache[mm_hash] = scatter_mm_placeholders(
|
|
output,
|
|
is_embed=pos_info.is_embed,
|
|
)
|
|
|
|
def _gather_mm_embeddings(
|
|
self,
|
|
scheduler_output: "SchedulerOutput",
|
|
shift_computed_tokens: int = 0,
|
|
) -> tuple[list[torch.Tensor], torch.Tensor]:
|
|
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
|
|
|
mm_embeds = list[torch.Tensor]()
|
|
is_mm_embed = self.is_mm_embed.cpu
|
|
is_mm_embed[:total_num_scheduled_tokens] = False
|
|
|
|
req_start_idx = 0
|
|
should_sync_mrope_positions = False
|
|
|
|
for req_id in self.input_batch.req_ids:
|
|
mm_embeds_req: list[torch.Tensor] = []
|
|
|
|
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
|
req_state = self.requests[req_id]
|
|
num_computed_tokens = req_state.num_computed_tokens + shift_computed_tokens
|
|
|
|
for mm_feature in req_state.mm_features:
|
|
pos_info = mm_feature.mm_position
|
|
start_pos = pos_info.offset
|
|
num_encoder_tokens = pos_info.length
|
|
|
|
# The encoder output is needed if the two ranges overlap:
|
|
# [num_computed_tokens,
|
|
# num_computed_tokens + num_scheduled_tokens) and
|
|
# [start_pos, start_pos + num_encoder_tokens)
|
|
if start_pos >= num_computed_tokens + num_scheduled_tokens:
|
|
# The encoder output is not needed in this step.
|
|
break
|
|
if start_pos + num_encoder_tokens <= num_computed_tokens:
|
|
# The encoder output is already processed and stored
|
|
# in the decoder's KV cache.
|
|
continue
|
|
|
|
start_idx = max(num_computed_tokens - start_pos, 0)
|
|
end_idx = min(
|
|
num_computed_tokens - start_pos + num_scheduled_tokens,
|
|
num_encoder_tokens,
|
|
)
|
|
assert start_idx < end_idx
|
|
|
|
mm_hash = mm_feature.identifier
|
|
encoder_output = self.encoder_cache.get(mm_hash, None)
|
|
assert encoder_output is not None, f"Encoder cache miss for {mm_hash}."
|
|
|
|
if (is_embed := pos_info.is_embed) is not None:
|
|
is_embed = is_embed[start_idx:end_idx]
|
|
|
|
req_start_pos = req_start_idx + start_pos - num_computed_tokens
|
|
is_mm_embed[req_start_pos + start_idx : req_start_pos + end_idx] = (
|
|
True if is_embed is None else is_embed
|
|
)
|
|
|
|
mm_embeds_item = gather_mm_placeholders(
|
|
encoder_output[start_idx:end_idx],
|
|
is_embed=is_embed,
|
|
)
|
|
mm_embeds_req.append(mm_embeds_item)
|
|
|
|
if self.is_multimodal_pruning_enabled and self.uses_mrope:
|
|
assert req_state.mrope_positions is not None
|
|
should_sync_mrope_positions = True
|
|
mm_embeds_req, new_mrope_positions, new_delta = (
|
|
self.model.recompute_mrope_positions(
|
|
input_ids=req_state.prompt_token_ids,
|
|
multimodal_embeddings=mm_embeds_req,
|
|
mrope_positions=req_state.mrope_positions,
|
|
num_computed_tokens=req_state.num_computed_tokens,
|
|
)
|
|
)
|
|
req_state.mrope_positions.copy_(new_mrope_positions)
|
|
req_state.mrope_position_delta = new_delta
|
|
|
|
mm_embeds.extend(mm_embeds_req)
|
|
req_start_idx += num_scheduled_tokens
|
|
|
|
is_mm_embed = self.is_mm_embed.copy_to_gpu(total_num_scheduled_tokens)
|
|
|
|
if should_sync_mrope_positions:
|
|
self._calc_mrope_positions(scheduler_output)
|
|
self.mrope_positions.copy_to_gpu(total_num_scheduled_tokens)
|
|
|
|
return mm_embeds, is_mm_embed
|
|
|
|
def _extract_encoder_inputs(
|
|
self,
|
|
scheduler_output: "SchedulerOutput",
|
|
) -> dict[str, torch.Tensor]:
|
|
"""Extract encoder inputs for encoder-decoder models.
|
|
|
|
This method extracts multimodal input features from scheduled encoder
|
|
inputs and formats them for the encoder-decoder model forward pass.
|
|
"""
|
|
# Batch the multi-modal inputs using the helper method.
|
|
mm_kwargs, _ = self._batch_mm_kwargs_from_scheduler(scheduler_output)
|
|
|
|
if not mm_kwargs:
|
|
return {}
|
|
|
|
# Group MM kwargs by modality and extract features
|
|
model = cast(SupportsMultiModal, self.model)
|
|
encoder_features = {}
|
|
for _, _, mm_kwargs_group in group_mm_kwargs_by_modality(
|
|
mm_kwargs,
|
|
device=self.device,
|
|
pin_memory=self.pin_memory,
|
|
merge_by_field_config=model.merge_by_field_config,
|
|
):
|
|
# Add the grouped features to encoder_features dict
|
|
# This allows the model to receive them as kwargs (e.g.,
|
|
# input_features=...)
|
|
encoder_features.update(mm_kwargs_group)
|
|
|
|
return encoder_features
|
|
|
|
def get_model(self) -> nn.Module:
|
|
# get raw model out of the cudagraph wrapper.
|
|
if isinstance(self.model, (CUDAGraphWrapper, UBatchWrapper)):
|
|
return self.model.unwrap()
|
|
return self.model
|
|
|
|
def get_supported_generation_tasks(self) -> list[GenerationTask]:
|
|
model = self.get_model()
|
|
supported_tasks = list[GenerationTask]()
|
|
|
|
if is_text_generation_model(model):
|
|
supported_tasks.append("generate")
|
|
|
|
if supports_transcription(model):
|
|
if model.supports_transcription_only:
|
|
return ["transcription"]
|
|
|
|
supported_tasks.append("transcription")
|
|
|
|
return supported_tasks
|
|
|
|
def get_supported_pooling_tasks(self) -> list[PoolingTask]:
|
|
model = self.get_model()
|
|
if not is_pooling_model(model):
|
|
return []
|
|
|
|
supported_tasks = list(model.pooler.get_supported_tasks())
|
|
|
|
if (
|
|
self.scheduler_config.chunked_prefill_enabled
|
|
and "encode" in supported_tasks
|
|
):
|
|
supported_tasks.remove("encode")
|
|
|
|
logger.debug_once(
|
|
"Chunked prefill is not supported with "
|
|
"encode task which using ALL pooling. "
|
|
"Please turn off chunked prefill by "
|
|
"`--no-enable-chunked-prefill` before using it."
|
|
)
|
|
|
|
if "score" in supported_tasks:
|
|
num_labels = getattr(self.model_config.hf_config, "num_labels", 0)
|
|
if num_labels != 1:
|
|
supported_tasks.remove("score")
|
|
logger.debug_once("Score API is only enabled for num_labels == 1.")
|
|
|
|
return supported_tasks
|
|
|
|
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
|
|
tasks = list[SupportedTask]()
|
|
|
|
if self.model_config.runner_type == "generate":
|
|
tasks.extend(self.get_supported_generation_tasks())
|
|
if self.model_config.runner_type == "pooling":
|
|
tasks.extend(self.get_supported_pooling_tasks())
|
|
|
|
return tuple(tasks)
|
|
|
|
def sync_and_slice_intermediate_tensors(
|
|
self,
|
|
num_tokens: int,
|
|
intermediate_tensors: IntermediateTensors,
|
|
sync_self: bool,
|
|
) -> IntermediateTensors:
|
|
assert self.intermediate_tensors is not None
|
|
|
|
tp = self.vllm_config.parallel_config.tensor_parallel_size
|
|
is_rs = is_residual_scattered_for_sp(self.vllm_config, num_tokens)
|
|
|
|
# When sequence parallelism is enabled, the "residual" tensor is sharded
|
|
# across tensor parallel ranks, so each rank only needs its own slice.
|
|
if sync_self:
|
|
assert intermediate_tensors is not None
|
|
for k, v in intermediate_tensors.items():
|
|
is_scattered = k == "residual" and is_rs
|
|
copy_len = num_tokens // tp if is_scattered else num_tokens
|
|
self.intermediate_tensors[k][:copy_len].copy_(
|
|
v[:copy_len], non_blocking=True
|
|
)
|
|
|
|
return IntermediateTensors(
|
|
{
|
|
k: v[: num_tokens // tp]
|
|
if k == "residual" and is_rs
|
|
else v[:num_tokens]
|
|
for k, v in self.intermediate_tensors.items()
|
|
}
|
|
)
|
|
|
|
def eplb_step(self, is_dummy: bool = False, is_profile: bool = False) -> None:
|
|
"""
|
|
Step for the EPLB (Expert Parallelism Load Balancing) state.
|
|
"""
|
|
if not self.parallel_config.enable_eplb:
|
|
return
|
|
|
|
assert self.eplb_state is not None
|
|
model = self.get_model()
|
|
assert is_mixture_of_experts(model)
|
|
self.eplb_state.step(
|
|
model,
|
|
is_dummy,
|
|
is_profile,
|
|
log_stats=self.parallel_config.eplb_config.log_balancedness,
|
|
)
|
|
|
|
def get_dp_padding(self, num_tokens: int) -> tuple[int, Optional[torch.Tensor]]:
|
|
"""
|
|
Determines the total number of tokens that each rank will run.
|
|
All ranks will be padded out so that they run with the same number
|
|
of tokens
|
|
|
|
Returns: tuple[
|
|
num_pad_tokens: The number of tokens that will be added to the batch
|
|
num_tokens_after_padding: A tensor containing the total number of
|
|
tokens for each DP rank including padding.
|
|
]
|
|
"""
|
|
dp_size = self.vllm_config.parallel_config.data_parallel_size
|
|
dp_rank = self.vllm_config.parallel_config.data_parallel_rank
|
|
|
|
# For DP: Don't pad when setting enforce_eager.
|
|
# This lets us set enforce_eager on the prefiller in a P/D setup and
|
|
# still use CUDA graphs (enabled by this padding) on the decoder.
|
|
#
|
|
# TODO(tms) : There are many cases where padding is enabled for
|
|
# prefills, causing unnecessary and excessive padding of activations.
|
|
|
|
if dp_size == 1 or self.vllm_config.model_config.enforce_eager:
|
|
# Early exit.
|
|
return 0, None
|
|
|
|
num_tokens_across_dp = DPMetadata.num_tokens_across_dp(
|
|
num_tokens, dp_size, dp_rank
|
|
)
|
|
max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp).item()
|
|
num_tokens_after_padding = torch.tensor(
|
|
[max_tokens_across_dp_cpu] * dp_size, device="cpu", dtype=torch.int32
|
|
)
|
|
return max_tokens_across_dp_cpu - num_tokens, num_tokens_after_padding
|
|
|
|
def get_local_padding(self, num_tokens_unpadded: int) -> int:
|
|
num_tokens_padded = num_tokens_unpadded
|
|
|
|
if (
|
|
self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
|
and num_tokens_unpadded <= self.cudagraph_batch_sizes[-1]
|
|
):
|
|
# Use piecewise CUDA graphs.
|
|
# Add padding to the batch size.
|
|
num_tokens_padded = self.vllm_config.pad_for_cudagraph(num_tokens_unpadded)
|
|
else:
|
|
# Eager mode.
|
|
# Pad tokens to multiple of tensor_parallel_size when
|
|
# enabled collective fusion for SP
|
|
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
|
|
if (
|
|
self.vllm_config.compilation_config.pass_config.enable_sequence_parallelism
|
|
and tp_size > 1
|
|
):
|
|
num_tokens_padded = round_up(num_tokens_unpadded, tp_size)
|
|
|
|
num_pad_tokens = num_tokens_padded - num_tokens_unpadded
|
|
return num_pad_tokens
|
|
|
|
# This is where the second ubatch is adjusted to account for the padding.
|
|
# Should be called after attention metadata creation. This just pads
|
|
# the second ubatch slice out to the total number of tokens
|
|
# (num_tokens + padding)
|
|
def pad_out_ubatch_slice(self, ubatch_slices: UBatchSlices, num_total_tokens: int):
|
|
padded_second_ubatch_slice = slice(
|
|
ubatch_slices[1].token_slice.start, num_total_tokens
|
|
)
|
|
ubatch_slices[1] = UBatchSlice(
|
|
padded_second_ubatch_slice, padded_second_ubatch_slice
|
|
)
|
|
|
|
def _pool(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
num_scheduled_tokens: int,
|
|
num_scheduled_tokens_np: np.ndarray,
|
|
) -> ModelRunnerOutput:
|
|
assert self.input_batch.num_reqs == len(self.input_batch.pooling_params), (
|
|
"Either all or none of the requests in a batch must be pooling request"
|
|
)
|
|
|
|
hidden_states = hidden_states[:num_scheduled_tokens]
|
|
pooling_metadata = self.input_batch.get_pooling_metadata()
|
|
pooling_metadata.build_pooling_cursor(
|
|
num_scheduled_tokens_np.tolist(), device=hidden_states.device
|
|
)
|
|
seq_lens_cpu = self.seq_lens.cpu[: self.input_batch.num_reqs]
|
|
|
|
model = cast(VllmModelForPooling, self.model)
|
|
raw_pooler_output: PoolerOutput = model.pooler(
|
|
hidden_states=hidden_states,
|
|
pooling_metadata=pooling_metadata,
|
|
)
|
|
raw_pooler_output = json_map_leaves(
|
|
lambda x: x.to("cpu", non_blocking=True),
|
|
raw_pooler_output,
|
|
)
|
|
self._sync_device()
|
|
|
|
pooler_output: list[Optional[torch.Tensor]] = []
|
|
for raw_output, seq_len, prompt_len in zip(
|
|
raw_pooler_output, seq_lens_cpu, pooling_metadata.prompt_lens
|
|
):
|
|
output = raw_output if seq_len == prompt_len else None
|
|
pooler_output.append(output)
|
|
|
|
return ModelRunnerOutput(
|
|
req_ids=self.input_batch.req_ids,
|
|
req_id_to_index=self.input_batch.req_id_to_index,
|
|
sampled_token_ids=[],
|
|
logprobs=None,
|
|
prompt_logprobs_dict={},
|
|
pooler_output=pooler_output,
|
|
)
|
|
|
|
def _get_num_input_tokens(self, num_scheduled_tokens: int) -> int:
|
|
if (
|
|
self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
|
and not envs.VLLM_DISABLE_PAD_FOR_CUDAGRAPH
|
|
and hasattr(self, "cudagraph_batch_sizes")
|
|
and self.cudagraph_batch_sizes
|
|
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]
|
|
):
|
|
# Use CUDA graphs.
|
|
# Add padding to the batch size.
|
|
return self.vllm_config.pad_for_cudagraph(num_scheduled_tokens)
|
|
|
|
# Eager mode.
|
|
# Pad tokens to multiple of tensor_parallel_size when
|
|
# enabled collective fusion for SP
|
|
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
|
|
if (
|
|
self.compilation_config.pass_config.enable_sequence_parallelism
|
|
and tp_size > 1
|
|
):
|
|
return round_up(num_scheduled_tokens, tp_size)
|
|
return num_scheduled_tokens
|
|
|
|
def _preprocess(
|
|
self,
|
|
scheduler_output: "SchedulerOutput",
|
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
ubatch_slices: Optional[UBatchSlices] = None,
|
|
num_tokens_after_padding: Optional[torch.Tensor] = None,
|
|
) -> tuple[
|
|
int,
|
|
int,
|
|
Optional[torch.Tensor],
|
|
Optional[torch.Tensor],
|
|
Optional[torch.Tensor],
|
|
torch.Tensor,
|
|
Optional[IntermediateTensors],
|
|
dict[str, Any],
|
|
]:
|
|
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
|
if ubatch_slices:
|
|
assert num_tokens_after_padding is not None
|
|
num_input_tokens = int(num_tokens_after_padding[0].item() * 2)
|
|
self.pad_out_ubatch_slice(ubatch_slices, num_input_tokens)
|
|
elif ubatch_slices is None:
|
|
num_input_tokens = self._get_num_input_tokens(num_scheduled_tokens)
|
|
num_pad, num_tokens_after_padding = self.get_dp_padding(num_input_tokens)
|
|
num_input_tokens += num_pad
|
|
|
|
# _prepare_inputs may reorder the batch, so we must gather multi
|
|
# modal outputs after that to ensure the correct order
|
|
if (
|
|
self.supports_mm_inputs
|
|
and get_pp_group().is_first_rank
|
|
and not self.model_config.is_encoder_decoder
|
|
):
|
|
# Run the multimodal encoder if any.
|
|
self._execute_mm_encoder(scheduler_output)
|
|
mm_embeds, is_mm_embed = self._gather_mm_embeddings(scheduler_output)
|
|
|
|
# NOTE(woosuk): To unify token ids and soft tokens (vision
|
|
# embeddings), we always use embeddings (rather than token ids)
|
|
# as input to the multimodal model, even when the input is text.
|
|
inputs_embeds_scheduled = self.model.get_input_embeddings(
|
|
self.input_ids.gpu[:num_scheduled_tokens],
|
|
multimodal_embeddings=mm_embeds,
|
|
is_multimodal=is_mm_embed,
|
|
)
|
|
|
|
# TODO(woosuk): Avoid the copy. Optimize.
|
|
self.inputs_embeds.gpu[:num_scheduled_tokens].copy_(inputs_embeds_scheduled)
|
|
|
|
input_ids = None
|
|
inputs_embeds = self.inputs_embeds.gpu[:num_input_tokens]
|
|
model_kwargs = {
|
|
**self._init_model_kwargs(num_scheduled_tokens),
|
|
**self._extract_mm_kwargs(scheduler_output),
|
|
}
|
|
elif self.enable_prompt_embeds and get_pp_group().is_first_rank:
|
|
# Get the input embeddings for the tokens that are not input embeds,
|
|
# then put them into the appropriate positions.
|
|
# TODO(qthequartermasterman): Since even when prompt embeds are
|
|
# enabled, (a) not all requests will use prompt embeds, and (b)
|
|
# after the initial prompt is processed, the rest of the generated
|
|
# tokens will be token ids, it is not desirable to have the
|
|
# embedding layer outside of the CUDA graph all the time. The v0
|
|
# engine avoids this by "double compiling" the CUDA graph, once
|
|
# with input_ids and again with inputs_embeds, for all num_tokens.
|
|
# If a batch only has token ids, then including the embedding layer
|
|
# in the CUDA graph will be more performant (like in the else case
|
|
# below).
|
|
token_ids_idx = (
|
|
self.is_token_ids.gpu[:num_scheduled_tokens]
|
|
.nonzero(as_tuple=False)
|
|
.squeeze(1)
|
|
)
|
|
# Some tokens ids may need to become embeds
|
|
if token_ids_idx.numel() > 0:
|
|
token_ids = self.input_ids.gpu[token_ids_idx]
|
|
tokens_to_embeds = self.model.get_input_embeddings(input_ids=token_ids)
|
|
self.inputs_embeds.gpu[token_ids_idx] = tokens_to_embeds
|
|
|
|
inputs_embeds = self.inputs_embeds.gpu[:num_input_tokens]
|
|
model_kwargs = self._init_model_kwargs(num_input_tokens)
|
|
input_ids = None
|
|
else:
|
|
# For text-only models, we use token ids as input.
|
|
# While it is possible to use embeddings as input just like the
|
|
# multimodal models, it is not desirable for performance since
|
|
# then the embedding layer is not included in the CUDA graph.
|
|
input_ids = self.input_ids.gpu[:num_input_tokens]
|
|
inputs_embeds = None
|
|
model_kwargs = self._init_model_kwargs(num_input_tokens)
|
|
if self.uses_mrope:
|
|
positions = self.mrope_positions.gpu[:, :num_input_tokens]
|
|
else:
|
|
positions = self.positions.gpu[:num_input_tokens]
|
|
|
|
if get_pp_group().is_first_rank:
|
|
intermediate_tensors = None
|
|
else:
|
|
intermediate_tensors = self.sync_and_slice_intermediate_tensors(
|
|
num_input_tokens, intermediate_tensors, True
|
|
)
|
|
|
|
if (
|
|
self.model_config.is_encoder_decoder
|
|
and scheduler_output.scheduled_encoder_inputs
|
|
):
|
|
encoder_inputs = self._extract_encoder_inputs(scheduler_output)
|
|
model_kwargs.update(encoder_inputs)
|
|
|
|
return (
|
|
num_scheduled_tokens,
|
|
num_input_tokens,
|
|
num_tokens_after_padding,
|
|
input_ids,
|
|
inputs_embeds,
|
|
positions,
|
|
intermediate_tensors,
|
|
model_kwargs,
|
|
)
|
|
|
|
def _sample(
|
|
self,
|
|
logits: Optional[torch.Tensor],
|
|
spec_decode_metadata: Optional[SpecDecodeMetadata],
|
|
) -> SamplerOutput:
|
|
# Sample the next token and get logprobs if needed.
|
|
sampling_metadata = self.input_batch.sampling_metadata
|
|
if spec_decode_metadata is None:
|
|
sampler_output = self.sampler(
|
|
logits=logits,
|
|
sampling_metadata=sampling_metadata,
|
|
)
|
|
else:
|
|
# When indexing with a tensor (bonus_logits_indices), PyTorch
|
|
# creates a new tensor with separate storage from the original
|
|
# logits tensor. This means any in-place operations on bonus_logits
|
|
# won't affect the original logits tensor.
|
|
assert logits is not None
|
|
bonus_logits = logits[spec_decode_metadata.bonus_logits_indices]
|
|
sampler_output = self.sampler(
|
|
logits=bonus_logits,
|
|
sampling_metadata=sampling_metadata,
|
|
)
|
|
bonus_token_ids = sampler_output.sampled_token_ids
|
|
|
|
# Just like `bonus_logits`, `target_logits` is a new tensor with
|
|
# separate storage from the original `logits` tensor. Therefore,
|
|
# it is safe to update `target_logits` in place.
|
|
target_logits = logits[spec_decode_metadata.target_logits_indices]
|
|
output_token_ids = self.rejection_sampler(
|
|
spec_decode_metadata,
|
|
None, # draft_probs
|
|
target_logits,
|
|
bonus_token_ids,
|
|
sampling_metadata,
|
|
)
|
|
sampler_output.sampled_token_ids = output_token_ids
|
|
self._update_states_after_model_execute(output_token_ids)
|
|
|
|
return sampler_output
|
|
|
|
def _bookkeeping_sync(
|
|
self,
|
|
scheduler_output: "SchedulerOutput",
|
|
sampler_output: SamplerOutput,
|
|
logits: Optional[torch.Tensor],
|
|
hidden_states: torch.Tensor,
|
|
num_scheduled_tokens: int,
|
|
) -> tuple[
|
|
dict[str, int],
|
|
Optional[LogprobsLists],
|
|
list[list[int]],
|
|
dict[str, Optional[LogprobsTensors]],
|
|
list[str],
|
|
dict[str, int],
|
|
list[int],
|
|
]:
|
|
num_nans_in_logits = {}
|
|
if envs.VLLM_COMPUTE_NANS_IN_LOGITS:
|
|
num_nans_in_logits = self._get_nans_in_logits(logits)
|
|
|
|
discard_sampled_tokens_req_indices = self.discard_request_indices.np[
|
|
: self.num_discarded_requests
|
|
]
|
|
for i in discard_sampled_tokens_req_indices:
|
|
gen = self.input_batch.generators.get(int(i))
|
|
if gen is not None:
|
|
gen.set_offset(gen.get_offset() - 4)
|
|
|
|
# Copy some objects so they don't get modified after returning.
|
|
# This is important when using async scheduling.
|
|
req_ids_output_copy = self.input_batch.req_ids.copy()
|
|
req_id_to_index_output_copy = self.input_batch.req_id_to_index.copy()
|
|
|
|
# NOTE: GPU -> CPU Sync happens here.
|
|
# Move as many CPU operations as possible before this sync point.
|
|
logprobs_tensors = sampler_output.logprobs_tensors
|
|
logprobs_lists = (
|
|
logprobs_tensors.tolists() if logprobs_tensors is not None else None
|
|
)
|
|
|
|
# Compute prompt logprobs if needed.
|
|
prompt_logprobs_dict = self._get_prompt_logprobs_dict(
|
|
hidden_states[:num_scheduled_tokens],
|
|
scheduler_output.num_scheduled_tokens,
|
|
)
|
|
|
|
num_sampled_tokens = sampler_output.sampled_token_ids.shape[0]
|
|
sampled_token_ids = sampler_output.sampled_token_ids
|
|
invalid_req_indices = []
|
|
if not self.use_async_scheduling:
|
|
# Get the valid generated tokens.
|
|
max_gen_len = sampled_token_ids.shape[-1]
|
|
if max_gen_len == 1:
|
|
# No spec decode tokens.
|
|
valid_sampled_token_ids = self._to_list(sampled_token_ids)
|
|
else:
|
|
# Includes spec decode tokens.
|
|
valid_sampled_token_ids = self.rejection_sampler.parse_output(
|
|
sampled_token_ids,
|
|
self.input_batch.vocab_size,
|
|
)
|
|
# Mask out the sampled tokens that should not be sampled.
|
|
for i in discard_sampled_tokens_req_indices:
|
|
valid_sampled_token_ids[int(i)].clear()
|
|
else:
|
|
valid_sampled_token_ids = []
|
|
invalid_req_indices = discard_sampled_tokens_req_indices.tolist()
|
|
invalid_req_indices_set = set(invalid_req_indices)
|
|
assert sampled_token_ids.shape[-1] == 1
|
|
|
|
# Cache the sampled tokens on the GPU and avoid CPU sync.
|
|
# These will be copied into input_ids in the next step
|
|
# when preparing inputs.
|
|
self.input_batch.prev_sampled_token_ids = sampled_token_ids
|
|
self.input_batch.prev_sampled_token_ids_invalid_indices = (
|
|
invalid_req_indices_set
|
|
)
|
|
self.input_batch.prev_req_id_to_index = {
|
|
req_id: i
|
|
for i, req_id in enumerate(self.input_batch.req_ids)
|
|
if i not in invalid_req_indices_set
|
|
}
|
|
|
|
# Cache the sampled tokens in the model runner, so that the scheduler
|
|
# doesn't need to send them back.
|
|
# NOTE(woosuk): As an exception, when using PP, the scheduler sends
|
|
# the sampled tokens back, because there's no direct communication
|
|
# between the first-stage worker and the last-stage worker.
|
|
req_ids = self.input_batch.req_ids
|
|
for req_idx in range(num_sampled_tokens):
|
|
if self.use_async_scheduling:
|
|
sampled_ids = [-1] if req_idx not in invalid_req_indices_set else None
|
|
else:
|
|
sampled_ids = valid_sampled_token_ids[req_idx]
|
|
if not sampled_ids:
|
|
continue
|
|
|
|
start_idx = self.input_batch.num_tokens_no_spec[req_idx]
|
|
end_idx = start_idx + len(sampled_ids)
|
|
assert end_idx <= self.max_model_len + 1, (
|
|
"Sampled token IDs exceed the max model length + 1. "
|
|
f"Total number of tokens: {end_idx} > max_model_len + 1: "
|
|
f"{self.max_model_len + 1}"
|
|
)
|
|
|
|
n_tokens_cache = len(sampled_ids)
|
|
|
|
# Sampled token IDs exceed the max model length by 1. This is
|
|
# legitimate as we can still sample 1 last token when the context
|
|
# length equals the max model length. Note that we do not need to
|
|
# cache this token ID as the sequence finishes after this step.
|
|
# Additionally, the buffers token_ids_cpu and is_token_ids are of
|
|
# size max model length only.
|
|
if end_idx == self.max_model_len + 1:
|
|
n_tokens_cache -= 1
|
|
|
|
self.input_batch.token_ids_cpu[
|
|
req_idx, start_idx : (start_idx + n_tokens_cache)
|
|
] = sampled_ids[:n_tokens_cache]
|
|
self.input_batch.is_token_ids[
|
|
req_idx, start_idx : (start_idx + n_tokens_cache)
|
|
] = True
|
|
|
|
self.input_batch.num_tokens_no_spec[req_idx] = end_idx
|
|
self.input_batch.num_tokens[req_idx] = end_idx
|
|
|
|
req_id = req_ids[req_idx]
|
|
req_state = self.requests[req_id]
|
|
req_state.output_token_ids.extend(sampled_ids)
|
|
|
|
return (
|
|
num_nans_in_logits,
|
|
logprobs_lists,
|
|
valid_sampled_token_ids,
|
|
prompt_logprobs_dict,
|
|
req_ids_output_copy,
|
|
req_id_to_index_output_copy,
|
|
invalid_req_indices,
|
|
)
|
|
|
|
@contextmanager
|
|
def synchronize_input_prep(self):
|
|
if self.prepare_inputs_event is None:
|
|
yield
|
|
return
|
|
|
|
# Ensure prior step has finished with reused CPU tensors.
|
|
# This is required in the async scheduling case because
|
|
# the CPU->GPU transfer happens async.
|
|
self.prepare_inputs_event.synchronize()
|
|
try:
|
|
yield
|
|
finally:
|
|
self.prepare_inputs_event.record()
|
|
|
|
def _model_forward(
|
|
self,
|
|
input_ids: Optional[torch.Tensor] = None,
|
|
positions: Optional[torch.Tensor] = None,
|
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
**model_kwargs: dict[str, Any],
|
|
) -> Any:
|
|
"""Helper method to call the model forward pass.
|
|
|
|
This method can be overridden by subclasses for model execution.
|
|
Motivation: We can inspect only this method versus
|
|
the whole execute_model, which has additional logic.
|
|
|
|
Args:
|
|
input_ids: Input token IDs
|
|
positions: Token positions
|
|
intermediate_tensors: Tensors from previous pipeline stages
|
|
inputs_embeds: Input embeddings (alternative to input_ids)
|
|
**model_kwargs: Additional model arguments
|
|
|
|
Returns:
|
|
Model output tensor
|
|
"""
|
|
return self.model(
|
|
input_ids=input_ids,
|
|
positions=positions,
|
|
intermediate_tensors=intermediate_tensors,
|
|
inputs_embeds=inputs_embeds,
|
|
**model_kwargs,
|
|
)
|
|
|
|
@torch.inference_mode()
|
|
def execute_model(
|
|
self,
|
|
scheduler_output: "SchedulerOutput",
|
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
) -> Union[ModelRunnerOutput, AsyncModelRunnerOutput, IntermediateTensors]:
|
|
with record_function_or_nullcontext("Preprocess"):
|
|
with self.synchronize_input_prep():
|
|
# Update persistent batch states.
|
|
self._update_states(scheduler_output)
|
|
|
|
if not scheduler_output.total_num_scheduled_tokens:
|
|
if not has_kv_transfer_group():
|
|
# Return empty ModelRunnerOutput if no work to do.
|
|
return EMPTY_MODEL_RUNNER_OUTPUT
|
|
return self.kv_connector_no_forward(
|
|
scheduler_output, self.vllm_config
|
|
)
|
|
if self.cache_config.kv_sharing_fast_prefill:
|
|
assert not self.input_batch.num_prompt_logprobs, (
|
|
"--kv-sharing-fast-prefill produces incorrect "
|
|
"logprobs for prompt tokens, tokens, please disable "
|
|
"it when the requests need prompt logprobs"
|
|
)
|
|
|
|
# Prepare the decoder inputs.
|
|
(
|
|
attn_metadata,
|
|
logits_indices,
|
|
spec_decode_metadata,
|
|
num_scheduled_tokens_np,
|
|
spec_decode_common_attn_metadata,
|
|
max_query_len,
|
|
ubatch_slices,
|
|
num_tokens_after_padding,
|
|
use_cascade_attn,
|
|
) = self._prepare_inputs(scheduler_output)
|
|
|
|
(
|
|
num_scheduled_tokens,
|
|
num_input_tokens,
|
|
num_tokens_across_dp,
|
|
input_ids,
|
|
inputs_embeds,
|
|
positions,
|
|
intermediate_tensors,
|
|
model_kwargs,
|
|
) = self._preprocess(
|
|
scheduler_output,
|
|
intermediate_tensors,
|
|
ubatch_slices,
|
|
num_tokens_after_padding,
|
|
)
|
|
|
|
uniform_decode = (max_query_len == self.uniform_decode_query_len) and (
|
|
num_scheduled_tokens == self.input_batch.num_reqs * max_query_len
|
|
)
|
|
batch_descriptor = BatchDescriptor(
|
|
num_tokens=num_input_tokens, uniform_decode=uniform_decode
|
|
)
|
|
cudagraph_runtime_mode, batch_descriptor = (
|
|
self.cudagraph_dispatcher.dispatch(batch_descriptor, use_cascade_attn)
|
|
)
|
|
|
|
# Set cudagraph mode to none if calc_kv_scales is true.
|
|
if attn_metadata is not None:
|
|
metadata_list = (
|
|
attn_metadata.values()
|
|
if isinstance(attn_metadata, dict)
|
|
else [attn_metadata]
|
|
)
|
|
if any(
|
|
getattr(m, "enable_kv_scales_calculation", False) for m in metadata_list
|
|
):
|
|
cudagraph_runtime_mode = CUDAGraphMode.NONE
|
|
|
|
# This is currently to get around the assert in the DPMetadata
|
|
# where it wants `num_tokens_across_dp` to align with `num_tokens`
|
|
if ubatch_slices is not None:
|
|
num_input_tokens = ubatch_slices[0].num_tokens
|
|
|
|
# Run the model.
|
|
# Use persistent buffers for CUDA graphs.
|
|
with (
|
|
set_forward_context(
|
|
attn_metadata,
|
|
self.vllm_config,
|
|
num_tokens=num_input_tokens,
|
|
num_tokens_across_dp=num_tokens_across_dp,
|
|
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
|
batch_descriptor=batch_descriptor,
|
|
ubatch_slices=ubatch_slices,
|
|
),
|
|
record_function_or_nullcontext("Forward"),
|
|
self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output,
|
|
):
|
|
model_output = self._model_forward(
|
|
input_ids=input_ids,
|
|
positions=positions,
|
|
intermediate_tensors=intermediate_tensors,
|
|
inputs_embeds=inputs_embeds,
|
|
**model_kwargs,
|
|
)
|
|
|
|
with record_function_or_nullcontext("Postprocess"):
|
|
if self.use_aux_hidden_state_outputs:
|
|
# True when EAGLE 3 is used.
|
|
hidden_states, aux_hidden_states = model_output
|
|
else:
|
|
# Common case.
|
|
hidden_states = model_output
|
|
aux_hidden_states = None
|
|
|
|
if not self.broadcast_pp_output:
|
|
# Common case.
|
|
if not get_pp_group().is_last_rank:
|
|
# Return the intermediate tensors.
|
|
assert isinstance(hidden_states, IntermediateTensors)
|
|
hidden_states.kv_connector_output = kv_connector_output
|
|
return hidden_states
|
|
|
|
if self.is_pooling_model:
|
|
# Return the pooling output.
|
|
output = self._pool(
|
|
hidden_states, num_scheduled_tokens, num_scheduled_tokens_np
|
|
)
|
|
output.kv_connector_output = kv_connector_output
|
|
return output
|
|
|
|
sample_hidden_states = hidden_states[logits_indices]
|
|
logits = self.model.compute_logits(sample_hidden_states)
|
|
else:
|
|
# Rare case.
|
|
assert not self.is_pooling_model
|
|
|
|
if not get_pp_group().is_last_rank:
|
|
all_gather_tensors = {
|
|
"residual": not is_residual_scattered_for_sp(
|
|
self.vllm_config, num_input_tokens
|
|
)
|
|
}
|
|
get_pp_group().send_tensor_dict(
|
|
hidden_states.tensors,
|
|
all_gather_group=get_tp_group(),
|
|
all_gather_tensors=all_gather_tensors,
|
|
)
|
|
logits = None
|
|
else:
|
|
sample_hidden_states = hidden_states[logits_indices]
|
|
logits = self.model.compute_logits(sample_hidden_states)
|
|
|
|
model_output_broadcast_data = {}
|
|
if logits is not None:
|
|
model_output_broadcast_data["logits"] = logits.contiguous()
|
|
|
|
model_output_broadcast_data = get_pp_group().broadcast_tensor_dict(
|
|
model_output_broadcast_data, src=len(get_pp_group().ranks) - 1
|
|
)
|
|
assert model_output_broadcast_data is not None
|
|
logits = model_output_broadcast_data["logits"]
|
|
|
|
# Apply structured output bitmasks if present
|
|
if scheduler_output.grammar_bitmask is not None:
|
|
apply_grammar_bitmask(
|
|
scheduler_output, self.input_batch, logits, self.device
|
|
)
|
|
|
|
with record_function_or_nullcontext("Sample"):
|
|
sampler_output = self._sample(logits, spec_decode_metadata)
|
|
|
|
def propose_draft_token_ids(sampled_token_ids):
|
|
assert spec_decode_common_attn_metadata is not None
|
|
with record_function_or_nullcontext("Draft"):
|
|
self._draft_token_ids = self.propose_draft_token_ids(
|
|
scheduler_output,
|
|
sampled_token_ids,
|
|
self.input_batch.sampling_metadata,
|
|
hidden_states,
|
|
sample_hidden_states,
|
|
aux_hidden_states,
|
|
spec_decode_metadata,
|
|
spec_decode_common_attn_metadata,
|
|
)
|
|
|
|
use_padded_batch_for_eagle = (
|
|
self.speculative_config
|
|
and self.speculative_config.use_eagle()
|
|
and not self.speculative_config.disable_padded_drafter_batch
|
|
)
|
|
effective_drafter_max_model_len = self.max_model_len
|
|
if effective_drafter_max_model_len is None:
|
|
effective_drafter_max_model_len = self.model_config.max_model_len
|
|
if (
|
|
self.speculative_config
|
|
and self.speculative_config.draft_model_config is not None
|
|
and self.speculative_config.draft_model_config.max_model_len is not None
|
|
):
|
|
effective_drafter_max_model_len = (
|
|
self.speculative_config.draft_model_config.max_model_len
|
|
)
|
|
input_fits_in_drafter = spec_decode_common_attn_metadata and (
|
|
spec_decode_common_attn_metadata.max_seq_len
|
|
+ self.speculative_config.num_speculative_tokens
|
|
<= effective_drafter_max_model_len
|
|
)
|
|
if use_padded_batch_for_eagle and input_fits_in_drafter:
|
|
# EAGLE speculative decoding can use the GPU sampled tokens
|
|
# as inputs, and does not need to wait for bookkeeping to finish.
|
|
propose_draft_token_ids(sampler_output.sampled_token_ids)
|
|
|
|
with record_function_or_nullcontext("Bookkeep"):
|
|
(
|
|
num_nans_in_logits,
|
|
logprobs_lists,
|
|
valid_sampled_token_ids,
|
|
prompt_logprobs_dict,
|
|
req_ids_output_copy,
|
|
req_id_to_index_output_copy,
|
|
invalid_req_indices,
|
|
) = self._bookkeeping_sync(
|
|
scheduler_output,
|
|
sampler_output,
|
|
logits,
|
|
hidden_states,
|
|
num_scheduled_tokens,
|
|
)
|
|
|
|
if (
|
|
self.speculative_config
|
|
and not use_padded_batch_for_eagle
|
|
and input_fits_in_drafter
|
|
):
|
|
# ngram and other speculative decoding methods use the sampled
|
|
# tokens on the CPU, so they are run after bookkeeping.
|
|
propose_draft_token_ids(valid_sampled_token_ids)
|
|
|
|
with record_function_or_nullcontext("EPLB"):
|
|
self.eplb_step()
|
|
|
|
output = ModelRunnerOutput(
|
|
req_ids=req_ids_output_copy,
|
|
req_id_to_index=req_id_to_index_output_copy,
|
|
sampled_token_ids=valid_sampled_token_ids,
|
|
logprobs=logprobs_lists,
|
|
prompt_logprobs_dict=prompt_logprobs_dict,
|
|
pooler_output=[],
|
|
kv_connector_output=kv_connector_output,
|
|
num_nans_in_logits=num_nans_in_logits,
|
|
)
|
|
|
|
if not self.use_async_scheduling:
|
|
return output
|
|
|
|
return AsyncGPUModelRunnerOutput(
|
|
model_runner_output=output,
|
|
sampled_token_ids=sampler_output.sampled_token_ids,
|
|
invalid_req_indices=invalid_req_indices,
|
|
async_output_copy_stream=self.async_output_copy_stream,
|
|
)
|
|
|
|
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
|
|
if self._draft_token_ids is None:
|
|
return None
|
|
req_ids = self.input_batch.req_ids
|
|
if isinstance(self._draft_token_ids, torch.Tensor):
|
|
draft_token_ids = self._draft_token_ids.tolist()
|
|
else:
|
|
draft_token_ids = self._draft_token_ids
|
|
self._draft_token_ids = None
|
|
return DraftTokenIds(req_ids, draft_token_ids)
|
|
|
|
def propose_draft_token_ids(
|
|
self,
|
|
scheduler_output: "SchedulerOutput",
|
|
sampled_token_ids: Union[torch.Tensor, list[list[int]]],
|
|
sampling_metadata: SamplingMetadata,
|
|
hidden_states: torch.Tensor,
|
|
sample_hidden_states: torch.Tensor,
|
|
aux_hidden_states: Optional[list[torch.Tensor]],
|
|
spec_decode_metadata: Optional[SpecDecodeMetadata],
|
|
common_attn_metadata: CommonAttentionMetadata,
|
|
) -> Union[list[list[int]], torch.Tensor]:
|
|
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
|
if self.speculative_config.method == "ngram":
|
|
assert isinstance(sampled_token_ids, list)
|
|
assert isinstance(self.drafter, NgramProposer)
|
|
draft_token_ids = self.drafter.propose(
|
|
sampled_token_ids,
|
|
self.input_batch.req_ids,
|
|
self.input_batch.num_tokens_no_spec,
|
|
self.input_batch.token_ids_cpu,
|
|
self.input_batch.spec_decode_unsupported_reqs,
|
|
)
|
|
elif self.speculative_config.method == "medusa":
|
|
assert isinstance(sampled_token_ids, list)
|
|
assert isinstance(self.drafter, MedusaProposer)
|
|
|
|
if sample_hidden_states.shape[0] == len(sampled_token_ids):
|
|
# The input to the target model does not include draft tokens.
|
|
hidden_states = sample_hidden_states
|
|
else:
|
|
indices = []
|
|
offset = 0
|
|
assert spec_decode_metadata is not None
|
|
for num_draft, tokens in zip(
|
|
spec_decode_metadata.num_draft_tokens, sampled_token_ids
|
|
):
|
|
indices.append(offset + len(tokens) - 1)
|
|
offset += num_draft + 1
|
|
indices = torch.tensor(indices, device=self.device)
|
|
hidden_states = sample_hidden_states[indices]
|
|
|
|
draft_token_ids = self.drafter.propose(
|
|
target_hidden_states=hidden_states,
|
|
sampling_metadata=sampling_metadata,
|
|
)
|
|
elif self.speculative_config.use_eagle():
|
|
assert isinstance(self.drafter, EagleProposer)
|
|
|
|
if self.speculative_config.disable_padded_drafter_batch:
|
|
# When padded-batch is disabled, the sampled_token_ids should be
|
|
# the cpu-side list[list[int]] of valid sampled tokens for each
|
|
# request, with invalid requests having empty lists.
|
|
assert isinstance(sampled_token_ids, list), (
|
|
"sampled_token_ids should be a python list when"
|
|
"padded-batch is disabled."
|
|
)
|
|
next_token_ids = self.drafter.prepare_next_token_ids_cpu(
|
|
sampled_token_ids,
|
|
self.requests,
|
|
self.input_batch,
|
|
scheduler_output.num_scheduled_tokens,
|
|
)
|
|
else:
|
|
# When using padded-batch, the sampled_token_ids should be
|
|
# the gpu tensor of sampled tokens for each request, of shape
|
|
# (num_reqs, num_spec_tokens + 1) with rejected tokens having
|
|
# value -1.
|
|
assert isinstance(sampled_token_ids, torch.Tensor), (
|
|
"sampled_token_ids should be a torch.Tensor when"
|
|
"padded-batch is enabled."
|
|
)
|
|
next_token_ids, valid_sampled_tokens_count = (
|
|
self.drafter.prepare_next_token_ids_padded(
|
|
common_attn_metadata,
|
|
sampled_token_ids,
|
|
self.requests,
|
|
self.input_batch,
|
|
self.discard_request_indices.gpu,
|
|
self.num_discarded_requests,
|
|
)
|
|
)
|
|
|
|
if spec_decode_metadata is None:
|
|
token_indices_to_sample = None
|
|
# input_ids can be None for multimodal models.
|
|
target_token_ids = self.input_ids.gpu[:num_scheduled_tokens]
|
|
target_positions = self._get_positions(num_scheduled_tokens)
|
|
if self.use_aux_hidden_state_outputs:
|
|
assert aux_hidden_states is not None
|
|
target_hidden_states = torch.cat(
|
|
[h[:num_scheduled_tokens] for h in aux_hidden_states], dim=-1
|
|
)
|
|
else:
|
|
target_hidden_states = hidden_states[:num_scheduled_tokens]
|
|
else:
|
|
if self.speculative_config.disable_padded_drafter_batch:
|
|
token_indices_to_sample = None
|
|
common_attn_metadata, token_indices = self.drafter.prepare_inputs(
|
|
common_attn_metadata,
|
|
sampled_token_ids,
|
|
spec_decode_metadata.num_draft_tokens,
|
|
)
|
|
else:
|
|
common_attn_metadata, token_indices, token_indices_to_sample = (
|
|
self.drafter.prepare_inputs_padded(
|
|
common_attn_metadata,
|
|
spec_decode_metadata,
|
|
valid_sampled_tokens_count,
|
|
)
|
|
)
|
|
|
|
target_token_ids = self.input_ids.gpu[token_indices]
|
|
target_positions = self._get_positions(token_indices)
|
|
if self.use_aux_hidden_state_outputs:
|
|
assert aux_hidden_states is not None
|
|
target_hidden_states = torch.cat(
|
|
[h[token_indices] for h in aux_hidden_states], dim=-1
|
|
)
|
|
else:
|
|
target_hidden_states = hidden_states[token_indices]
|
|
|
|
if self.supports_mm_inputs:
|
|
mm_embed_inputs = self._gather_mm_embeddings(
|
|
scheduler_output,
|
|
shift_computed_tokens=1,
|
|
)
|
|
else:
|
|
mm_embed_inputs = None
|
|
|
|
draft_token_ids = self.drafter.propose(
|
|
target_token_ids=target_token_ids,
|
|
target_positions=target_positions,
|
|
target_hidden_states=target_hidden_states,
|
|
next_token_ids=next_token_ids,
|
|
last_token_indices=token_indices_to_sample,
|
|
sampling_metadata=sampling_metadata,
|
|
common_attn_metadata=common_attn_metadata,
|
|
mm_embed_inputs=mm_embed_inputs,
|
|
)
|
|
|
|
return draft_token_ids
|
|
|
|
def update_config(self, overrides: dict[str, Any]) -> None:
|
|
allowed_config_names = {"load_config", "model_config"}
|
|
for config_name, config_overrides in overrides.items():
|
|
assert config_name in allowed_config_names, (
|
|
f"Config `{config_name}` not supported. "
|
|
f"Allowed configs: {allowed_config_names}"
|
|
)
|
|
config = getattr(self, config_name)
|
|
new_config = update_config(config, config_overrides)
|
|
setattr(self, config_name, new_config)
|
|
|
|
def load_model(self, eep_scale_up: bool = False) -> None:
|
|
"""
|
|
Args:
|
|
eep_scale_up: the model loading is for elastic EP scale up.
|
|
"""
|
|
logger.info("Starting to load model %s...", self.model_config.model)
|
|
if eep_scale_up:
|
|
from vllm.distributed.parallel_state import get_ep_group
|
|
|
|
num_local_physical_experts = torch.empty(1, dtype=torch.int32, device="cpu")
|
|
torch.distributed.broadcast(
|
|
num_local_physical_experts, group=get_ep_group().cpu_group, group_src=0
|
|
)
|
|
num_local_physical_experts = int(num_local_physical_experts.item())
|
|
new_ep_size = get_ep_group().world_size
|
|
global_expert_load, old_global_expert_indices = EplbState.recv_state()
|
|
num_logical_experts = global_expert_load.shape[1]
|
|
self.parallel_config.eplb_config.num_redundant_experts = (
|
|
num_local_physical_experts * new_ep_size - num_logical_experts
|
|
)
|
|
assert old_global_expert_indices.shape[1] % num_local_physical_experts == 0
|
|
old_ep_size = (
|
|
old_global_expert_indices.shape[1] // num_local_physical_experts
|
|
)
|
|
rank_mapping = {
|
|
old_ep_rank: old_ep_rank for old_ep_rank in range(old_ep_size)
|
|
}
|
|
else:
|
|
global_expert_load = None
|
|
old_global_expert_indices = None
|
|
rank_mapping = None
|
|
|
|
with DeviceMemoryProfiler() as m:
|
|
time_before_load = time.perf_counter()
|
|
model_loader = get_model_loader(self.load_config)
|
|
logger.info("Loading model from scratch...")
|
|
self.model = model_loader.load_model(
|
|
vllm_config=self.vllm_config, model_config=self.model_config
|
|
)
|
|
if self.lora_config:
|
|
self.model = self.load_lora_model(
|
|
self.model, self.vllm_config, self.device
|
|
)
|
|
if hasattr(self, "drafter"):
|
|
logger.info("Loading drafter model...")
|
|
self.drafter.load_model(self.model)
|
|
if self.use_aux_hidden_state_outputs:
|
|
if not supports_eagle3(self.model):
|
|
raise RuntimeError(
|
|
"Model does not support EAGLE3 interface but "
|
|
"aux_hidden_state_outputs was requested"
|
|
)
|
|
|
|
# Try to get auxiliary layers from speculative config,
|
|
# otherwise use model's default layers
|
|
aux_layers = self._get_eagle3_aux_layers_from_config()
|
|
if aux_layers:
|
|
logger.info(
|
|
"Using auxiliary layers from speculative config: %s",
|
|
aux_layers,
|
|
)
|
|
else:
|
|
aux_layers = self.model.get_eagle3_aux_hidden_state_layers()
|
|
|
|
self.model.set_aux_hidden_state_layers(aux_layers)
|
|
time_after_load = time.perf_counter()
|
|
self.model_memory_usage = m.consumed_memory
|
|
logger.info(
|
|
"Model loading took %.4f GiB and %.6f seconds",
|
|
self.model_memory_usage / GiB_bytes,
|
|
time_after_load - time_before_load,
|
|
)
|
|
prepare_communication_buffer_for_model(self.model)
|
|
|
|
self.is_multimodal_pruning_enabled = (
|
|
supports_multimodal_pruning(self.model)
|
|
and self.model_config.multimodal_config.is_multimodal_pruning_enabled()
|
|
)
|
|
|
|
if is_mixture_of_experts(self.model) and self.parallel_config.enable_eplb:
|
|
logger.info("EPLB is enabled for model %s.", self.model_config.model)
|
|
self.eplb_state = EplbState.build(
|
|
self.model,
|
|
self.device,
|
|
self.parallel_config,
|
|
global_expert_load,
|
|
old_global_expert_indices,
|
|
rank_mapping,
|
|
)
|
|
|
|
if (
|
|
self.vllm_config.compilation_config.level == CompilationLevel.DYNAMO_AS_IS
|
|
and supports_dynamo()
|
|
):
|
|
backend = self.vllm_config.compilation_config.init_backend(self.vllm_config)
|
|
compilation_counter.dynamo_as_is_count += 1
|
|
self.model.compile(fullgraph=True, backend=backend)
|
|
return
|
|
# for other compilation levels, cudagraph behavior is controlled by
|
|
# CudagraphWraper and CudagraphDispatcher of vllm.
|
|
|
|
# wrap the model with full cudagraph wrapper if needed.
|
|
if (
|
|
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
|
and not self.parallel_config.enable_dbo
|
|
):
|
|
self.model = CUDAGraphWrapper(
|
|
self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL
|
|
)
|
|
elif self.parallel_config.enable_dbo:
|
|
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
|
self.model = UBatchWrapper(
|
|
self.model, self.vllm_config, CUDAGraphMode.FULL, self.device
|
|
)
|
|
else:
|
|
self.model = UBatchWrapper(
|
|
self.model, self.vllm_config, CUDAGraphMode.NONE, self.device
|
|
)
|
|
|
|
def _get_eagle3_aux_layers_from_config(self) -> Optional[tuple[int, ...]]:
|
|
"""Extract Eagle3 auxiliary layer indices from speculative config.
|
|
|
|
These indices specify which hidden states from the base model should
|
|
be used as auxiliary inputs for the Eagle3 drafter model during
|
|
speculative decoding.
|
|
|
|
Returns:
|
|
Tuple of layer indices if found in draft model config,
|
|
None otherwise.
|
|
"""
|
|
if not (self.speculative_config and self.speculative_config.draft_model_config):
|
|
return None
|
|
|
|
hf_config = self.speculative_config.draft_model_config.hf_config
|
|
if not hasattr(hf_config, "eagle_aux_hidden_state_layer_ids"):
|
|
return None
|
|
|
|
layer_ids = hf_config.eagle_aux_hidden_state_layer_ids
|
|
if layer_ids and isinstance(layer_ids, (list, tuple)):
|
|
return tuple(layer_ids)
|
|
|
|
return None
|
|
|
|
def reload_weights(self) -> None:
|
|
assert getattr(self, "model", None) is not None, (
|
|
"Cannot reload weights before model is loaded."
|
|
)
|
|
model_loader = get_model_loader(self.load_config)
|
|
logger.info("Reloading weights inplace...")
|
|
model_loader.load_weights(self.get_model(), model_config=self.model_config)
|
|
|
|
def save_tensorized_model(
|
|
self,
|
|
tensorizer_config: "TensorizerConfig",
|
|
) -> None:
|
|
TensorizerLoader.save_model(
|
|
self.get_model(),
|
|
tensorizer_config=tensorizer_config,
|
|
model_config=self.model_config,
|
|
)
|
|
|
|
def _get_prompt_logprobs_dict(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
num_scheduled_tokens: dict[str, int],
|
|
) -> dict[str, Optional[LogprobsTensors]]:
|
|
num_prompt_logprobs_dict = self.input_batch.num_prompt_logprobs
|
|
if not num_prompt_logprobs_dict:
|
|
return {}
|
|
|
|
in_progress_dict = self.input_batch.in_progress_prompt_logprobs_cpu
|
|
prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] = {}
|
|
|
|
# Since prompt logprobs are a rare feature, prioritize simple,
|
|
# maintainable loop over optimal performance.
|
|
completed_prefill_reqs = []
|
|
for req_id, num_prompt_logprobs in num_prompt_logprobs_dict.items():
|
|
num_tokens = num_scheduled_tokens[req_id]
|
|
|
|
# Get metadata for this request.
|
|
request = self.requests[req_id]
|
|
if request.prompt_token_ids is None:
|
|
# Prompt logprobs is incompatible with prompt embeddings
|
|
continue
|
|
|
|
num_prompt_tokens = len(request.prompt_token_ids)
|
|
prompt_token_ids = torch.tensor(request.prompt_token_ids).to(
|
|
self.device, non_blocking=True
|
|
)
|
|
|
|
# Set up target LogprobsTensors object.
|
|
logprobs_tensors = in_progress_dict.get(req_id)
|
|
if not logprobs_tensors:
|
|
# Create empty logprobs CPU tensors for the entire prompt.
|
|
# If chunked, we'll copy in slice by slice.
|
|
logprobs_tensors = LogprobsTensors.empty_cpu(
|
|
num_prompt_tokens - 1, num_prompt_logprobs + 1
|
|
)
|
|
in_progress_dict[req_id] = logprobs_tensors
|
|
|
|
# Determine number of logits to retrieve.
|
|
start_idx = request.num_computed_tokens
|
|
start_tok = start_idx + 1
|
|
num_remaining_tokens = num_prompt_tokens - start_tok
|
|
if num_tokens <= num_remaining_tokens:
|
|
# This is a chunk, more tokens remain.
|
|
# In the == case, there are no more prompt logprobs to produce
|
|
# but we want to defer returning them to the next step where we
|
|
# have new generated tokens to return.
|
|
num_logits = num_tokens
|
|
else:
|
|
# This is the last chunk of prompt tokens to return.
|
|
num_logits = num_remaining_tokens
|
|
completed_prefill_reqs.append(req_id)
|
|
prompt_logprobs_dict[req_id] = logprobs_tensors
|
|
|
|
if num_logits <= 0:
|
|
# This can happen for the final chunk if we prefilled exactly
|
|
# (num_prompt_tokens - 1) tokens for this request in the prior
|
|
# step. There are no more prompt logprobs to produce.
|
|
continue
|
|
|
|
# Get the logits corresponding to this req's prompt tokens.
|
|
# If this is a partial request (i.e. chunked prefill),
|
|
# then there is prompt logprob generated for each index.
|
|
req_idx = self.input_batch.req_id_to_index[req_id]
|
|
offset = self.query_start_loc.np[req_idx].item()
|
|
prompt_hidden_states = hidden_states[offset : offset + num_logits]
|
|
logits = self.model.compute_logits(prompt_hidden_states)
|
|
|
|
# Get the "target" tokens for each index. For prompt at index i,
|
|
# the token at prompt index i+1 is the "sampled" token we want
|
|
# to gather the logprob for.
|
|
tgt_token_ids = prompt_token_ids[start_tok : start_tok + num_logits]
|
|
|
|
# Compute prompt logprobs.
|
|
logprobs = self.sampler.compute_logprobs(logits)
|
|
token_ids, logprobs, ranks = self.sampler.gather_logprobs(
|
|
logprobs, num_prompt_logprobs, tgt_token_ids
|
|
)
|
|
|
|
# Transfer GPU->CPU async.
|
|
chunk_slice = slice(start_idx, start_idx + num_logits)
|
|
logprobs_tensors.logprob_token_ids[chunk_slice].copy_(
|
|
token_ids, non_blocking=True
|
|
)
|
|
logprobs_tensors.logprobs[chunk_slice].copy_(logprobs, non_blocking=True)
|
|
logprobs_tensors.selected_token_ranks[chunk_slice].copy_(
|
|
ranks, non_blocking=True
|
|
)
|
|
|
|
# Remove requests that have completed prefill from the batch
|
|
# num_prompt_logprobs_dict.
|
|
for req_id in completed_prefill_reqs:
|
|
del num_prompt_logprobs_dict[req_id]
|
|
del in_progress_dict[req_id]
|
|
|
|
# Must synchronize the non-blocking GPU->CPU transfers.
|
|
if prompt_logprobs_dict:
|
|
self._sync_device()
|
|
|
|
return prompt_logprobs_dict
|
|
|
|
def _get_nans_in_logits(
|
|
self,
|
|
logits: Optional[torch.Tensor],
|
|
) -> dict[str, int]:
|
|
try:
|
|
if logits is None:
|
|
return {req_id: 0 for req_id in self.input_batch.req_ids}
|
|
|
|
num_nans_in_logits = {}
|
|
num_nans_for_index = logits.isnan().sum(dim=-1).cpu().numpy()
|
|
for req_id in self.input_batch.req_ids:
|
|
req_index = self.input_batch.req_id_to_index[req_id]
|
|
num_nans_in_logits[req_id] = (
|
|
int(num_nans_for_index[req_index])
|
|
if num_nans_for_index is not None and req_index < logits.shape[0]
|
|
else 0
|
|
)
|
|
return num_nans_in_logits
|
|
except IndexError:
|
|
return {}
|
|
|
|
@contextmanager
|
|
def maybe_randomize_inputs(self, input_ids: torch.Tensor):
|
|
"""
|
|
Randomize input_ids if VLLM_RANDOMIZE_DP_DUMMY_INPUTS is set.
|
|
This is to help balance expert-selection
|
|
- during profile_run
|
|
- during DP rank dummy run
|
|
"""
|
|
dp_size = self.vllm_config.parallel_config.data_parallel_size
|
|
randomize_inputs = envs.VLLM_RANDOMIZE_DP_DUMMY_INPUTS and dp_size > 1
|
|
if not randomize_inputs:
|
|
yield
|
|
else:
|
|
import functools
|
|
|
|
@functools.cache
|
|
def rand_input_ids() -> torch.Tensor:
|
|
return torch.randint_like(
|
|
self.input_ids.gpu,
|
|
low=0,
|
|
high=self.model_config.get_vocab_size(),
|
|
dtype=input_ids.dtype,
|
|
)
|
|
|
|
logger.debug_once("Randomizing dummy data for DP Rank")
|
|
input_ids.copy_(rand_input_ids()[: input_ids.size(0)], non_blocking=True)
|
|
yield
|
|
input_ids.fill_(0)
|
|
|
|
def _get_mm_dummy_batch(
|
|
self,
|
|
modality: str,
|
|
max_items_per_batch: int,
|
|
) -> BatchedTensorInputs:
|
|
"""Dummy data for profiling and precompiling multimodal models."""
|
|
assert self.mm_budget is not None
|
|
|
|
dummy_decoder_data = self.mm_registry.get_decoder_dummy_data(
|
|
model_config=self.model_config,
|
|
seq_len=self.max_model_len,
|
|
mm_counts={modality: 1},
|
|
cache=self.mm_budget.cache,
|
|
)
|
|
dummy_mm_data = dummy_decoder_data.multi_modal_data
|
|
|
|
# Result in the maximum GPU consumption of the model
|
|
dummy_mm_item = dummy_mm_data[modality][0]
|
|
dummy_mm_items = [dummy_mm_item] * max_items_per_batch
|
|
|
|
model = cast(SupportsMultiModal, self.model)
|
|
return next(
|
|
mm_kwargs_group
|
|
for _, _, mm_kwargs_group in group_mm_kwargs_by_modality(
|
|
dummy_mm_items,
|
|
device=self.device,
|
|
pin_memory=self.pin_memory,
|
|
merge_by_field_config=model.merge_by_field_config,
|
|
)
|
|
)
|
|
|
|
@torch.inference_mode()
|
|
def _dummy_run(
|
|
self,
|
|
num_tokens: int,
|
|
cudagraph_runtime_mode: Optional[CUDAGraphMode] = None,
|
|
force_attention: bool = False,
|
|
uniform_decode: bool = False,
|
|
allow_microbatching: bool = True,
|
|
skip_eplb: bool = False,
|
|
is_profile: bool = False,
|
|
create_mixed_batch: bool = False,
|
|
remove_lora: bool = True,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Run a dummy forward pass to warm up/profile run or capture the
|
|
CUDA graph for the model.
|
|
|
|
Args:
|
|
num_tokens: Number of tokens to run the dummy forward pass.
|
|
cudagraph_runtime_mode: used to control the behavior.
|
|
- if not set will determine the cudagraph mode based on using
|
|
the self.cudagraph_dispatcher.
|
|
- CUDAGraphMode.NONE: No cudagraph, for warm up and profile run
|
|
- CUDAGraphMode.PIECEWISE: Piecewise cudagraph.
|
|
- CUDAGraphMode.FULL: Full cudagraph, attention metadata is
|
|
needed.
|
|
force_attention: If True, always create attention metadata. Used to
|
|
warm up attention backend when mode is NONE.
|
|
uniform_decode: If True, the batch is a uniform decode batch.
|
|
skip_eplb: If True, skip EPLB state update.
|
|
is_profile: If True, this is a profile run.
|
|
create_mixed_batch: If True, create a mixed batch with both decode
|
|
(1 token) and prefill (multiple tokens) requests.
|
|
remove_lora: If False, dummy LoRAs are not destroyed after the run
|
|
"""
|
|
assert (
|
|
cudagraph_runtime_mode is None
|
|
or cudagraph_runtime_mode.valid_runtime_modes()
|
|
)
|
|
|
|
# If cudagraph_mode.decode_mode() == FULL and
|
|
# cudagraph_mode.separate_routine(). This means that we are using
|
|
# different graphs and/or modes for mixed prefill-decode batches vs.
|
|
# uniform decode batches. A uniform decode batch means that all
|
|
# requests have identical query length, except a potential virtual
|
|
# request (shorter) in the batch account for padding.
|
|
# Uniform decode batch could either be common pure decode, where
|
|
# max_query_len == 1, or speculative decode, where
|
|
# max_query_len == 1 + num_spec_decode_tokens.
|
|
|
|
# When setting max_query_len = 1, we switch to and capture the optimized
|
|
# routine of FA2 for pure decode, i.e., Flashdecode + an optimization
|
|
# for GQA/MQA.
|
|
max_query_len = self.uniform_decode_query_len if uniform_decode else num_tokens
|
|
|
|
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
|
|
# for dummy run with LoRA so that the num_reqs collectively
|
|
# has num_tokens in total.
|
|
assert num_tokens <= self.scheduler_config.max_num_batched_tokens
|
|
max_num_reqs = self.scheduler_config.max_num_seqs
|
|
if create_mixed_batch:
|
|
assert not uniform_decode
|
|
# Create mixed batch:
|
|
# first half decode tokens, second half one prefill
|
|
num_decode_tokens = min(max_num_reqs - 1, num_tokens // 2)
|
|
num_prefill_tokens = num_tokens - num_decode_tokens
|
|
num_reqs = num_decode_tokens + 1
|
|
|
|
# Create decode requests (1 token each) followed by prefill request
|
|
num_scheduled_tokens_list = [1] * num_decode_tokens + [num_prefill_tokens]
|
|
# Note: Overriding max_query_len to be the prefill tokens
|
|
max_query_len = num_prefill_tokens
|
|
elif uniform_decode:
|
|
assert not create_mixed_batch
|
|
num_reqs = min(max_num_reqs, cdiv(num_tokens, max_query_len))
|
|
num_scheduled_tokens_list = [max_query_len] * num_reqs
|
|
if num_tokens % max_query_len != 0:
|
|
num_scheduled_tokens_list[-1] = num_tokens % max_query_len
|
|
else:
|
|
num_reqs = min(num_tokens, max_num_reqs)
|
|
min_tokens_per_req = num_tokens // num_reqs
|
|
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
|
|
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
|
|
|
|
assert sum(num_scheduled_tokens_list) == num_tokens
|
|
assert len(num_scheduled_tokens_list) == num_reqs
|
|
num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32)
|
|
total_num_scheduled_tokens = int(num_scheduled_tokens.sum())
|
|
|
|
ubatch_slices = None
|
|
num_tokens_after_padding = None
|
|
|
|
# We currently only microbatch if the number of tokens is
|
|
# over a certain threshold.
|
|
if self.parallel_config.enable_dbo and allow_microbatching:
|
|
ubatch_slices, ubatch_num_tokens_after_padding = ubatch_split(
|
|
num_scheduled_tokens,
|
|
total_num_scheduled_tokens,
|
|
total_num_scheduled_tokens,
|
|
uniform_decode=uniform_decode,
|
|
vllm_config=self.vllm_config,
|
|
)
|
|
# Currently when DBO is enabled `ubatch_split` returns
|
|
# the num_tokens_after_padding for a single ubatch, but we have 2
|
|
# TODO(sage,lucas): this is cruft that should be addressed in the
|
|
# padding refactor.
|
|
if ubatch_num_tokens_after_padding is not None:
|
|
num_tokens_after_padding = ubatch_num_tokens_after_padding * 2
|
|
|
|
# If we failed to microbatch, currently need to resynchronize
|
|
# TODO(lucas,sage): we should be able to avoid this second sync by
|
|
# refactoring `get_dp_padding_ubatch` and `get_dp_padding` into
|
|
# a single `coordinate_batch_across_dp` function.
|
|
if num_tokens_after_padding is None:
|
|
num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens)
|
|
num_tokens_after_padding = num_tokens + num_pad
|
|
else:
|
|
num_tokens_across_dp = num_tokens_after_padding
|
|
num_tokens_after_padding = int(num_tokens_after_padding[0].item())
|
|
|
|
attn_metadata: Optional[PerLayerAttnMetadata] = None
|
|
|
|
# If force_attention is True, we always capture attention. Otherwise,
|
|
# it only happens for cudagraph_runtime_mode=FULL.
|
|
if force_attention or cudagraph_runtime_mode == CUDAGraphMode.FULL:
|
|
attn_metadata = {}
|
|
if ubatch_slices is not None:
|
|
attn_metadata = [dict() for _ in range(len(ubatch_slices))]
|
|
|
|
if create_mixed_batch:
|
|
# In the mixed batch mode (used for FI warmup), we use
|
|
# shorter sequence lengths to run faster.
|
|
# TODO(luka) better system for describing dummy batches
|
|
seq_lens = [1] * num_decode_tokens + [num_prefill_tokens + 1]
|
|
else:
|
|
seq_lens = max_query_len
|
|
self.seq_lens.np[:num_reqs] = seq_lens
|
|
self.seq_lens.np[num_reqs:] = 0
|
|
self.seq_lens.copy_to_gpu()
|
|
|
|
cum_num_tokens, _ = self._get_cumsum_and_arange(num_scheduled_tokens)
|
|
self.query_start_loc.np[1 : num_reqs + 1] = cum_num_tokens
|
|
self.query_start_loc.copy_to_gpu()
|
|
|
|
for kv_cache_group_id, kv_cache_group_spec in enumerate(
|
|
self.kv_cache_config.kv_cache_groups
|
|
):
|
|
common_attn_metadata = CommonAttentionMetadata(
|
|
query_start_loc=self.query_start_loc.gpu[: num_reqs + 1],
|
|
query_start_loc_cpu=self.query_start_loc.cpu[: num_reqs + 1],
|
|
seq_lens=self.seq_lens.gpu[:num_reqs],
|
|
seq_lens_cpu=self.seq_lens.cpu[:num_reqs],
|
|
num_computed_tokens_cpu=self.input_batch.num_computed_tokens_cpu_tensor[
|
|
:num_reqs
|
|
],
|
|
num_reqs=num_reqs,
|
|
num_actual_tokens=num_tokens,
|
|
max_query_len=max_query_len,
|
|
max_seq_len=self.max_model_len,
|
|
block_table_tensor=self.input_batch.block_table[
|
|
kv_cache_group_id
|
|
].get_device_tensor(num_reqs),
|
|
slot_mapping=self.input_batch.block_table[
|
|
kv_cache_group_id
|
|
].slot_mapping.gpu[:num_tokens],
|
|
causal=True,
|
|
)
|
|
for attn_group in self.attn_groups[kv_cache_group_id]:
|
|
if ubatch_slices is not None:
|
|
common_attn_metadata_list = split_attn_metadata(
|
|
ubatch_slices, common_attn_metadata
|
|
)
|
|
for ubid, common_attn_metadata in enumerate(
|
|
common_attn_metadata_list
|
|
):
|
|
assert common_attn_metadata.max_query_len == 1
|
|
attn_metadata_i = attn_group.get_metadata_builder(
|
|
ubatch_id=ubid
|
|
).build_for_cudagraph_capture(common_attn_metadata)
|
|
for layer_name in attn_group.layer_names:
|
|
assert type(attn_metadata) is list
|
|
attn_metadata[ubid][layer_name] = attn_metadata_i
|
|
else:
|
|
assert type(attn_metadata) is dict
|
|
metadata_builder = attn_group.get_metadata_builder()
|
|
attn_metadata_i = metadata_builder.build_for_cudagraph_capture(
|
|
common_attn_metadata
|
|
)
|
|
for layer_name in attn_group.layer_names:
|
|
attn_metadata[layer_name] = attn_metadata_i
|
|
|
|
with self.maybe_dummy_run_with_lora(
|
|
self.lora_config, num_scheduled_tokens, remove_lora
|
|
):
|
|
# Make sure padding doesn't exceed max_num_tokens
|
|
assert num_tokens_after_padding <= self.max_num_tokens
|
|
model_kwargs = self._init_model_kwargs(num_tokens_after_padding)
|
|
if self.supports_mm_inputs and not self.model_config.is_encoder_decoder:
|
|
input_ids = None
|
|
inputs_embeds = self.inputs_embeds.gpu[:num_tokens_after_padding]
|
|
model_kwargs = {
|
|
**model_kwargs,
|
|
**self._dummy_mm_kwargs(num_reqs),
|
|
}
|
|
elif self.enable_prompt_embeds:
|
|
input_ids = None
|
|
inputs_embeds = self.inputs_embeds.gpu[:num_tokens_after_padding]
|
|
model_kwargs = self._init_model_kwargs(num_tokens_after_padding)
|
|
else:
|
|
input_ids = self.input_ids.gpu[:num_tokens_after_padding]
|
|
inputs_embeds = None
|
|
|
|
if self.uses_mrope:
|
|
positions = self.mrope_positions.gpu[:, :num_tokens_after_padding]
|
|
else:
|
|
positions = self.positions.gpu[:num_tokens_after_padding]
|
|
|
|
if get_pp_group().is_first_rank:
|
|
intermediate_tensors = None
|
|
else:
|
|
if self.intermediate_tensors is None:
|
|
self.intermediate_tensors = (
|
|
self.model.make_empty_intermediate_tensors(
|
|
batch_size=self.max_num_tokens,
|
|
dtype=self.model_config.dtype,
|
|
device=self.device,
|
|
)
|
|
)
|
|
|
|
intermediate_tensors = self.sync_and_slice_intermediate_tensors(
|
|
num_tokens_after_padding, None, False
|
|
)
|
|
|
|
# filter out the valid batch descriptor
|
|
_cg_mode, batch_descriptor = (
|
|
self.cudagraph_dispatcher.dispatch(
|
|
BatchDescriptor(
|
|
num_tokens=num_tokens_after_padding,
|
|
uniform_decode=uniform_decode,
|
|
)
|
|
)
|
|
if not is_profile
|
|
else (CUDAGraphMode.NONE, None)
|
|
)
|
|
if cudagraph_runtime_mode is not None:
|
|
# we allow forcing NONE when the dispatcher disagrees to support
|
|
# warm ups for cudagraph capture
|
|
assert (
|
|
cudagraph_runtime_mode == CUDAGraphMode.NONE
|
|
or cudagraph_runtime_mode == _cg_mode
|
|
), (
|
|
f"Cudagraph runtime mode mismatch at dummy_run. "
|
|
f"Expected {_cg_mode}, but got {cudagraph_runtime_mode}."
|
|
)
|
|
else:
|
|
cudagraph_runtime_mode = _cg_mode
|
|
|
|
if ubatch_slices is not None:
|
|
# Adjust values to reflect a single ubatch.
|
|
# TODO(sage,lucas): this is cruft that should be addressed in
|
|
# the padding refactor.
|
|
num_tokens_after_padding = ubatch_slices[0].num_tokens
|
|
if num_tokens_across_dp is not None:
|
|
num_tokens_across_dp[:] = num_tokens_after_padding
|
|
|
|
with (
|
|
self.maybe_randomize_inputs(input_ids),
|
|
set_forward_context(
|
|
attn_metadata,
|
|
self.vllm_config,
|
|
num_tokens=num_tokens_after_padding,
|
|
num_tokens_across_dp=num_tokens_across_dp,
|
|
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
|
batch_descriptor=batch_descriptor,
|
|
ubatch_slices=ubatch_slices,
|
|
),
|
|
):
|
|
outputs = self.model(
|
|
input_ids=input_ids,
|
|
positions=positions,
|
|
intermediate_tensors=intermediate_tensors,
|
|
inputs_embeds=inputs_embeds,
|
|
**model_kwargs,
|
|
)
|
|
|
|
if self.use_aux_hidden_state_outputs:
|
|
hidden_states, _ = outputs
|
|
else:
|
|
hidden_states = outputs
|
|
|
|
if self.speculative_config and self.speculative_config.use_eagle():
|
|
assert isinstance(self.drafter, EagleProposer)
|
|
self.drafter.dummy_run(num_tokens)
|
|
|
|
# This is necessary to avoid blocking DP.
|
|
# For dummy runs, we typically skip EPLB since we don't have any real
|
|
# requests to process.
|
|
# However, in DP settings, there may be cases when some DP ranks do
|
|
# not have any requests to process, so they're executing dummy batches.
|
|
# In such cases, we still have to trigger EPLB to make sure
|
|
# ranks execute the rearrangement in synchronization.
|
|
if not skip_eplb:
|
|
self.eplb_step(is_dummy=True, is_profile=is_profile)
|
|
|
|
logit_indices = np.cumsum(num_scheduled_tokens) - 1
|
|
return hidden_states, hidden_states[logit_indices]
|
|
|
|
@torch.inference_mode()
|
|
def _dummy_sampler_run(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
# The dummy hidden states may contain special values,
|
|
# like `inf` or `nan`.
|
|
# To avoid breaking the sampler, we use a random tensor here instead.
|
|
hidden_states = torch.rand_like(hidden_states)
|
|
|
|
logits = self.model.compute_logits(hidden_states)
|
|
num_reqs = logits.size(0)
|
|
|
|
dummy_tensors = lambda v: torch.full((num_reqs,), v, device=self.device)
|
|
|
|
dummy_metadata = SamplingMetadata(
|
|
temperature=dummy_tensors(0.5),
|
|
all_greedy=False,
|
|
all_random=False,
|
|
top_p=dummy_tensors(0.9),
|
|
top_k=dummy_tensors(logits.size(1) - 1),
|
|
generators={},
|
|
max_num_logprobs=None,
|
|
no_penalties=True,
|
|
prompt_token_ids=None,
|
|
frequency_penalties=dummy_tensors(0.1),
|
|
presence_penalties=dummy_tensors(0.1),
|
|
repetition_penalties=dummy_tensors(0.1),
|
|
output_token_ids=[[] for _ in range(num_reqs)],
|
|
allowed_token_ids_mask=None,
|
|
bad_words_token_ids={},
|
|
logitsprocs=LogitsProcessors(),
|
|
)
|
|
try:
|
|
sampler_output = self.sampler(
|
|
logits=logits, sampling_metadata=dummy_metadata
|
|
)
|
|
except RuntimeError as e:
|
|
if "out of memory" in str(e):
|
|
raise RuntimeError(
|
|
"CUDA out of memory occurred when warming up sampler with "
|
|
f"{num_reqs} dummy requests. Please try lowering "
|
|
"`max_num_seqs` or `gpu_memory_utilization` when "
|
|
"initializing the engine."
|
|
) from e
|
|
else:
|
|
raise e
|
|
if self.speculative_config:
|
|
draft_token_ids = [[0] for _ in range(num_reqs)]
|
|
dummy_spec_decode_metadata = SpecDecodeMetadata.make_dummy(
|
|
draft_token_ids, self.device
|
|
)
|
|
|
|
num_tokens = sum(len(ids) for ids in draft_token_ids)
|
|
# draft_probs = torch.randn(
|
|
# num_tokens, logits.shape[-1], device=self.device,
|
|
# dtype=logits.dtype)
|
|
draft_probs = None
|
|
target_logits = torch.randn(
|
|
num_tokens, logits.shape[-1], device=self.device, dtype=logits.dtype
|
|
)
|
|
# NOTE(woosuk): Here, we should use int32 because the sampler uses
|
|
# int32 for bonus_token_ids. If the dtype mismatches, re-compilation
|
|
# will occur at runtime.
|
|
bonus_token_ids = torch.zeros(
|
|
num_reqs, device=self.device, dtype=torch.int32
|
|
)
|
|
self.rejection_sampler(
|
|
dummy_spec_decode_metadata,
|
|
draft_probs,
|
|
target_logits,
|
|
bonus_token_ids,
|
|
dummy_metadata,
|
|
)
|
|
return sampler_output
|
|
|
|
def _dummy_pooler_run_task(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
task: PoolingTask,
|
|
) -> PoolerOutput:
|
|
num_tokens = hidden_states.shape[0]
|
|
max_num_reqs = self.scheduler_config.max_num_seqs
|
|
num_reqs = min(num_tokens, max_num_reqs)
|
|
min_tokens_per_req = num_tokens // num_reqs
|
|
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
|
|
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
|
|
assert sum(num_scheduled_tokens_list) == num_tokens
|
|
assert len(num_scheduled_tokens_list) == num_reqs
|
|
|
|
req_num_tokens = num_tokens // num_reqs
|
|
|
|
dummy_prompt_lens = torch.tensor(
|
|
num_scheduled_tokens_list,
|
|
device="cpu",
|
|
)
|
|
dummy_token_ids = torch.zeros(
|
|
(num_reqs, req_num_tokens), dtype=torch.int32, device=self.device
|
|
)
|
|
|
|
model = cast(VllmModelForPooling, self.get_model())
|
|
dummy_pooling_params = PoolingParams(task=task)
|
|
dummy_pooling_params.verify(task=task, model_config=self.model_config)
|
|
to_update = model.pooler.get_pooling_updates(task)
|
|
to_update.apply(dummy_pooling_params)
|
|
|
|
dummy_metadata = PoolingMetadata(
|
|
prompt_lens=dummy_prompt_lens,
|
|
prompt_token_ids=dummy_token_ids,
|
|
pooling_params=[dummy_pooling_params] * num_reqs,
|
|
)
|
|
|
|
dummy_metadata.build_pooling_cursor(
|
|
num_scheduled_tokens_list, device=hidden_states.device
|
|
)
|
|
|
|
try:
|
|
return model.pooler(
|
|
hidden_states=hidden_states, pooling_metadata=dummy_metadata
|
|
)
|
|
except RuntimeError as e:
|
|
if "out of memory" in str(e):
|
|
raise RuntimeError(
|
|
"CUDA out of memory occurred when warming up pooler "
|
|
f"({task=}) with {num_reqs} dummy requests. Please try "
|
|
"lowering `max_num_seqs` or `gpu_memory_utilization` when "
|
|
"initializing the engine."
|
|
) from e
|
|
else:
|
|
raise e
|
|
|
|
@torch.inference_mode()
|
|
def _dummy_pooler_run(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
) -> PoolerOutput:
|
|
# Find the task that has the largest output for subsequent steps
|
|
output_size = dict[PoolingTask, float]()
|
|
for task in self.get_supported_pooling_tasks():
|
|
# Run a full batch with each task to ensure none of them OOMs
|
|
output = self._dummy_pooler_run_task(hidden_states, task)
|
|
output_size[task] = sum(o.nbytes for o in output)
|
|
del output # Allow GC
|
|
|
|
max_task = max(output_size.items(), key=lambda x: x[1])[0]
|
|
return self._dummy_pooler_run_task(hidden_states, max_task)
|
|
|
|
def profile_run(self) -> None:
|
|
# Profile with multimodal encoder & encoder cache.
|
|
if self.supports_mm_inputs:
|
|
if self.model_config.multimodal_config.skip_mm_profiling:
|
|
logger.info(
|
|
"Skipping memory profiling for multimodal encoder and "
|
|
"encoder cache."
|
|
)
|
|
else:
|
|
mm_budget = self.mm_budget
|
|
assert mm_budget is not None
|
|
|
|
if (encoder_budget := mm_budget.get_encoder_budget()) > 0:
|
|
# NOTE: Currently model is profiled with a single non-text
|
|
# modality with the max possible input tokens even when
|
|
# it supports multiple.
|
|
dummy_modality = mm_budget.get_modality_with_max_tokens()
|
|
max_mm_items_per_batch = mm_budget.max_items_per_batch_by_modality[
|
|
dummy_modality
|
|
]
|
|
|
|
logger.info(
|
|
"Encoder cache will be initialized with a budget of "
|
|
"%s tokens, and profiled with %s %s items of the "
|
|
"maximum feature size.",
|
|
encoder_budget,
|
|
max_mm_items_per_batch,
|
|
dummy_modality,
|
|
)
|
|
|
|
# Create dummy batch of multimodal inputs.
|
|
batched_dummy_mm_inputs = self._get_mm_dummy_batch(
|
|
dummy_modality,
|
|
max_mm_items_per_batch,
|
|
)
|
|
|
|
# Run multimodal encoder.
|
|
dummy_encoder_outputs = self.model.get_multimodal_embeddings(
|
|
**batched_dummy_mm_inputs
|
|
)
|
|
|
|
sanity_check_mm_encoder_outputs(
|
|
dummy_encoder_outputs,
|
|
expected_num_items=max_mm_items_per_batch,
|
|
)
|
|
|
|
# NOTE: This happens when encoder cache needs to store
|
|
# the embeddings that encoder outputs are scattered onto.
|
|
# In this case we create dummy embeddings of size
|
|
# (encode_budget, hidden_size) and scatter encoder
|
|
# output into it.
|
|
encoder_output_shape = dummy_encoder_outputs[0].shape
|
|
if encoder_output_shape[0] < encoder_budget:
|
|
expanded_outputs = []
|
|
for output in dummy_encoder_outputs:
|
|
expanded = output.new_zeros(
|
|
(encoder_budget, encoder_output_shape[-1])
|
|
)
|
|
num_tokens = output.shape[0]
|
|
expanded[:num_tokens].copy_(output)
|
|
expanded_outputs.append(expanded)
|
|
|
|
dummy_encoder_outputs = expanded_outputs
|
|
|
|
# Cache the dummy encoder outputs.
|
|
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
|
|
|
|
# Add `is_profile` here to pre-allocate communication buffers
|
|
hidden_states, last_hidden_states = self._dummy_run(
|
|
self.max_num_tokens, is_profile=True
|
|
)
|
|
if get_pp_group().is_last_rank:
|
|
if self.is_pooling_model:
|
|
output = self._dummy_pooler_run(hidden_states)
|
|
else:
|
|
output = self._dummy_sampler_run(last_hidden_states)
|
|
else:
|
|
output = None
|
|
self._sync_device()
|
|
del hidden_states, output
|
|
self.encoder_cache.clear()
|
|
gc.collect()
|
|
|
|
def capture_model(self) -> int:
|
|
if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE:
|
|
logger.warning(
|
|
"Skipping CUDA graph capture. To turn on CUDA graph capture, "
|
|
"ensure `cudagraph_mode` was not manually set to `NONE`"
|
|
)
|
|
return 0
|
|
else:
|
|
self.initialize_cudagraph_capture()
|
|
|
|
compilation_counter.num_gpu_runner_capture_triggers += 1
|
|
|
|
start_time = time.perf_counter()
|
|
|
|
@contextmanager
|
|
def freeze_gc():
|
|
# Optimize garbage collection during CUDA graph capture.
|
|
# Clean up, then freeze all remaining objects from being included
|
|
# in future collections.
|
|
gc.collect()
|
|
should_freeze = not envs.VLLM_ENABLE_CUDAGRAPH_GC
|
|
if should_freeze:
|
|
gc.freeze()
|
|
try:
|
|
yield
|
|
finally:
|
|
if should_freeze:
|
|
gc.unfreeze()
|
|
gc.collect()
|
|
|
|
# Trigger CUDA graph capture for specific shapes.
|
|
# Capture the large shapes first so that the smaller shapes
|
|
# can reuse the memory pool allocated for the large shapes.
|
|
set_cudagraph_capturing_enabled(True)
|
|
with freeze_gc(), graph_capture(device=self.device):
|
|
start_free_gpu_memory = torch.cuda.mem_get_info()[0]
|
|
cudagraph_mode = self.compilation_config.cudagraph_mode
|
|
assert cudagraph_mode is not None
|
|
if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE:
|
|
cudagraph_runtime_mode = cudagraph_mode.mixed_mode()
|
|
|
|
compilation_cases = list(reversed(self.cudagraph_batch_sizes))
|
|
self._capture_cudagraphs(
|
|
compilation_cases,
|
|
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
|
uniform_decode=False,
|
|
)
|
|
|
|
# Capture full cudagraph for uniform decode batches if we
|
|
# don't already have full mixed prefill-decode cudagraphs.
|
|
if (
|
|
cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
|
|
and cudagraph_mode.separate_routine()
|
|
):
|
|
max_num_tokens = (
|
|
self.scheduler_config.max_num_seqs * self.uniform_decode_query_len
|
|
)
|
|
decode_cudagraph_batch_sizes = [
|
|
x
|
|
for x in self.cudagraph_batch_sizes
|
|
if x <= max_num_tokens and x >= self.uniform_decode_query_len
|
|
]
|
|
compilation_cases_decode = list(reversed(decode_cudagraph_batch_sizes))
|
|
self._capture_cudagraphs(
|
|
compilation_cases=compilation_cases_decode,
|
|
cudagraph_runtime_mode=CUDAGraphMode.FULL,
|
|
uniform_decode=True,
|
|
)
|
|
|
|
torch.cuda.synchronize()
|
|
end_free_gpu_memory = torch.cuda.mem_get_info()[0]
|
|
|
|
# Disable cudagraph capturing globally, so any unexpected cudagraph
|
|
# capturing will be detected and raise an error after here.
|
|
# Note: We don't put it into graph_capture context manager because
|
|
# we may do lazy capturing in future that still allows capturing
|
|
# after here.
|
|
set_cudagraph_capturing_enabled(False)
|
|
|
|
end_time = time.perf_counter()
|
|
elapsed_time = end_time - start_time
|
|
cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory
|
|
# This usually takes 5~20 seconds.
|
|
logger.info(
|
|
"Graph capturing finished in %.0f secs, took %.2f GiB",
|
|
elapsed_time,
|
|
cuda_graph_size / (1 << 30),
|
|
)
|
|
return cuda_graph_size
|
|
|
|
def _capture_cudagraphs(
|
|
self,
|
|
compilation_cases: list[int],
|
|
cudagraph_runtime_mode: CUDAGraphMode,
|
|
uniform_decode: bool,
|
|
):
|
|
assert (
|
|
cudagraph_runtime_mode != CUDAGraphMode.NONE
|
|
and cudagraph_runtime_mode.valid_runtime_modes()
|
|
), f"Invalid cudagraph runtime mode: {cudagraph_runtime_mode}"
|
|
|
|
# Only rank 0 should print progress bar during capture
|
|
if is_global_first_rank():
|
|
compilation_cases = tqdm(
|
|
compilation_cases,
|
|
disable=not self.load_config.use_tqdm_on_load,
|
|
desc="Capturing CUDA graphs ({}, {})".format(
|
|
"decode" if uniform_decode else "mixed prefill-decode",
|
|
cudagraph_runtime_mode.name,
|
|
),
|
|
)
|
|
|
|
# We skip EPLB here since we don't want to record dummy metrics
|
|
for num_tokens in compilation_cases:
|
|
# We currently only capture ubatched graphs when its a FULL
|
|
# cudagraph, a uniform decode batch, and the number of tokens
|
|
# is above the threshold. Otherwise we just capture a non-ubatched
|
|
# version of the graph
|
|
allow_microbatching = (
|
|
self.parallel_config.enable_dbo
|
|
and cudagraph_runtime_mode == CUDAGraphMode.FULL
|
|
and uniform_decode
|
|
and check_ubatch_thresholds(
|
|
config=self.vllm_config.parallel_config,
|
|
num_tokens=num_tokens,
|
|
uniform_decode=uniform_decode,
|
|
)
|
|
)
|
|
|
|
for _ in range(self.compilation_config.cudagraph_num_of_warmups):
|
|
# Use CUDAGraphRuntimeStyle.NONE (default) for warmup.
|
|
# But be careful, warm up with `NONE`is orthogonal to
|
|
# if we want to warm up attention or not. This is
|
|
# different from the case where `FULL` implies capture
|
|
# attention while `PIECEWISE` implies no attention.
|
|
force_attention = cudagraph_runtime_mode == CUDAGraphMode.FULL
|
|
self._dummy_run(
|
|
num_tokens,
|
|
cudagraph_runtime_mode=CUDAGraphMode.NONE,
|
|
force_attention=force_attention,
|
|
uniform_decode=uniform_decode,
|
|
allow_microbatching=allow_microbatching,
|
|
skip_eplb=True,
|
|
remove_lora=False,
|
|
)
|
|
self._dummy_run(
|
|
num_tokens,
|
|
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
|
uniform_decode=uniform_decode,
|
|
allow_microbatching=allow_microbatching,
|
|
skip_eplb=True,
|
|
remove_lora=False,
|
|
)
|
|
self.maybe_remove_all_loras(self.lora_config)
|
|
|
|
def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:
|
|
"""
|
|
Initialize the attention backends and attention metadata builders.
|
|
"""
|
|
assert len(self.attn_groups) == 0, "Attention backends are already initialized"
|
|
|
|
class AttentionGroupKey(NamedTuple):
|
|
attn_backend: type[AttentionBackend]
|
|
kv_cache_spec: KVCacheSpec
|
|
|
|
def get_attn_backends_for_group(
|
|
kv_cache_group_spec: KVCacheGroupSpec,
|
|
) -> dict[AttentionGroupKey, list[str]]:
|
|
layers = get_layers_from_vllm_config(
|
|
self.vllm_config, AttentionLayerBase, kv_cache_group_spec.layer_names
|
|
)
|
|
attn_backends = {}
|
|
attn_backend_layers = defaultdict(list)
|
|
# Dedupe based on full class name; this is a bit safer than
|
|
# using the class itself as the key because when we create dynamic
|
|
# attention backend subclasses (e.g. ChunkedLocalAttention) unless
|
|
# they are cached correctly, there will be different objects per
|
|
# layer.
|
|
for layer_name in kv_cache_group_spec.layer_names:
|
|
attn_backend = layers[layer_name].get_attn_backend()
|
|
|
|
if layer_name in self.kv_sharing_fast_prefill_eligible_layers:
|
|
attn_backend = create_fast_prefill_custom_backend(
|
|
"FastPrefill",
|
|
attn_backend,
|
|
)
|
|
|
|
full_cls_name = attn_backend.full_cls_name()
|
|
layer_kv_cache_spec = kv_cache_group_spec.kv_cache_spec
|
|
if isinstance(layer_kv_cache_spec, UniformTypeKVCacheSpecs):
|
|
layer_kv_cache_spec = layer_kv_cache_spec.kv_cache_specs[layer_name]
|
|
key = (full_cls_name, layer_kv_cache_spec)
|
|
attn_backends[key] = AttentionGroupKey(
|
|
attn_backend, layer_kv_cache_spec
|
|
)
|
|
attn_backend_layers[key].append(layer_name)
|
|
return {attn_backends[k]: v for k, v in attn_backend_layers.items()}
|
|
|
|
def create_attn_groups(
|
|
attn_backends_map: dict[AttentionGroupKey, list[str]],
|
|
) -> list[AttentionGroup]:
|
|
attn_groups: list[AttentionGroup] = []
|
|
for (attn_backend, kv_cache_spec), layer_names in attn_backends_map.items():
|
|
attn_group = AttentionGroup.create_with_metadata_builders(
|
|
attn_backend,
|
|
layer_names,
|
|
kv_cache_spec,
|
|
self.vllm_config,
|
|
self.device,
|
|
num_metadata_builders=1
|
|
if not self.parallel_config.enable_dbo
|
|
else 2,
|
|
)
|
|
|
|
attn_groups.append(attn_group)
|
|
return attn_groups
|
|
|
|
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
|
|
attn_backends = get_attn_backends_for_group(kv_cache_group_spec)
|
|
self.attn_groups.append(create_attn_groups(attn_backends))
|
|
|
|
# Calculate reorder batch threshold (if needed)
|
|
self.calculate_reorder_batch_threshold()
|
|
|
|
def initialize_cudagraph_capture(self) -> None:
|
|
"""
|
|
Resolve the cudagraph_mode when there are multiple attention
|
|
backends with potential conflicting CUDA graph support.
|
|
Then initialize the cudagraph_dispatcher based on the resolved
|
|
cudagraph_mode.
|
|
"""
|
|
min_cg_support = AttentionCGSupport.ALWAYS
|
|
min_cg_builder_name = None
|
|
|
|
for attn_group in self._attn_group_iterator():
|
|
builder = attn_group.get_metadata_builder()
|
|
if builder.cudagraph_support.value < min_cg_support.value:
|
|
min_cg_support = builder.cudagraph_support
|
|
min_cg_builder_name = builder.__class__.__name__
|
|
# Flexible resolve the cudagraph mode
|
|
cudagraph_mode = self.compilation_config.cudagraph_mode
|
|
# check cudagraph for mixed batch is supported
|
|
if (
|
|
cudagraph_mode.mixed_mode() == CUDAGraphMode.FULL
|
|
and min_cg_support != AttentionCGSupport.ALWAYS
|
|
):
|
|
msg = (
|
|
f"CUDAGraphMode.{cudagraph_mode.name} is not supported "
|
|
f"with {min_cg_builder_name} backend (support: "
|
|
f"{min_cg_support})"
|
|
)
|
|
if min_cg_support == AttentionCGSupport.NEVER:
|
|
# if not supported any full cudagraphs, just raise it.
|
|
msg += (
|
|
"; please try cudagraph_mode=PIECEWISE, and "
|
|
"make sure compilation level is piecewise"
|
|
)
|
|
raise ValueError(msg)
|
|
|
|
# attempt to resolve the full cudagraph related mode
|
|
if self.compilation_config.splitting_ops_contain_attention():
|
|
msg += "; setting cudagraph_mode=FULL_AND_PIECEWISE"
|
|
cudagraph_mode = self.compilation_config.cudagraph_mode = (
|
|
CUDAGraphMode.FULL_AND_PIECEWISE
|
|
)
|
|
else:
|
|
msg += "; setting cudagraph_mode=FULL_DECODE_ONLY"
|
|
cudagraph_mode = self.compilation_config.cudagraph_mode = (
|
|
CUDAGraphMode.FULL_DECODE_ONLY
|
|
)
|
|
logger.warning(msg)
|
|
|
|
# check that if we are doing decode full-cudagraphs it is supported
|
|
if (
|
|
cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
|
|
and min_cg_support == AttentionCGSupport.NEVER
|
|
):
|
|
msg = (
|
|
f"CUDAGraphMode.{cudagraph_mode.name} is not supported "
|
|
f"with {min_cg_builder_name} backend (support: "
|
|
f"{min_cg_support})"
|
|
)
|
|
if self.compilation_config.level == CompilationLevel.PIECEWISE and (
|
|
self.compilation_config.splitting_ops_contain_attention()
|
|
or self.compilation_config.use_inductor_graph_partition
|
|
):
|
|
msg += (
|
|
"; setting cudagraph_mode=PIECEWISE because "
|
|
"attention is compiled piecewise"
|
|
)
|
|
cudagraph_mode = self.compilation_config.cudagraph_mode = (
|
|
CUDAGraphMode.PIECEWISE
|
|
)
|
|
else:
|
|
msg += (
|
|
"; setting cudagraph_mode=NONE because "
|
|
"attention is not compiled piecewise"
|
|
)
|
|
cudagraph_mode = self.compilation_config.cudagraph_mode = (
|
|
CUDAGraphMode.NONE
|
|
)
|
|
logger.warning(msg)
|
|
|
|
# check that if we are doing spec-decode + decode full-cudagraphs it is
|
|
# supported
|
|
if (
|
|
cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
|
|
and self.uniform_decode_query_len > 1
|
|
and min_cg_support.value < AttentionCGSupport.UNIFORM_BATCH.value
|
|
):
|
|
msg = (
|
|
f"CUDAGraphMode.{cudagraph_mode.name} is not supported"
|
|
f" with spec-decode for attention backend "
|
|
f"{min_cg_builder_name} (support: {min_cg_support})"
|
|
)
|
|
if self.compilation_config.splitting_ops_contain_attention():
|
|
msg += "; setting cudagraph_mode=PIECEWISE"
|
|
cudagraph_mode = self.compilation_config.cudagraph_mode = (
|
|
CUDAGraphMode.PIECEWISE
|
|
)
|
|
else:
|
|
msg += "; setting cudagraph_mode=NONE"
|
|
cudagraph_mode = self.compilation_config.cudagraph_mode = (
|
|
CUDAGraphMode.NONE
|
|
)
|
|
logger.warning(msg)
|
|
|
|
# double check that we can support full cudagraph if they are requested
|
|
# even after automatic downgrades
|
|
if (
|
|
cudagraph_mode.has_full_cudagraphs()
|
|
and min_cg_support == AttentionCGSupport.NEVER
|
|
):
|
|
raise ValueError(
|
|
f"CUDAGraphMode.{cudagraph_mode.name} is not "
|
|
f"supported with {min_cg_builder_name} backend ("
|
|
f"support:{min_cg_support}) "
|
|
"; please try cudagraph_mode=PIECEWISE, "
|
|
"and make sure compilation level is piecewise"
|
|
)
|
|
|
|
# Trigger cudagraph dispatching keys initialization here (after
|
|
# initializing attn backends).
|
|
self.cudagraph_dispatcher.initialize_cudagraph_keys(
|
|
self.compilation_config.cudagraph_mode, self.uniform_decode_query_len
|
|
)
|
|
|
|
def calculate_reorder_batch_threshold(self) -> None:
|
|
"""
|
|
Check that if any backends reorder batches; that the reordering
|
|
is compatible (e.g., decode threshold is the same)
|
|
"""
|
|
for group in self._attn_group_iterator():
|
|
attn_metadata_builder_i = group.get_metadata_builder()
|
|
|
|
# check that if any backends reorder batches; that the reordering
|
|
# is compatible (e.g., decode threshold is the same)
|
|
reorder_batch_threshold_i = attn_metadata_builder_i.reorder_batch_threshold
|
|
if reorder_batch_threshold_i is not None:
|
|
if self.reorder_batch_threshold is not None:
|
|
if reorder_batch_threshold_i != self.reorder_batch_threshold:
|
|
raise ValueError(
|
|
f"Attention backend reorders decodes with "
|
|
f"threshold {reorder_batch_threshold_i} but other "
|
|
f"backend uses threshold "
|
|
f"{self.reorder_batch_threshold}"
|
|
)
|
|
else:
|
|
self.reorder_batch_threshold = reorder_batch_threshold_i
|
|
|
|
def may_reinitialize_input_batch(self, kv_cache_config: KVCacheConfig) -> None:
|
|
"""
|
|
Re-initialize the input batch if the block sizes are different from
|
|
`[self.cache_config.block_size]`. This usually happens when there
|
|
are multiple KV cache groups.
|
|
|
|
Args:
|
|
kv_cache_config: The KV cache configuration.
|
|
"""
|
|
block_sizes = [
|
|
kv_cache_group.kv_cache_spec.block_size
|
|
for kv_cache_group in kv_cache_config.kv_cache_groups
|
|
]
|
|
if block_sizes != [self.cache_config.block_size]:
|
|
assert self.cache_config.cpu_offload_gb == 0, (
|
|
"Cannot re-initialize the input batch when CPU weight "
|
|
"offloading is enabled. See https://github.com/vllm-project/vllm/pull/18298 " # noqa: E501
|
|
"for more details."
|
|
)
|
|
self.input_batch = InputBatch(
|
|
max_num_reqs=self.max_num_reqs,
|
|
max_model_len=max(self.max_model_len, self.max_encoder_len),
|
|
max_num_batched_tokens=self.max_num_tokens,
|
|
device=self.device,
|
|
pin_memory=self.pin_memory,
|
|
vocab_size=self.model_config.get_vocab_size(),
|
|
block_sizes=block_sizes,
|
|
is_spec_decode=bool(self.vllm_config.speculative_config),
|
|
logitsprocs=self.input_batch.logitsprocs,
|
|
is_pooling_model=self.is_pooling_model,
|
|
num_speculative_tokens=(
|
|
self.vllm_config.speculative_config.num_speculative_tokens
|
|
if self.vllm_config.speculative_config
|
|
else 0
|
|
),
|
|
)
|
|
|
|
def _allocate_kv_cache_tensors(
|
|
self, kv_cache_config: KVCacheConfig
|
|
) -> dict[str, torch.Tensor]:
|
|
"""
|
|
Initializes the KV cache buffer with the correct size. The buffer needs
|
|
to be reshaped to the desired shape before being used by the models.
|
|
|
|
Args:
|
|
kv_cache_config: The KV cache config
|
|
Returns:
|
|
dict[str, torch.Tensor]: A map between layer names to their
|
|
corresponding memory buffer for KV cache.
|
|
"""
|
|
kv_cache_raw_tensors: dict[str, torch.Tensor] = {}
|
|
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
|
|
tensor = torch.zeros(
|
|
kv_cache_tensor.size, dtype=torch.int8, device=self.device
|
|
)
|
|
for layer_name in kv_cache_tensor.shared_by:
|
|
kv_cache_raw_tensors[layer_name] = tensor
|
|
|
|
layer_names = set()
|
|
for group in kv_cache_config.kv_cache_groups:
|
|
for layer_name in group.layer_names:
|
|
if layer_name in self.runner_only_attn_layers:
|
|
continue
|
|
layer_names.add(layer_name)
|
|
assert layer_names == set(kv_cache_raw_tensors.keys()), (
|
|
"Some layers are not correctly initialized"
|
|
)
|
|
return kv_cache_raw_tensors
|
|
|
|
def _attn_group_iterator(self) -> Iterator[AttentionGroup]:
|
|
return itertools.chain.from_iterable(self.attn_groups)
|
|
|
|
def _kv_cache_spec_attn_group_iterator(self) -> Iterator[AttentionGroup]:
|
|
if not self.kv_cache_config.kv_cache_groups:
|
|
return
|
|
for attn_groups in self.attn_groups:
|
|
yield from attn_groups
|
|
|
|
def _reshape_kv_cache_tensors(
|
|
self,
|
|
kv_cache_config: KVCacheConfig,
|
|
kv_cache_raw_tensors: dict[str, torch.Tensor],
|
|
) -> dict[str, torch.Tensor]:
|
|
"""
|
|
Reshape the KV cache tensors to the desired shape and dtype.
|
|
|
|
Args:
|
|
kv_cache_config: The KV cache config
|
|
kv_cache_raw_tensors: The KV cache buffer of each layer, with
|
|
correct size but uninitialized shape.
|
|
Returns:
|
|
Dict[str, torch.Tensor]: A map between layer names to their
|
|
corresponding memory buffer for KV cache.
|
|
"""
|
|
kv_caches: dict[str, torch.Tensor] = {}
|
|
has_attn, has_mamba = False, False
|
|
for group in self._kv_cache_spec_attn_group_iterator():
|
|
kv_cache_spec = group.kv_cache_spec
|
|
attn_backend = group.backend
|
|
for layer_name in group.layer_names:
|
|
if layer_name in self.runner_only_attn_layers:
|
|
continue
|
|
raw_tensor = kv_cache_raw_tensors[layer_name]
|
|
assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0
|
|
num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes
|
|
if isinstance(kv_cache_spec, AttentionSpec):
|
|
has_attn = True
|
|
kv_cache_shape = attn_backend.get_kv_cache_shape(
|
|
num_blocks,
|
|
kv_cache_spec.block_size,
|
|
kv_cache_spec.num_kv_heads,
|
|
kv_cache_spec.head_size,
|
|
cache_dtype_str=self.cache_config.cache_dtype,
|
|
)
|
|
dtype = kv_cache_spec.dtype
|
|
try:
|
|
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order()
|
|
assert len(kv_cache_stride_order) == len(kv_cache_shape)
|
|
except (AttributeError, NotImplementedError):
|
|
kv_cache_stride_order = tuple(range(len(kv_cache_shape)))
|
|
# The allocation respects the backend-defined stride order
|
|
# to ensure the semantic remains consistent for each
|
|
# backend. We first obtain the generic kv cache shape and
|
|
# then permute it according to the stride order which could
|
|
# result in a non-contiguous tensor.
|
|
kv_cache_shape = tuple(
|
|
kv_cache_shape[i] for i in kv_cache_stride_order
|
|
)
|
|
# Maintain original KV shape view.
|
|
inv_order = [
|
|
kv_cache_stride_order.index(i)
|
|
for i in range(len(kv_cache_stride_order))
|
|
]
|
|
kv_caches[layer_name] = (
|
|
kv_cache_raw_tensors[layer_name]
|
|
.view(dtype)
|
|
.view(kv_cache_shape)
|
|
.permute(*inv_order)
|
|
)
|
|
elif isinstance(kv_cache_spec, MambaSpec):
|
|
has_mamba = True
|
|
raw_tensor = kv_cache_raw_tensors[layer_name]
|
|
state_tensors = []
|
|
storage_offset_bytes = 0
|
|
for shape, dtype in zip(kv_cache_spec.shapes, kv_cache_spec.dtypes):
|
|
dtype_size = get_dtype_size(dtype)
|
|
num_element_per_page = (
|
|
kv_cache_spec.page_size_bytes // dtype_size
|
|
)
|
|
target_shape = (num_blocks, *shape)
|
|
stride = torch.empty(target_shape).stride()
|
|
target_stride = (num_element_per_page, *stride[1:])
|
|
assert storage_offset_bytes % dtype_size == 0
|
|
tensor = torch.as_strided(
|
|
raw_tensor.view(dtype),
|
|
size=target_shape,
|
|
stride=target_stride,
|
|
storage_offset=storage_offset_bytes // dtype_size,
|
|
)
|
|
state_tensors.append(tensor)
|
|
storage_offset_bytes += stride[0] * dtype_size
|
|
|
|
kv_caches[layer_name] = state_tensors
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
if has_attn and has_mamba:
|
|
self._update_hybrid_attention_mamba_layout(kv_caches)
|
|
|
|
return kv_caches
|
|
|
|
def _update_hybrid_attention_mamba_layout(
|
|
self, kv_caches: dict[str, torch.Tensor]
|
|
) -> None:
|
|
"""
|
|
Update the layout of attention layers from (2, num_blocks, ...) to
|
|
(num_blocks, 2, ...).
|
|
|
|
Args:
|
|
kv_caches: The KV cache buffer of each layer.
|
|
"""
|
|
|
|
for group in self._kv_cache_spec_attn_group_iterator():
|
|
kv_cache_spec = group.kv_cache_spec
|
|
for layer_name in group.layer_names:
|
|
kv_cache = kv_caches[layer_name]
|
|
if isinstance(kv_cache_spec, AttentionSpec) and kv_cache.shape[0] == 2:
|
|
assert kv_cache.shape[1] != 2, (
|
|
"Fail to determine whether the layout is "
|
|
"(2, num_blocks, ...) or (num_blocks, 2, ...) for "
|
|
f"a tensor of shape {kv_cache.shape}"
|
|
)
|
|
hidden_size = kv_cache.shape[2:].numel()
|
|
kv_cache.as_strided_(
|
|
size=kv_cache.shape,
|
|
stride=(hidden_size, 2 * hidden_size, *kv_cache.stride()[2:]),
|
|
)
|
|
|
|
def initialize_kv_cache_tensors(
|
|
self, kv_cache_config: KVCacheConfig
|
|
) -> dict[str, torch.Tensor]:
|
|
"""
|
|
Initialize the memory buffer for KV cache.
|
|
|
|
Args:
|
|
kv_cache_config: The KV cache config
|
|
Returns:
|
|
Dict[str, torch.Tensor]: A map between layer names to their
|
|
corresponding memory buffer for KV cache.
|
|
"""
|
|
# Initialize the memory buffer for KV cache
|
|
kv_cache_raw_tensors = self._allocate_kv_cache_tensors(kv_cache_config)
|
|
# Change the memory buffer to the desired shape
|
|
kv_caches = self._reshape_kv_cache_tensors(
|
|
kv_cache_config, kv_cache_raw_tensors
|
|
)
|
|
|
|
# Set up cross-layer KV cache sharing
|
|
for layer_name, target_layer_name in self.shared_kv_cache_layers.items():
|
|
logger.debug("%s reuses KV cache of %s", layer_name, target_layer_name)
|
|
kv_caches[layer_name] = kv_caches[target_layer_name]
|
|
|
|
num_attn_module = (
|
|
2 if self.model_config.hf_config.model_type == "longcat_flash" else 1
|
|
)
|
|
bind_kv_cache(
|
|
kv_caches,
|
|
self.compilation_config.static_forward_context,
|
|
self.kv_caches,
|
|
num_attn_module,
|
|
)
|
|
return kv_caches
|
|
|
|
def maybe_add_kv_sharing_layers_to_kv_cache_groups(
|
|
self, kv_cache_config: KVCacheConfig
|
|
) -> None:
|
|
"""
|
|
Add layers that re-use KV cache to KV cache group of its target layer.
|
|
Mapping of KV cache tensors happens in `initialize_kv_cache_tensors()`
|
|
"""
|
|
if not self.shared_kv_cache_layers:
|
|
# No cross-layer KV sharing, return
|
|
return
|
|
|
|
add_kv_sharing_layers_to_kv_cache_groups(
|
|
self.shared_kv_cache_layers,
|
|
kv_cache_config.kv_cache_groups,
|
|
self.runner_only_attn_layers,
|
|
)
|
|
|
|
if self.cache_config.kv_sharing_fast_prefill:
|
|
# In You Only Cache Once (https://arxiv.org/abs/2405.05254) or other
|
|
# similar KV sharing setups, only the layers that generate KV caches
|
|
# are involved in the prefill phase, enabling prefill to early exit.
|
|
attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
|
|
for layer_name in reversed(attn_layers):
|
|
if layer_name in self.shared_kv_cache_layers:
|
|
self.kv_sharing_fast_prefill_eligible_layers.add(layer_name)
|
|
else:
|
|
break
|
|
|
|
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
|
"""
|
|
Initialize KV cache based on `kv_cache_config`.
|
|
Args:
|
|
kv_cache_config: Configuration for the KV cache, including the KV
|
|
cache size of each layer
|
|
"""
|
|
kv_cache_config = deepcopy(kv_cache_config)
|
|
self.kv_cache_config = kv_cache_config
|
|
self.may_reinitialize_input_batch(kv_cache_config)
|
|
self.may_add_encoder_only_layers_to_kv_cache_config()
|
|
self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config)
|
|
self.initialize_attn_backend(kv_cache_config)
|
|
kv_caches = self.initialize_kv_cache_tensors(kv_cache_config)
|
|
|
|
if self.speculative_config and self.speculative_config.use_eagle():
|
|
assert isinstance(self.drafter, EagleProposer)
|
|
# validate all draft model layers belong to the same kv cache
|
|
# group
|
|
self.drafter.validate_same_kv_cache_group(kv_cache_config)
|
|
|
|
if has_kv_transfer_group():
|
|
kv_transfer_group = get_kv_transfer_group()
|
|
kv_transfer_group.register_kv_caches(kv_caches)
|
|
kv_transfer_group.set_host_xfer_buffer_ops(copy_kv_blocks)
|
|
|
|
if self.dcp_world_size > 1:
|
|
layer_names = self.attn_groups[0][0].layer_names
|
|
layers = get_layers_from_vllm_config(
|
|
self.vllm_config, AttentionLayerBase, layer_names
|
|
)
|
|
for layer in layers.values():
|
|
assert layer.impl.need_to_return_lse_for_decode, (
|
|
"DCP requires attention impls to return"
|
|
" the softmax lse for decode, but the impl "
|
|
f"{layer.impl.__class__.__name__} "
|
|
"does not return the softmax lse for decode."
|
|
)
|
|
|
|
def may_add_encoder_only_layers_to_kv_cache_config(self) -> None:
|
|
"""
|
|
Add encoder-only layers to the KV cache config.
|
|
"""
|
|
block_size = self.vllm_config.cache_config.block_size
|
|
encoder_only_attn_specs: dict[AttentionSpec, list[str]] = defaultdict(list)
|
|
attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
|
|
for layer_name, attn_module in attn_layers.items():
|
|
if attn_module.attn_type == AttentionType.ENCODER_ONLY:
|
|
attn_spec: AttentionSpec = EncoderOnlyAttentionSpec(
|
|
block_size=block_size,
|
|
num_kv_heads=attn_module.num_kv_heads,
|
|
head_size=attn_module.head_size,
|
|
dtype=self.kv_cache_dtype,
|
|
)
|
|
encoder_only_attn_specs[attn_spec].append(layer_name)
|
|
self.runner_only_attn_layers.add(layer_name)
|
|
if len(encoder_only_attn_specs) > 0:
|
|
assert len(encoder_only_attn_specs) == 1, (
|
|
"Only support one encoder-only attention spec now"
|
|
)
|
|
spec, layer_names = encoder_only_attn_specs.popitem()
|
|
self.kv_cache_config.kv_cache_groups.append(
|
|
KVCacheGroupSpec(layer_names=layer_names, kv_cache_spec=spec)
|
|
)
|
|
|
|
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
|
"""
|
|
Generates the KVCacheSpec by parsing the kv cache format from each
|
|
Attention module in the static forward context.
|
|
Returns:
|
|
KVCacheSpec: A dictionary mapping layer names to their KV cache
|
|
format. Layers that do not need KV cache are not included.
|
|
"""
|
|
|
|
block_size = self.vllm_config.cache_config.block_size
|
|
use_mla = self.vllm_config.model_config.use_mla
|
|
cache_dtype_str = self.vllm_config.cache_config.cache_dtype
|
|
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
|
attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
|
|
for layer_name, attn_module in attn_layers.items():
|
|
if (kv_tgt_layer := attn_module.kv_sharing_target_layer_name) is not None:
|
|
# The layer doesn't need its own KV cache and will use that of
|
|
# the target layer. We skip creating a KVCacheSpec for it, so
|
|
# that KV cache management logic will act as this layer does
|
|
# not exist, and doesn't allocate KV cache for the layer. This
|
|
# enables the memory saving of cross-layer kv sharing, allowing
|
|
# a given amount of memory to accommodate longer context lengths
|
|
# or enable more requests to be processed simultaneously.
|
|
self.shared_kv_cache_layers[layer_name] = kv_tgt_layer
|
|
continue
|
|
|
|
# TODO(lucas): move the attention specs into the model layers like
|
|
# the attention backends
|
|
if attn_module.attn_type == AttentionType.DECODER:
|
|
if attn_module.sliding_window is not None:
|
|
assert not use_mla, "MLA is not supported for slidingwindow"
|
|
kv_cache_spec[layer_name] = SlidingWindowSpec(
|
|
block_size=block_size,
|
|
num_kv_heads=attn_module.num_kv_heads,
|
|
head_size=attn_module.head_size,
|
|
dtype=self.kv_cache_dtype,
|
|
sliding_window=attn_module.sliding_window,
|
|
)
|
|
elif use_mla:
|
|
kv_cache_spec[layer_name] = MLAAttentionSpec(
|
|
block_size=block_size,
|
|
num_kv_heads=attn_module.num_kv_heads,
|
|
head_size=attn_module.head_size,
|
|
dtype=self.kv_cache_dtype,
|
|
cache_dtype_str=cache_dtype_str,
|
|
)
|
|
elif self.attention_chunk_size is not None and isinstance(
|
|
attn_module, ChunkedLocalAttention
|
|
):
|
|
kv_cache_spec[layer_name] = ChunkedLocalAttentionSpec(
|
|
block_size=block_size,
|
|
num_kv_heads=attn_module.num_kv_heads,
|
|
head_size=attn_module.head_size,
|
|
dtype=self.kv_cache_dtype,
|
|
attention_chunk_size=self.attention_chunk_size,
|
|
)
|
|
else:
|
|
kv_cache_spec[layer_name] = FullAttentionSpec(
|
|
block_size=block_size,
|
|
num_kv_heads=attn_module.num_kv_heads,
|
|
head_size=attn_module.head_size,
|
|
dtype=self.kv_cache_dtype,
|
|
)
|
|
elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
|
|
kv_cache_spec[layer_name] = CrossAttentionSpec(
|
|
block_size=block_size,
|
|
num_kv_heads=attn_module.num_kv_heads,
|
|
head_size=attn_module.head_size,
|
|
dtype=self.kv_cache_dtype,
|
|
)
|
|
elif attn_module.attn_type in (
|
|
AttentionType.ENCODER,
|
|
AttentionType.ENCODER_ONLY,
|
|
):
|
|
# encoder-only attention does not need KV cache.
|
|
continue
|
|
else:
|
|
raise ValueError(f"Unknown attention type: {attn_module.attn_type}")
|
|
|
|
mamba_layers = get_layers_from_vllm_config(self.vllm_config, MambaBase)
|
|
if len(mamba_layers) > 0:
|
|
if (
|
|
self.vllm_config.speculative_config is not None
|
|
and self.vllm_config.model_config.hf_config.model_type
|
|
not in ["qwen3_next"]
|
|
):
|
|
raise NotImplementedError(
|
|
"Mamba with speculative decoding is not supported yet."
|
|
)
|
|
mamba_block_size = self.vllm_config.cache_config.mamba_block_size
|
|
page_size_padded = self.vllm_config.cache_config.mamba_page_size_padded
|
|
|
|
for layer_name, mamba_module in mamba_layers.items():
|
|
kv_cache_spec[layer_name] = MambaSpec(
|
|
shapes=mamba_module.get_state_shape(),
|
|
dtypes=mamba_module.get_state_dtype(),
|
|
block_size=mamba_block_size,
|
|
page_size_padded=page_size_padded,
|
|
mamba_type=mamba_module.mamba_type,
|
|
num_speculative_blocks=(
|
|
self.speculative_config.num_speculative_tokens
|
|
if self.speculative_config
|
|
else 0
|
|
),
|
|
)
|
|
ds_indexer_layers = get_layers_from_vllm_config(
|
|
self.vllm_config, DeepseekV32IndexerCache
|
|
)
|
|
for layer_name, ds_indexer_module in ds_indexer_layers.items():
|
|
kv_cache_spec[layer_name] = ds_indexer_module.get_kv_cache_spec()
|
|
|
|
return kv_cache_spec
|
|
|
|
def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]:
|
|
# This is a short term mitigation for issue mentioned in
|
|
# https://github.com/vllm-project/vllm/issues/22754.
|
|
# `tolist` would trigger a cuda wise stream sync, which
|
|
# would block other copy ops from other cuda streams.
|
|
# A cuda event sync would avoid such a situation. Since
|
|
# this is in the critical path of every single model
|
|
# forward loop, this has caused perf issue for a disagg
|
|
# setup.
|
|
pinned = self.sampled_token_ids_pinned_cpu[: sampled_token_ids.shape[0]]
|
|
pinned.copy_(sampled_token_ids, non_blocking=True)
|
|
self.transfer_event.record()
|
|
self.transfer_event.synchronize()
|
|
return pinned.tolist()
|