mirror of
https://github.com/vllm-project/vllm-ascend.git
synced 2025-10-20 13:43:53 +08:00
### What this PR does / why we need it? Because the update_attn_params function requires passing the num_tokens parameter, and num_tokens is obtained via postions.shape[0]. However, the multimodal model uses mrope (Multidimensional Rotary Position Embedding), which results in the postions having a shape of 2. Consequently, postions.shape[0] retrieves an incorrect value.We resolve this issue by replacing positions.shape[0] with maybe_padded_num_tokens. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 Signed-off-by: fan2956 <zhoufan53@huawei.com>
3580 lines
170 KiB
Python
3580 lines
170 KiB
Python
#
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
# Copyright 2023 The vLLM team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# This file is a part of the vllm-ascend project.
|
|
# Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py
|
|
#
|
|
|
|
import copy
|
|
import gc
|
|
import itertools
|
|
import re
|
|
import time
|
|
from collections import defaultdict
|
|
from collections.abc import Iterator
|
|
from contextlib import contextmanager, nullcontext
|
|
from copy import deepcopy
|
|
from dataclasses import dataclass
|
|
from multiprocessing import Manager
|
|
from typing import (TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional,
|
|
Union, cast)
|
|
|
|
import numpy as np
|
|
import numpy.typing as npt
|
|
import torch
|
|
import torch._dynamo.cache_size
|
|
import torch.distributed as dist
|
|
import torch.nn as nn
|
|
from tqdm import tqdm # type: ignore
|
|
from vllm.attention import AttentionType, get_attn_backend
|
|
from vllm.attention.backends.abstract import AttentionBackend
|
|
from vllm.attention.layer import Attention
|
|
from vllm.compilation.counter import compilation_counter
|
|
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
|
|
from vllm.config import (CompilationLevel, CUDAGraphMode, VllmConfig,
|
|
get_layers_from_vllm_config)
|
|
from vllm.distributed import tensor_model_parallel_all_gather
|
|
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
|
has_kv_transfer_group)
|
|
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
|
|
from vllm.distributed.parallel_state import (get_dp_group, get_pp_group,
|
|
get_tp_group,
|
|
is_global_first_rank)
|
|
from vllm.forward_context import BatchDescriptor, get_forward_context
|
|
from vllm.logger import logger
|
|
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
|
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
|
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
|
from vllm.model_executor.model_loader import get_model
|
|
from vllm.model_executor.models.interfaces import supports_transcription
|
|
from vllm.model_executor.models.interfaces_base import (
|
|
VllmModelForPooling, is_pooling_model, is_text_generation_model)
|
|
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
|
|
from vllm.multimodal.utils import group_mm_kwargs_by_modality
|
|
from vllm.pooling_params import PoolingParams
|
|
from vllm.sampling_params import SamplingType
|
|
from vllm.sequence import IntermediateTensors
|
|
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
|
|
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
|
LazyLoader, cdiv, get_dtype_size,
|
|
is_pin_memory_available)
|
|
from vllm.utils.jsontree import json_map_leaves
|
|
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
|
|
from vllm.v1.attention.backends.utils import (
|
|
AttentionCGSupport, reorder_batch_to_split_decodes_and_prefills)
|
|
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
|
|
# yapf conflicts with isort for this block
|
|
# yapf: disable
|
|
from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec,
|
|
KVCacheConfig, KVCacheGroupSpec,
|
|
KVCacheSpec, MambaSpec,
|
|
MLAAttentionSpec,
|
|
UniformTypeKVCacheSpecs)
|
|
# yapf: enable
|
|
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
|
|
DraftTokenIds, LogprobsTensors, ModelRunnerOutput,
|
|
PoolerOutput)
|
|
from vllm.v1.pool.metadata import PoolingMetadata
|
|
from vllm.v1.sample.metadata import SamplingMetadata
|
|
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
|
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
|
from vllm.v1.utils import CpuGpuBuffer
|
|
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput
|
|
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
|
from vllm.v1.worker.utils import (AttentionGroup, bind_kv_cache,
|
|
gather_mm_placeholders,
|
|
sanity_check_mm_encoder_outputs,
|
|
scatter_mm_placeholders)
|
|
|
|
import vllm_ascend.envs as envs_ascend
|
|
from vllm_ascend.ascend_config import get_ascend_config
|
|
from vllm_ascend.ascend_forward_context import (MoECommType,
|
|
set_ascend_forward_context)
|
|
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
|
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
|
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
|
from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
|
|
set_graph_params,
|
|
update_attn_params,
|
|
update_mla_attn_params)
|
|
from vllm_ascend.eplb.adaptor.vllm_adaptor import VllmEplbAdaptor
|
|
from vllm_ascend.eplb.core.eplb_device_transfer_loader import \
|
|
D2DExpertWeightLoader
|
|
from vllm_ascend.eplb.core.eplb_worker import EplbProcess
|
|
from vllm_ascend.eplb.eplb_updator import EplbUpdator
|
|
from vllm_ascend.eplb.utils import model_register
|
|
from vllm_ascend.models.layers.mla import AscendMultiHeadLatentAttention
|
|
from vllm_ascend.multistream.ms_split import compute_split_seq_index
|
|
from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod
|
|
from vllm_ascend.platform import NPUPlatform
|
|
from vllm_ascend.sample.logits_processor import build_logitsprocs
|
|
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
|
|
from vllm_ascend.spec_decode import get_spec_decode_method
|
|
from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
|
|
from vllm_ascend.spec_decode.interface import SpecDcodeType
|
|
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
|
|
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
|
|
AscendSocVersion, ProfileExecuteDuration,
|
|
get_ascend_soc_version, is_310p, is_enable_nz,
|
|
lmhead_tp_enable)
|
|
from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch
|
|
|
|
if TYPE_CHECKING:
|
|
import xgrammar as xgr # type: ignore[import-untyped]
|
|
from vllm.v1.core.sched.output import SchedulerOutput
|
|
else:
|
|
xgr = LazyLoader("xgr", globals(), "xgrammar")
|
|
|
|
import torch_npu
|
|
|
|
# if true, allow tensor initialization and casting with internal format (e.g., NZ)
|
|
torch.npu.config.allow_internal_format = True
|
|
|
|
if is_310p():
|
|
torch_npu.npu.set_compile_mode(jit_compile=False)
|
|
ACL_FORMAT = ACL_FORMAT_FRACTAL_NZ
|
|
else:
|
|
ACL_FORMAT = ACL_FORMAT_FRACTAL_ND
|
|
|
|
|
|
@dataclass
|
|
class GraphCaptureContext:
|
|
stream: torch.npu.Stream
|
|
|
|
|
|
@contextmanager
|
|
def graph_capture(device: torch.device):
|
|
"""
|
|
`graph_capture` is a context manager which should surround the code that
|
|
is capturing the NPU graph. Its main purpose is to ensure that the
|
|
some operations will be run after the graph is captured, before the graph
|
|
is replayed. It returns a `GraphCaptureContext` object which contains the
|
|
necessary data for the graph capture. Currently, it only contains the
|
|
stream that the graph capture is running on. This stream is set to the
|
|
current NPU stream when the context manager is entered and reset to the
|
|
default stream when the context manager is exited. This is to ensure that
|
|
the graph capture is running on a separate stream from the default stream,
|
|
in order to explicitly distinguish the kernels to capture
|
|
from other kernels possibly launched on background in the default stream.
|
|
"""
|
|
graph_capture_context = GraphCaptureContext(
|
|
torch.npu.Stream(device=device))
|
|
stream = graph_capture_context.stream
|
|
|
|
# we use nullcontext now
|
|
maybe_ca_context = nullcontext()
|
|
|
|
# ensure all initialization operations complete before attempting to
|
|
# capture the graph on another stream
|
|
curr_stream = torch.npu.current_stream()
|
|
if curr_stream != stream:
|
|
stream.wait_stream(curr_stream)
|
|
|
|
with torch.npu.stream(stream), maybe_ca_context:
|
|
yield graph_capture_context
|
|
|
|
|
|
# Wrapper for ModelRunnerOutput to support overlapped execution.
|
|
class AsyncNPUModelRunnerOutput(AsyncModelRunnerOutput):
|
|
|
|
def __init__(
|
|
self,
|
|
model_runner_output: ModelRunnerOutput,
|
|
sampled_token_ids: torch.Tensor,
|
|
invalid_req_indices: list[int],
|
|
async_output_copy_stream: torch.npu.Stream,
|
|
):
|
|
self._model_runner_output = model_runner_output
|
|
self._invalid_req_indices = invalid_req_indices
|
|
|
|
# Event on the copy stream so we can synchronize the non-blocking copy.
|
|
self._async_copy_ready_event = torch.npu.Event()
|
|
|
|
# Keep a reference to the device tensor to avoid it being
|
|
# deallocated until we finish copying it to the host.
|
|
self._sampled_token_ids = sampled_token_ids
|
|
|
|
# Initiate the copy on a separate stream, but do not synchronize it.
|
|
default_stream = torch.npu.current_stream()
|
|
with torch.npu.stream(async_output_copy_stream):
|
|
async_output_copy_stream.wait_stream(default_stream)
|
|
self._sampled_token_ids_cpu = self._sampled_token_ids.to(
|
|
'cpu', non_blocking=True)
|
|
self._async_copy_ready_event.record()
|
|
|
|
def get_output(self) -> ModelRunnerOutput:
|
|
"""Copy the device tensors to the host and return a ModelRunnerOutput.
|
|
|
|
This function blocks until the copy is finished.
|
|
"""
|
|
self._async_copy_ready_event.synchronize()
|
|
|
|
# Release the device tensor once the copy has completed
|
|
del self._sampled_token_ids
|
|
|
|
valid_sampled_token_ids = self._sampled_token_ids_cpu.tolist()
|
|
for i in self._invalid_req_indices:
|
|
valid_sampled_token_ids[i].clear()
|
|
|
|
output = self._model_runner_output
|
|
output.sampled_token_ids = valid_sampled_token_ids
|
|
return output
|
|
|
|
|
|
class NPUModelRunner(LoRAModelRunnerMixin):
|
|
|
|
def __init__(self, vllm_config: VllmConfig, device: torch.device):
|
|
self.vllm_config = vllm_config
|
|
self.model_config = vllm_config.model_config
|
|
self.cache_config = vllm_config.cache_config
|
|
self.compilation_config = vllm_config.compilation_config
|
|
self.load_config = vllm_config.load_config
|
|
self.lora_config = vllm_config.lora_config
|
|
self.parallel_config = vllm_config.parallel_config
|
|
self.pin_memory = is_pin_memory_available()
|
|
self.scheduler_config = vllm_config.scheduler_config
|
|
self.speculative_config = vllm_config.speculative_config
|
|
self.block_size = vllm_config.cache_config.block_size
|
|
self.max_num_blocks_per_req = cdiv(self.model_config.max_model_len,
|
|
self.block_size)
|
|
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
|
|
decode_max_num_seqs = getattr(self.scheduler_config,
|
|
'decode_max_num_seqs', 0)
|
|
self.max_num_reqs = max(self.scheduler_config.max_num_seqs,
|
|
decode_max_num_seqs)
|
|
self.dp_size = vllm_config.parallel_config.data_parallel_size
|
|
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
|
|
self.device = device
|
|
if envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP:
|
|
self.prefetch_stream = torch.npu.Stream(device=device)
|
|
else:
|
|
self.prefetch_stream = None
|
|
self.dtype = self.model_config.dtype
|
|
if envs_ascend.VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION:
|
|
# TODO: drop the env config to use ascend sampler by default
|
|
from vllm_ascend.sample.sampler import AscendSampler
|
|
|
|
self.sampler = AscendSampler()
|
|
else:
|
|
from vllm.v1.sample.sampler import Sampler
|
|
|
|
self.sampler = Sampler()
|
|
self.reorder_batch_threshold: Optional[int] = None
|
|
|
|
# Lazy initialization, these will be set after __init__
|
|
self.kv_caches: List[torch.Tensor] = []
|
|
self.attn_groups: list[list[AttentionGroup]] = []
|
|
self.encoder_cache: Dict[str, torch.Tensor] = {}
|
|
self.attn_mask = None
|
|
self.attn_state = None
|
|
self.requests: Dict[str, CachedRequestState] = {}
|
|
self.intermediate_tensors: Optional[IntermediateTensors] = None
|
|
self.runner_only_attn_layers: set[str] = set()
|
|
|
|
self.ascend_config = get_ascend_config()
|
|
if self.ascend_config.ascend_scheduler_config.enabled:
|
|
self.chunked_prefill_enabled = self.scheduler_config.chunked_prefill_enabled
|
|
else:
|
|
self.chunked_prefill_enabled = True
|
|
self.weight_prefetch_method = WeightPrefetchMethod(
|
|
self.ascend_config.weight_prefetch_config)
|
|
|
|
if self.cache_config.cache_dtype == "auto":
|
|
self.kv_cache_dtype = self.dtype
|
|
else:
|
|
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
|
|
self.cache_config.cache_dtype]
|
|
# use_hybrid_blocks: if hybrid blocks is used.
|
|
self.use_hybrid_blocks: bool = False
|
|
self.need_accepted_tokens: bool = False
|
|
|
|
self.is_multimodal_model = self.model_config.is_multimodal_model
|
|
self.is_pooling_model = self.model_config.pooler_config is not None
|
|
if self.is_multimodal_model:
|
|
self.inputs_embeds = torch.zeros(
|
|
(self.max_num_tokens, self.model_config.get_hidden_size()),
|
|
dtype=self.dtype,
|
|
device=self.device)
|
|
# Set up Attention
|
|
self.attn_backend = get_attn_backend(
|
|
0,
|
|
self.dtype,
|
|
None,
|
|
self.block_size,
|
|
use_mla=self.model_config.use_mla,
|
|
use_sfa=self.ascend_config.use_sfa)
|
|
if torch.version.cann.startswith("8.3"):
|
|
self.attn_mask_builder = AttentionMaskBuilder(
|
|
self.scheduler_config.max_num_batched_tokens, self.dtype,
|
|
self.device)
|
|
else:
|
|
self.attn_mask_builder = AttentionMaskBuilder(
|
|
self.model_config.max_model_len, self.dtype)
|
|
|
|
# Set up speculative decoding.
|
|
self.spec_attn_mask = None
|
|
self.drafter: Optional[Union[NgramProposer, EagleProposer,
|
|
MtpProposer]] = None
|
|
self.actual_seq_lengths_q: list[int] = []
|
|
self.decode_token_per_req = 1
|
|
if self.speculative_config:
|
|
spec_token_num = self.speculative_config.num_speculative_tokens
|
|
assert spec_token_num > 0
|
|
self.decode_token_per_req = 1 + spec_token_num
|
|
self.spec_attn_mask = torch.triu(torch.ones(2048,
|
|
2048,
|
|
dtype=torch.bool),
|
|
diagonal=1).to(self.device)
|
|
if get_pp_group().is_last_rank:
|
|
self.drafter = get_spec_decode_method(
|
|
self.speculative_config.method, self.vllm_config,
|
|
self.device, self)
|
|
self.rejection_sampler = AscendRejectionSampler()
|
|
|
|
# Persistent batch.
|
|
self.input_ids = torch.zeros(self.max_num_tokens,
|
|
dtype=torch.int32,
|
|
device=self.device)
|
|
self.positions = torch.zeros(self.max_num_tokens,
|
|
dtype=torch.int64,
|
|
device=self.device)
|
|
self.query_start_loc = torch.zeros(self.max_num_reqs + 1,
|
|
dtype=torch.int32,
|
|
device=self.device)
|
|
self.seq_lens = torch.zeros(self.max_num_reqs,
|
|
dtype=torch.int32,
|
|
device=self.device)
|
|
self.slot_mapping = torch.zeros(self.max_num_tokens,
|
|
dtype=torch.int32,
|
|
device=self.device)
|
|
|
|
if self.vllm_config.model_config.use_mla and \
|
|
self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
|
|
rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
|
|
self.cos = torch.ones(self.max_num_reqs,
|
|
1,
|
|
1,
|
|
rope_dim,
|
|
dtype=self.dtype,
|
|
device=self.device)
|
|
self.sin = torch.zeros(self.max_num_reqs,
|
|
1,
|
|
1,
|
|
rope_dim,
|
|
dtype=self.dtype,
|
|
device=self.device)
|
|
else:
|
|
self.cos = None
|
|
self.sin = None
|
|
|
|
self.uses_mrope = self.model_config.uses_mrope
|
|
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
|
if self.uses_mrope:
|
|
# NOTE: `mrope_positions` is implemented with one additional dummy
|
|
# position on purpose to make it non-contiguous so that it can work
|
|
# with torch compile.
|
|
# See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923
|
|
|
|
# NOTE: When M-RoPE is enabled, position ids are 3D regardless of
|
|
# the modality of inputs. For text-only inputs, each dimension has
|
|
# identical position IDs, making M-RoPE functionally equivalent to
|
|
# 1D-RoPE.
|
|
# See page 5 of https://arxiv.org/abs/2409.12191
|
|
self.mrope_positions = torch.zeros((3, self.max_num_tokens + 1),
|
|
dtype=torch.int64,
|
|
device=self.device)
|
|
self.mrope_positions_cpu = torch.zeros(
|
|
(3, self.max_num_tokens + 1),
|
|
dtype=torch.int64,
|
|
device="cpu",
|
|
pin_memory=True)
|
|
self.mrope_positions_np = self.mrope_positions_cpu.numpy()
|
|
|
|
# OPTIMIZATION: Cache the tensors rather than creating them every step.
|
|
self.arange_np: npt.NDArray[np.int32] = np.arange(max(
|
|
self.max_num_reqs + 1, self.model_config.max_model_len,
|
|
self.max_num_tokens),
|
|
dtype=np.int32)
|
|
# NOTE(woosuk): These tensors are "stateless", i.e., they are literally
|
|
# a faster version of creating a new tensor every time. Thus, we should
|
|
# not make any assumptions about the values in these tensors.
|
|
self.input_ids_cpu = torch.zeros(self.max_num_tokens,
|
|
dtype=torch.int32,
|
|
device="cpu",
|
|
pin_memory=True)
|
|
self.positions_cpu = torch.zeros(self.max_num_tokens,
|
|
dtype=torch.int64,
|
|
device="cpu",
|
|
pin_memory=True)
|
|
self.positions_np = self.positions_cpu.numpy()
|
|
|
|
self.slot_mapping_cpu = torch.zeros(self.max_num_tokens,
|
|
dtype=torch.int32,
|
|
device="cpu",
|
|
pin_memory=True)
|
|
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
|
|
self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1,
|
|
dtype=torch.int32,
|
|
device="cpu",
|
|
pin_memory=True)
|
|
self.query_start_loc_np = self.query_start_loc_cpu.numpy()
|
|
self.seq_lens_cpu = torch.zeros(self.max_num_reqs,
|
|
dtype=torch.int32,
|
|
device="cpu",
|
|
pin_memory=True)
|
|
self.seq_lens_np = self.seq_lens_cpu.numpy()
|
|
|
|
self.use_aclgraph = self._use_aclgraph()
|
|
self.aclgraph_batch_sizes = list(
|
|
reversed(self.compilation_config.cudagraph_capture_sizes))
|
|
|
|
self.uniform_decode_query_len = 1 if not self.speculative_config else \
|
|
1 + self.speculative_config.num_speculative_tokens
|
|
# aclgraph dispatcher for runtime aclgraph dispatching.
|
|
self.aclgraph_dispatcher = CudagraphDispatcher(self.vllm_config)
|
|
# Cached outputs.
|
|
self._draft_token_ids: Optional[Union[list[list[int]],
|
|
torch.Tensor]] = None
|
|
|
|
# NOTE: we need to use `in_profile_run` to determine whether `enable_force_load_balance` is True
|
|
self.in_profile_run = False
|
|
|
|
# kv role
|
|
self.is_kv_producer = False
|
|
self.is_kv_consumer = False
|
|
if vllm_config.kv_transfer_config is not None:
|
|
self.is_kv_producer = vllm_config.kv_transfer_config.is_kv_producer
|
|
self.is_kv_consumer = vllm_config.kv_transfer_config.is_kv_consumer
|
|
|
|
# NOTE: To be clear, we need to make sure that during graph capture, the number of
|
|
# tokens is less than or equal to mc2_tokens_capacity. According to _set_cudagraph_sizes,
|
|
# the max number of tokens in graph is min(max_num_seqs * 2, 512).
|
|
if self.compilation_config.cudagraph_capture_sizes:
|
|
max_num_tokens = self.compilation_config.cudagraph_capture_sizes[0]
|
|
else:
|
|
max_num_tokens = self.max_num_reqs * self.uniform_decode_query_len
|
|
tp_size = self.parallel_config.tensor_parallel_size
|
|
# Use integer arithmetic for ceiling division.
|
|
num_tokens_per_tp_rank = (max_num_tokens + tp_size - 1) // tp_size
|
|
self.mc2_tokens_capacity = num_tokens_per_tp_rank * tp_size
|
|
self.reserved_mc2_mask = torch.zeros(
|
|
self.mc2_tokens_capacity,
|
|
dtype=torch.bool,
|
|
device=self.device,
|
|
)
|
|
|
|
self.dynamic_eplb = self.ascend_config.dynamic_eplb
|
|
if self.dynamic_eplb:
|
|
self.is_eplb_warmuped = False
|
|
self.policy_type = self.ascend_config.eplb_policy_type
|
|
self.eplb_loader = D2DExpertWeightLoader()
|
|
self.manager = Manager()
|
|
self.shared_dict = self.manager.dict({
|
|
"expert_map": None,
|
|
"moe_load": None,
|
|
"expert_maps": None
|
|
})
|
|
self.eplb_process = EplbProcess(shared_dict=self.shared_dict,
|
|
policy_type=self.policy_type,
|
|
enable_d2d=True)
|
|
self.process = self.eplb_process._launch_process()
|
|
ascend_config = get_ascend_config()
|
|
self.eplb_updator = EplbUpdator(ascend_config, self.eplb_loader,
|
|
self.eplb_process, self.process)
|
|
|
|
self.use_async_scheduling = self.scheduler_config.async_scheduling
|
|
self.async_output_copy_stream = torch.npu.Stream() if \
|
|
self.use_async_scheduling else None
|
|
# Input Batch
|
|
# NOTE(Chen): Ideally, we should initialize the input batch inside
|
|
# `initialize_kv_cache` based on the kv cache config. However, as in
|
|
# https://github.com/vllm-project/vllm/pull/18298, due to some unknown
|
|
# reasons, we have to initialize the input batch before `load_model`,
|
|
# quantization + weight offloading will fail otherwise. As a temporary
|
|
# solution, we initialize the input batch here, and re-initialize it
|
|
# in `initialize_kv_cache` if the block_sizes here is different from
|
|
# the block_sizes in the kv cache config.
|
|
self.input_batch = InputBatch(
|
|
max_num_reqs=self.max_num_reqs,
|
|
max_model_len=self.model_config.max_model_len,
|
|
max_num_batched_tokens=self.max_num_tokens,
|
|
device=self.device,
|
|
pin_memory=self.pin_memory,
|
|
vocab_size=self.model_config.get_vocab_size(),
|
|
block_sizes=[self.block_size],
|
|
is_spec_decode=bool(self.vllm_config.speculative_config),
|
|
logitsprocs=build_logitsprocs(
|
|
self.vllm_config, self.device, self.pin_memory,
|
|
self.is_pooling_model,
|
|
self.vllm_config.model_config.logits_processors),
|
|
is_pooling_model=self.is_pooling_model,
|
|
kernel_block_sizes=[[self.vllm_config.cache_config.block_size]],
|
|
)
|
|
self.num_accepted_tokens = self._make_buffer(self.max_num_reqs,
|
|
dtype=torch.int64)
|
|
self.num_draft_tokens = self._make_buffer(self.max_num_reqs,
|
|
dtype=torch.int32)
|
|
|
|
def _make_buffer(self,
|
|
*size: Union[int, torch.SymInt],
|
|
dtype: torch.dtype,
|
|
numpy: bool = True) -> CpuGpuBuffer:
|
|
# Bfloat16 torch tensors cannot be directly cast to a numpy array, so
|
|
# if a bfloat16 buffer is needed without a corresponding numpy array,
|
|
# don't bother instantiating the numpy array.
|
|
return CpuGpuBuffer(*size,
|
|
dtype=dtype,
|
|
device=self.device,
|
|
pin_memory=self.pin_memory,
|
|
with_numpy=numpy)
|
|
|
|
def _update_states_after_model_execute(
|
|
self, output_token_ids: torch.Tensor) -> None:
|
|
"""Update the cached states after model execution.
|
|
|
|
This is used for MTP/EAGLE for hybrid models, as in linear attention,
|
|
only the last token's state is kept. In MTP/EAGLE, for draft tokens
|
|
the state are kept util we decide how many tokens are accepted for
|
|
each sequence, and a shifting is done during the next iteration
|
|
based on the number of accepted tokens.
|
|
"""
|
|
if not self.model_config.is_hybrid or not self.speculative_config:
|
|
return
|
|
|
|
# Find the number of accepted tokens for each sequence.
|
|
num_accepted_tokens = (torch.cat(
|
|
[
|
|
output_token_ids,
|
|
torch.full((output_token_ids.size(0), 1),
|
|
-1,
|
|
device=output_token_ids.device),
|
|
],
|
|
dim=1) == -1).int().argmax(-1).cpu().numpy()
|
|
for i, num_tokens in enumerate(num_accepted_tokens):
|
|
self.input_batch.num_accepted_tokens_cpu[i] = num_tokens
|
|
|
|
def _use_aclgraph(self) -> bool:
|
|
return self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE and self.compilation_config.level == CompilationLevel.PIECEWISE and not self.model_config.enforce_eager
|
|
|
|
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
|
|
# Remove finished requests from the cached states.
|
|
for req_id in scheduler_output.finished_req_ids:
|
|
self.requests.pop(req_id, None)
|
|
|
|
# Remove the finished requests from the persistent batch.
|
|
# NOTE(woosuk): There could be an edge case where finished_req_ids and
|
|
# scheduled_req_ids overlap. This happens when a request is aborted and
|
|
# then resubmitted with the same ID. In this case, we treat them as two
|
|
# distinct requests - clearing the cached states for the first request
|
|
# and handling the second as a new request.
|
|
for req_id in scheduler_output.finished_req_ids:
|
|
self.input_batch.remove_request(req_id)
|
|
for mm_hash in scheduler_output.free_encoder_mm_hashes:
|
|
self.encoder_cache.pop(mm_hash, None)
|
|
# Remove the unscheduled requests from the persistent batch.
|
|
# NOTE(woosuk): The unscheduled requests are either preempted requests
|
|
# or running requests that are not scheduled in this step. We remove
|
|
# them from the persistent batch but keep their cached states since
|
|
# they will be scheduled again sometime in the future.
|
|
scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys()
|
|
cached_req_ids = self.input_batch.req_id_to_index.keys()
|
|
unscheduled_req_ids = cached_req_ids - scheduled_req_ids
|
|
# NOTE(woosuk): The persistent batch optimization assumes that
|
|
# consecutive batches contain mostly the same requests. If batches
|
|
# have low request overlap (e.g., alternating between two distinct
|
|
# sets of requests), this optimization becomes very inefficient.
|
|
for req_id in unscheduled_req_ids:
|
|
self.input_batch.remove_request(req_id)
|
|
|
|
req_ids_to_add: list[str] = []
|
|
# Add new requests to the cached states.
|
|
for new_req_data in scheduler_output.scheduled_new_reqs:
|
|
req_id = new_req_data.req_id
|
|
sampling_params = new_req_data.sampling_params
|
|
pooling_params = new_req_data.pooling_params
|
|
|
|
if sampling_params and \
|
|
sampling_params.sampling_type == SamplingType.RANDOM_SEED:
|
|
generator = torch.Generator(device=self.device)
|
|
generator.manual_seed(sampling_params.seed)
|
|
else:
|
|
generator = None
|
|
|
|
if pooling_params:
|
|
assert (task := pooling_params.task) is not None, (
|
|
"You did not set `task` in the API")
|
|
model = cast(VllmModelForPooling, self.get_model())
|
|
to_update = model.pooler.get_pooling_updates(task)
|
|
to_update.apply(pooling_params)
|
|
|
|
backward_kwargs = {}
|
|
backward_kwargs["mm_features"] = new_req_data.mm_features
|
|
|
|
self.requests[req_id] = CachedRequestState(
|
|
req_id=req_id,
|
|
prompt_token_ids=new_req_data.prompt_token_ids,
|
|
sampling_params=sampling_params,
|
|
pooling_params=pooling_params,
|
|
generator=generator,
|
|
block_ids=new_req_data.block_ids,
|
|
num_computed_tokens=new_req_data.num_computed_tokens,
|
|
output_token_ids=[],
|
|
lora_request=new_req_data.lora_request,
|
|
**backward_kwargs,
|
|
)
|
|
|
|
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
|
if self.uses_mrope:
|
|
self._init_mrope_positions(self.requests[req_id])
|
|
|
|
req_ids_to_add.append(req_id)
|
|
|
|
# Update the states of the running/resumed requests.
|
|
is_last_rank = get_pp_group().is_last_rank
|
|
req_data = scheduler_output.scheduled_cached_reqs
|
|
for i, req_id in enumerate(req_data.req_ids):
|
|
req_state = self.requests[req_id]
|
|
num_computed_tokens = req_data.num_computed_tokens[i]
|
|
new_block_ids = req_data.new_block_ids[i]
|
|
resumed_from_preemption = req_data.resumed_from_preemption[i]
|
|
|
|
# Update the cached states.
|
|
req_state.num_computed_tokens = num_computed_tokens
|
|
|
|
if not is_last_rank:
|
|
# When using PP, the scheduler sends the sampled tokens back,
|
|
# because there's no direct communication between the first-
|
|
# stage worker and the last-stage worker.
|
|
new_token_ids = req_data.new_token_ids[i]
|
|
# Add the sampled token(s) from the previous step (if any).
|
|
# This doesn't include "unverified" tokens like spec tokens.
|
|
num_new_tokens = (num_computed_tokens + len(new_token_ids) -
|
|
req_state.num_tokens)
|
|
if num_new_tokens == 1:
|
|
# Avoid slicing list in most common case.
|
|
req_state.output_token_ids.append(new_token_ids[-1])
|
|
elif num_new_tokens > 0:
|
|
req_state.output_token_ids.extend(
|
|
new_token_ids[-num_new_tokens:])
|
|
|
|
# Update the block IDs.
|
|
if not resumed_from_preemption:
|
|
if new_block_ids is not None:
|
|
# Append the new blocks to the existing block IDs.
|
|
for block_ids, new_ids in zip(req_state.block_ids,
|
|
new_block_ids):
|
|
block_ids.extend(new_ids)
|
|
else:
|
|
assert new_block_ids is not None
|
|
# The request is resumed from preemption.
|
|
# Replace the existing block IDs with the new ones.
|
|
req_state.block_ids = new_block_ids
|
|
|
|
req_index = self.input_batch.req_id_to_index.get(req_id)
|
|
if req_index is None:
|
|
# The request is not in the persistent batch.
|
|
# The request was either preempted and resumed later, or was not
|
|
# scheduled in the previous step and needs to be added again.
|
|
req_ids_to_add.append(req_id)
|
|
continue
|
|
|
|
# Update the persistent batch.
|
|
self.input_batch.num_computed_tokens_cpu[req_index] = (
|
|
num_computed_tokens)
|
|
if new_block_ids is not None:
|
|
self.input_batch.block_table.append_row(
|
|
new_block_ids, req_index)
|
|
|
|
# For the last rank, we don't need to update the token_ids_cpu
|
|
# because the sampled tokens are already cached.
|
|
if not is_last_rank:
|
|
# Add new_token_ids to token_ids_cpu.
|
|
start_token_index = num_computed_tokens
|
|
end_token_index = num_computed_tokens + len(new_token_ids)
|
|
self.input_batch.token_ids_cpu[
|
|
req_index,
|
|
start_token_index:end_token_index] = new_token_ids
|
|
self.input_batch.num_tokens_no_spec[
|
|
req_index] = end_token_index
|
|
self.input_batch.num_tokens[req_index] = end_token_index
|
|
|
|
# Add spec_token_ids to token_ids_cpu.
|
|
spec_token_ids = (
|
|
scheduler_output.scheduled_spec_decode_tokens.get(req_id, ()))
|
|
if spec_token_ids:
|
|
num_spec_tokens = len(spec_token_ids)
|
|
start_index = self.input_batch.num_tokens_no_spec[req_index]
|
|
end_token_index = start_index + num_spec_tokens
|
|
self.input_batch.token_ids_cpu[
|
|
req_index, start_index:end_token_index] = spec_token_ids
|
|
# NOTE(woosuk): `num_tokens` here may include spec tokens.
|
|
self.input_batch.num_tokens[req_index] += num_spec_tokens
|
|
|
|
# Add the new or resumed requests to the persistent batch.
|
|
# The smaller empty indices are filled first.
|
|
for req_id in req_ids_to_add:
|
|
req_state = self.requests[req_id]
|
|
self.input_batch.add_request(req_state)
|
|
|
|
# Condense the batched states if there are gaps left by removed requests
|
|
self.input_batch.condense()
|
|
# Allow attention backend to reorder the batch, potentially
|
|
self._may_reorder_batch(scheduler_output)
|
|
# Refresh batch metadata with any pending updates.
|
|
self.input_batch.refresh_metadata()
|
|
|
|
def _init_mrope_positions(self, req_state: CachedRequestState):
|
|
image_grid_thw = []
|
|
video_grid_thw = []
|
|
second_per_grid_ts = []
|
|
audio_feature_lengths = []
|
|
use_audio_in_video = False
|
|
assert req_state.mm_features is not None
|
|
for mm_feature in req_state.mm_features:
|
|
mm_item = mm_feature.data
|
|
if mm_item is None:
|
|
continue
|
|
mm_input = mm_item.get_data()
|
|
if (t := mm_input.get("image_grid_thw")) is not None:
|
|
image_grid_thw.append(t.tolist())
|
|
if (t := mm_input.get("video_grid_thw")) is not None:
|
|
video_grid_thw.append(t.tolist())
|
|
if (t := mm_input.get("second_per_grid_ts")) is not None:
|
|
second_per_grid_ts.append(t)
|
|
if (t := mm_input.get("audio_feature_lengths")) is not None:
|
|
audio_feature_lengths.append(t)
|
|
if mm_input.get("use_audio_in_video") is True:
|
|
use_audio_in_video = True
|
|
|
|
req_state.mrope_positions, req_state.mrope_position_delta = \
|
|
MRotaryEmbedding.get_input_positions_tensor(
|
|
req_state.prompt_token_ids,
|
|
hf_config=self.model_config.hf_config,
|
|
image_grid_thw=image_grid_thw,
|
|
video_grid_thw=video_grid_thw,
|
|
second_per_grid_ts=second_per_grid_ts,
|
|
audio_feature_lengths=audio_feature_lengths,
|
|
use_audio_in_video=use_audio_in_video,
|
|
)
|
|
|
|
def _sync_metadata_across_dp(
|
|
self, num_tokens: int, with_prefill: bool, enable_dbo: bool
|
|
) -> tuple[int, Optional[torch.Tensor], bool, bool]:
|
|
# TODO: In vLLM, the only thing that needs to be synced is num_tokens, but in
|
|
# our case, we still need to sync the other two flags as well. So we need to
|
|
# include them in the all_reduce operation, and more over, we CANNOT skip it
|
|
# even if we are running in eager mode, which harms performance.
|
|
# FIXME: Restore the `or self.vllm_config.model_config.enforce_eager` here
|
|
# immediately once the other two flags are no longer needed.
|
|
if self.dp_size == 1:
|
|
return num_tokens, None, with_prefill, enable_dbo
|
|
|
|
# Sync num_tokens, with_prefill, enable_dbo across dp ranks
|
|
num_tokens_tensor = torch.tensor([
|
|
num_tokens if i == self.dp_rank else 0 for i in range(self.dp_size)
|
|
],
|
|
dtype=torch.int32,
|
|
device="npu")
|
|
|
|
flags_tensor = torch.tensor(
|
|
[int(with_prefill), int(not enable_dbo)],
|
|
dtype=torch.int32,
|
|
device="npu")
|
|
|
|
packed_tensor = torch.cat([num_tokens_tensor, flags_tensor])
|
|
|
|
dist.all_reduce(packed_tensor, group=get_dp_group().device_group)
|
|
|
|
# Unpack the results
|
|
num_tokens_across_dp = packed_tensor[:-2]
|
|
synced_flags = packed_tensor[-2:]
|
|
|
|
max_tokens_across_dp = torch.max(num_tokens_across_dp).item()
|
|
global_with_prefill = bool(synced_flags[0])
|
|
global_enable_dbo = not bool(synced_flags[1])
|
|
|
|
# Create a tensor for num_tokens_after_padding
|
|
num_tokens_after_padding = torch.tensor([max_tokens_across_dp] *
|
|
self.dp_size,
|
|
device="npu",
|
|
dtype=torch.int32)
|
|
|
|
return max_tokens_across_dp, num_tokens_after_padding, global_with_prefill, global_enable_dbo
|
|
|
|
def _check_dbo_is_valid(self, query_lens: torch.Tensor,
|
|
attn_state: AscendAttentionState,
|
|
num_tokens: int) -> bool:
|
|
# do the checks for dp + dbo
|
|
if attn_state in [
|
|
AscendAttentionState.DecodeOnly,
|
|
AscendAttentionState.SpecDecoding
|
|
]:
|
|
return False
|
|
# considering the case that one dp rank may enable dbo while others may not
|
|
if not self.vllm_config.model_config.use_mla or not envs_ascend.VLLM_ASCEND_ENABLE_DBO:
|
|
return False
|
|
# TODO: remove it if token-level microbatch is enabled
|
|
[token_index,
|
|
seq_index] = compute_split_seq_index(query_lens, attn_state,
|
|
num_tokens)
|
|
if token_index == 0 or seq_index == 0 or seq_index == len(
|
|
query_lens) or num_tokens < 256:
|
|
return False
|
|
return True
|
|
|
|
def get_model(self) -> nn.Module:
|
|
# get raw model out of the aclgraph wrapper.
|
|
if isinstance(self.model, ACLGraphWrapper):
|
|
return self.model.unwrap()
|
|
return self.model
|
|
|
|
def get_supported_generation_tasks(self) -> "list[GenerationTask]":
|
|
model = self.get_model()
|
|
supported_tasks = list[GenerationTask]()
|
|
|
|
if is_text_generation_model(model):
|
|
supported_tasks.append("generate")
|
|
|
|
if supports_transcription(model):
|
|
if model.supports_transcription_only:
|
|
return ["transcription"]
|
|
|
|
supported_tasks.append("transcription")
|
|
|
|
return supported_tasks
|
|
|
|
def get_supported_tasks(self) -> "tuple[SupportedTask, ...]":
|
|
tasks = list[SupportedTask]()
|
|
|
|
if self.model_config.runner_type == "generate":
|
|
tasks.extend(self.get_supported_generation_tasks())
|
|
if self.model_config.runner_type == "pooling":
|
|
tasks.extend(self.get_supported_pooling_tasks())
|
|
|
|
return tuple(tasks)
|
|
|
|
def _make_attention_mask(self, seq_lens, position,
|
|
attn_state) -> torch.Tensor:
|
|
# Chunk Prefill situation.
|
|
if attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla and not self.ascend_config.use_sfa:
|
|
if torch.version.cann.startswith("8.3"):
|
|
return self.attn_mask_builder.get_splitfuse_attn_mask()
|
|
else:
|
|
return self.attn_mask_builder.get_splitfuse_attn_mask(
|
|
seq_lens, position, self.dtype, self.device)
|
|
|
|
# Prefill without cache situation.
|
|
elif attn_state == AscendAttentionState.PrefillNoCache:
|
|
max_seq_len = max(seq_lens.max().item(), 0)
|
|
return self.attn_mask_builder.get_attn_mask(
|
|
max_seq_len, self.dtype, self.device)
|
|
# Prefill with cache hit.
|
|
elif attn_state == AscendAttentionState.PrefillCacheHit:
|
|
return self.attn_mask_builder.get_attn_mask(
|
|
128, self.dtype, self.device)
|
|
# Decode-only situation.
|
|
else:
|
|
return None
|
|
|
|
def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"):
|
|
mrope_pos_ptr = 0
|
|
for index, req_id in enumerate(self.input_batch.req_ids):
|
|
req = self.requests[req_id]
|
|
assert req.mrope_positions is not None
|
|
|
|
num_computed_tokens = \
|
|
self.input_batch.num_computed_tokens_cpu[index]
|
|
num_scheduled_tokens = \
|
|
scheduler_output.num_scheduled_tokens[req_id]
|
|
num_prompt_tokens = len(req.prompt_token_ids)
|
|
|
|
if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens:
|
|
prompt_part_len = max(0,
|
|
num_prompt_tokens - num_computed_tokens)
|
|
completion_part_len = max(
|
|
0, num_scheduled_tokens - prompt_part_len)
|
|
else:
|
|
prompt_part_len = num_scheduled_tokens
|
|
completion_part_len = 0
|
|
|
|
assert num_scheduled_tokens == prompt_part_len + completion_part_len
|
|
|
|
if prompt_part_len > 0:
|
|
# prompt's mrope_positions are pre-computed
|
|
dst_start = mrope_pos_ptr
|
|
dst_end = mrope_pos_ptr + prompt_part_len
|
|
src_start = num_computed_tokens
|
|
src_end = num_computed_tokens + prompt_part_len
|
|
|
|
self.mrope_positions_cpu[:, dst_start:dst_end] = \
|
|
req.mrope_positions[:,src_start:src_end]
|
|
|
|
mrope_pos_ptr += prompt_part_len
|
|
|
|
if completion_part_len > 0:
|
|
# compute completion's mrope_positions on-the-fly
|
|
dst_start = mrope_pos_ptr
|
|
dst_end = mrope_pos_ptr + completion_part_len
|
|
MRotaryEmbedding.get_next_input_positions_tensor(
|
|
out=self.mrope_positions_np,
|
|
out_offset=dst_start,
|
|
mrope_position_delta=req.mrope_position_delta,
|
|
context_len=num_computed_tokens + prompt_part_len,
|
|
num_new_tokens=completion_part_len,
|
|
)
|
|
|
|
mrope_pos_ptr += completion_part_len
|
|
|
|
def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
|
|
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
|
|
if not scheduled_encoder_inputs:
|
|
return
|
|
|
|
# Batch the multi-modal inputs.
|
|
mm_kwargs, mm_hashes_pos = self._batch_mm_kwargs_from_scheduler(
|
|
scheduler_output)
|
|
encoder_outputs = []
|
|
|
|
for _, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
|
|
mm_kwargs,
|
|
device=self.device,
|
|
pin_memory=True,
|
|
):
|
|
# Run the encoder.
|
|
# `curr_group_outputs` is either of the following:
|
|
# 1. A tensor of shape (num_items, feature_size, hidden_size)
|
|
# in case feature_size is fixed across all multimodal items.
|
|
# 2. A list or tuple (length: num_items) of tensors, each of shape
|
|
# (feature_size, hidden_size) in case the feature size is dynamic
|
|
# depending on the input multimodal items.
|
|
curr_group_outputs = self.model.get_multimodal_embeddings(
|
|
**mm_kwargs_group)
|
|
|
|
sanity_check_mm_encoder_outputs(
|
|
curr_group_outputs,
|
|
expected_num_items=num_items,
|
|
)
|
|
|
|
for output in curr_group_outputs:
|
|
encoder_outputs.append(output)
|
|
|
|
for (mm_hash, pos_info), output in zip(mm_hashes_pos, encoder_outputs):
|
|
self.encoder_cache[mm_hash] = scatter_mm_placeholders(
|
|
output,
|
|
is_embed=pos_info.is_embed,
|
|
)
|
|
|
|
def _batch_mm_kwargs_from_scheduler(
|
|
self,
|
|
scheduler_output: "SchedulerOutput",
|
|
) -> tuple[list[MultiModalKwargsItem], list[tuple[str, PlaceholderRange]]]:
|
|
"""Batch multimodal kwargs from scheduled encoder inputs.
|
|
|
|
Args:
|
|
scheduler_output: The scheduler output containing scheduled encoder
|
|
inputs.
|
|
|
|
Returns:
|
|
A tuple of (mm_kwargs, req_ids_pos) where:
|
|
- mm_kwargs: List of multimodal kwargs items to be batched
|
|
- mm_hashes_pos: List of (mm_hash, position_info) tuples
|
|
"""
|
|
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
|
|
if not scheduled_encoder_inputs:
|
|
return [], []
|
|
# Batch the multi-modal inputs.
|
|
mm_kwargs = list[MultiModalKwargsItem]()
|
|
# list of tuple (mm_hash, position_info)
|
|
mm_hashes_pos = list[tuple[str, PlaceholderRange]]()
|
|
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
|
|
req_state = self.requests[req_id]
|
|
assert req_state.mm_features is not None
|
|
for mm_input_id in encoder_input_ids:
|
|
mm_feature = req_state.mm_features[mm_input_id]
|
|
mm_hash = mm_feature.identifier
|
|
mm_kwargs.append(mm_feature.data)
|
|
mm_hashes_pos.append((mm_hash, mm_feature.mm_position))
|
|
|
|
return mm_kwargs, mm_hashes_pos
|
|
|
|
def _gather_mm_embeddings(
|
|
self,
|
|
scheduler_output: "SchedulerOutput",
|
|
) -> list[torch.Tensor]:
|
|
|
|
def _iter_mm_features(req_state: CachedRequestState):
|
|
assert req_state.mm_features is not None
|
|
for mm_feature in req_state.mm_features:
|
|
pos_info = mm_feature.mm_position
|
|
yield mm_feature.identifier, pos_info, getattr(
|
|
pos_info, "is_embed", None)
|
|
|
|
mm_embeds: list[torch.Tensor] = []
|
|
|
|
for req_id in self.input_batch.req_ids:
|
|
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
|
|
req_id]
|
|
req_state = self.requests[req_id]
|
|
num_computed_tokens = req_state.num_computed_tokens
|
|
|
|
for mm_hash, pos_info, is_embed in _iter_mm_features(req_state):
|
|
start_pos = pos_info.offset
|
|
num_encoder_tokens = pos_info.length
|
|
|
|
if start_pos >= num_computed_tokens + num_scheduled_tokens:
|
|
break
|
|
if start_pos + num_encoder_tokens <= num_computed_tokens:
|
|
continue
|
|
|
|
start_idx = max(num_computed_tokens - start_pos, 0)
|
|
end_idx = min(
|
|
num_computed_tokens - start_pos + num_scheduled_tokens,
|
|
num_encoder_tokens,
|
|
)
|
|
assert start_idx < end_idx
|
|
|
|
encoder_output = self.encoder_cache.get(mm_hash, None)
|
|
assert encoder_output is not None, \
|
|
f"Encoder cache miss for {mm_hash}."
|
|
|
|
if is_embed is not None:
|
|
is_embed = is_embed[start_idx:end_idx]
|
|
|
|
mm_embeds_item = gather_mm_placeholders(
|
|
encoder_output[start_idx:end_idx],
|
|
is_embed=is_embed,
|
|
)
|
|
mm_embeds.append(mm_embeds_item)
|
|
return mm_embeds
|
|
|
|
def _get_cumsum_and_arange(
|
|
self,
|
|
num_tokens: np.ndarray,
|
|
cumsum_dtype: Optional[np.dtype] = None,
|
|
) -> tuple[np.ndarray, np.ndarray]:
|
|
"""Get the cumulative sum and batched arange of the given array.
|
|
# E.g., [2, 5, 3] -> ([2, 7, 10], [0, 1, 0, 1, 2, 3, 4, 0, 1, 2])
|
|
# Equivalent to but faster than:
|
|
# np.concatenate([np.arange(n) for n in num_tokens])
|
|
"""
|
|
# Step 1. [2, 5, 3] -> [2, 7, 10]
|
|
cu_num_tokens = np.cumsum(num_tokens, dtype=cumsum_dtype)
|
|
total_num_tokens = cu_num_tokens[-1]
|
|
# Step 2. [2, 7, 10] -> [0, 0, 2, 2, 2, 2, 2, 7, 7, 7]
|
|
cumsums_offsets = np.repeat(cu_num_tokens - num_tokens, num_tokens)
|
|
# Step 3. [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
|
|
arange = self.arange_np[:total_num_tokens] - cumsums_offsets
|
|
|
|
return cu_num_tokens, arange
|
|
|
|
def _prepare_input_ids(self, total_num_scheduled_tokens: int,
|
|
cu_num_tokens: np.ndarray) -> None:
|
|
"""Prepare the input IDs for the current batch.
|
|
|
|
Carefully handles the `prev_sampled_token_ids` which can be cached
|
|
from the previous engine iteration, in which case those tokens on the
|
|
NPU need to be copied into the corresponding slots into input_ids."""
|
|
|
|
if self.input_batch.prev_sampled_token_ids is None:
|
|
# Normal scheduling case
|
|
self.input_ids[:total_num_scheduled_tokens].copy_(
|
|
self.input_ids_cpu[:total_num_scheduled_tokens],
|
|
non_blocking=True)
|
|
return
|
|
|
|
# Async scheduling case, where some decode requests from the previous
|
|
# iteration won't have entries in input_ids_cpu and need to be copied
|
|
# on the NPU from prev_sampled_token_ids.
|
|
prev_req_id_to_index = self.input_batch.prev_req_id_to_index
|
|
assert prev_req_id_to_index is not None
|
|
flattened_indices = []
|
|
prev_common_req_indices = []
|
|
indices_match = True
|
|
max_flattened_index = -1
|
|
for req_id, cur_index in self.input_batch.req_id_to_index.items():
|
|
if (prev_index := prev_req_id_to_index.get(req_id)) is not None:
|
|
prev_common_req_indices.append(prev_index)
|
|
# We need to compute the flattened input_ids index of the
|
|
# last token in each common request.
|
|
flattened_index = cu_num_tokens[cur_index].item() - 1
|
|
flattened_indices.append(flattened_index)
|
|
indices_match &= (prev_index == flattened_index)
|
|
max_flattened_index = max(max_flattened_index, flattened_index)
|
|
num_commmon_tokens = len(flattened_indices)
|
|
if num_commmon_tokens < total_num_scheduled_tokens:
|
|
# If not all requests are decodes from the last iteration,
|
|
# We need to copy the input_ids_cpu to the NPU first.
|
|
self.input_ids[:total_num_scheduled_tokens].copy_(
|
|
self.input_ids_cpu[:total_num_scheduled_tokens],
|
|
non_blocking=True)
|
|
if num_commmon_tokens == 0:
|
|
# No requests in common with the previous iteration
|
|
# So input_ids_cpu will have all the input ids.
|
|
return
|
|
if indices_match and max_flattened_index == (num_commmon_tokens - 1):
|
|
# Common-case optimization: the batch is unchanged
|
|
# and no reordering happened.
|
|
# The indices are both the same permutation of 0..N-1 so
|
|
# we can copy directly using a single slice.
|
|
self.input_ids[:num_commmon_tokens].copy_(
|
|
self.input_batch.prev_sampled_token_ids[:num_commmon_tokens,
|
|
0],
|
|
non_blocking=True)
|
|
return
|
|
# Upload the index tensors asynchronously
|
|
# so the scatter can be non-blocking.
|
|
input_ids_index_tensor = torch.tensor(flattened_indices,
|
|
dtype=torch.int64,
|
|
pin_memory=self.pin_memory).to(
|
|
self.device,
|
|
non_blocking=True)
|
|
prev_common_req_indices_tensor = torch.tensor(
|
|
prev_common_req_indices,
|
|
dtype=torch.int64,
|
|
pin_memory=self.pin_memory).to(self.device, non_blocking=True)
|
|
self.input_ids.scatter_(dim=0,
|
|
index=input_ids_index_tensor,
|
|
src=self.input_batch.prev_sampled_token_ids[
|
|
prev_common_req_indices_tensor, 0])
|
|
|
|
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
|
|
"""
|
|
Update the order of requests in the batch based on the attention
|
|
backend's needs. For example, some attention backends (namely MLA) may
|
|
want to separate requests based on if the attention computation will be
|
|
compute-bound or memory-bound.
|
|
|
|
Args:
|
|
scheduler_output: The scheduler output.
|
|
"""
|
|
# Attention free models have zero kv_cache_goups, however models
|
|
# like Mamba are also attention free but use the kv_cache for
|
|
# keeping its internal state. This is why we check the number
|
|
# of kv_cache groups instead of solely checking
|
|
# for self.model_config.is_attention_free.
|
|
if len(self.kv_cache_config.kv_cache_groups) == 0:
|
|
return
|
|
|
|
if self.reorder_batch_threshold is not None:
|
|
reorder_batch_to_split_decodes_and_prefills(
|
|
self.input_batch,
|
|
scheduler_output,
|
|
decode_threshold=self.reorder_batch_threshold)
|
|
|
|
def _prepare_inputs(
|
|
self,
|
|
scheduler_output: "SchedulerOutput",
|
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
) -> tuple[dict[str, Any], torch.Tensor, np.ndarray, int, torch.Tensor,
|
|
int, torch.Tensor, SpecDecodeMetadata, Optional[torch.Tensor],
|
|
Optional[torch.Tensor], Optional[torch.Tensor], int]:
|
|
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
|
assert total_num_scheduled_tokens > 0
|
|
num_reqs = self.input_batch.num_reqs
|
|
assert num_reqs > 0
|
|
|
|
# OPTIMIZATION: Start copying the block table first.
|
|
# This way, we can overlap the copy with the following CPU operations.
|
|
self.input_batch.block_table.commit_block_table(num_reqs)
|
|
|
|
# Get the number of scheduled tokens for each request.
|
|
req_ids = self.input_batch.req_ids
|
|
tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
|
|
num_scheduled_tokens = np.array(tokens, dtype=np.int32)
|
|
max_num_scheduled_tokens = num_scheduled_tokens.max()
|
|
num_valid_tokens = np.array([
|
|
num_tokens -
|
|
len(scheduler_output.scheduled_spec_decode_tokens.get(i, []))
|
|
for num_tokens, i in zip(tokens, req_ids)
|
|
],
|
|
dtype=np.int32)
|
|
|
|
if (self.use_aclgraph and total_num_scheduled_tokens
|
|
<= self.aclgraph_batch_sizes[-1]):
|
|
# Add padding to the batch size.
|
|
num_input_tokens = self.vllm_config.pad_for_cudagraph(
|
|
total_num_scheduled_tokens)
|
|
else:
|
|
# Eager mode.
|
|
num_input_tokens = total_num_scheduled_tokens
|
|
|
|
# Get the attention state.
|
|
attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens,
|
|
num_valid_tokens)
|
|
self.attn_state = attn_state # type: ignore
|
|
|
|
# Determine if it's a splitfuse batch
|
|
with_prefill = attn_state not in [
|
|
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
|
|
]
|
|
|
|
self.query_lens = torch.from_numpy(num_scheduled_tokens)
|
|
enable_dbo = self._check_dbo_is_valid(self.query_lens.tolist(),
|
|
attn_state,
|
|
total_num_scheduled_tokens)
|
|
|
|
# Get info across DP ranks.
|
|
# NOTE: maybe_padded_num_tokens is only used when using TorchAir with DP,
|
|
# Otherwise, it's just max_tokens_across_dp_cpu
|
|
(maybe_padded_num_tokens, num_tokens_across_dp, with_prefill,
|
|
enable_dbo) = self._sync_metadata_across_dp(num_input_tokens,
|
|
with_prefill, enable_dbo)
|
|
|
|
# TODO: Now that num_input_tokens is basically identical with maybe_padded_num_tokens
|
|
# We should consider removing maybe_padded_num_tokens later
|
|
num_input_tokens = maybe_padded_num_tokens
|
|
|
|
# Hot-Swap lora model
|
|
if self.lora_config:
|
|
self.set_active_loras(self.input_batch, num_scheduled_tokens)
|
|
|
|
# Get request indices.
|
|
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
|
|
req_indices = np.repeat(self.arange_np[:num_reqs],
|
|
num_scheduled_tokens)
|
|
|
|
# cu_num_tokens: [2, 5, 3] -> [2, 7, 10]
|
|
# arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
|
|
cu_num_tokens, arange = self._get_cumsum_and_arange(
|
|
num_scheduled_tokens)
|
|
|
|
positions_np = self.positions_np[:total_num_scheduled_tokens]
|
|
np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
|
|
arange,
|
|
out=positions_np)
|
|
|
|
# Calculate M-RoPE positions.
|
|
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
|
if self.uses_mrope:
|
|
self._calc_mrope_positions(scheduler_output)
|
|
|
|
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
|
self.mrope_positions[:, :total_num_scheduled_tokens].copy_(
|
|
self.mrope_positions_cpu[:, :total_num_scheduled_tokens],
|
|
non_blocking=True)
|
|
|
|
# Get token indices.
|
|
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
|
|
# -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
|
|
# where M is the max_model_len.
|
|
token_indices = (positions_np +
|
|
req_indices * self.input_batch.token_ids_cpu.shape[1])
|
|
|
|
# Prepare input_ids.
|
|
# NOTE(woosuk): We use torch.index_select instead of np.take here
|
|
# because torch.index_select is much faster than np.take for large
|
|
# tensors.
|
|
torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
|
|
0,
|
|
torch.from_numpy(token_indices),
|
|
out=self.input_ids_cpu[:total_num_scheduled_tokens])
|
|
|
|
# Prepare some information for building Attention-Metadata
|
|
# Compute and commit slot mapping
|
|
self.input_batch.block_table.compute_slot_mapping(
|
|
req_indices, positions_np)
|
|
self.input_batch.block_table.commit_slot_mapping(
|
|
total_num_scheduled_tokens)
|
|
|
|
self.query_start_loc_np[0] = 0
|
|
self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens
|
|
self.query_start_loc[:num_reqs + 1].copy_(
|
|
self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True)
|
|
|
|
self.seq_lens_np[:num_reqs] = (
|
|
self.input_batch.num_computed_tokens_cpu[:num_reqs] +
|
|
num_scheduled_tokens)
|
|
self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs],
|
|
non_blocking=True)
|
|
|
|
# Fill unused with -1. Needed for reshape_and_cache
|
|
self.query_start_loc[num_reqs + 1:].fill_(-1)
|
|
self.seq_lens[num_reqs:].fill_(0)
|
|
|
|
self.query_lens = torch.from_numpy(num_scheduled_tokens)
|
|
|
|
# Copy the tensors to the NPU.
|
|
self.input_ids[:total_num_scheduled_tokens].copy_(
|
|
self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True)
|
|
|
|
self.positions_cpu[total_num_scheduled_tokens:num_input_tokens].zero_()
|
|
self.positions[:num_input_tokens].copy_(
|
|
self.positions_cpu[:num_input_tokens], non_blocking=True)
|
|
|
|
# Make Attention metadata
|
|
positions_cpu = self.positions_cpu[:num_input_tokens]
|
|
positions = self.positions[:num_input_tokens]
|
|
seq_lens_cpu = self.seq_lens_cpu[:num_reqs]
|
|
self.attn_mask = self._make_attention_mask(seq_lens=seq_lens_cpu,
|
|
position=positions_cpu,
|
|
attn_state=attn_state)
|
|
self.attn_state = attn_state # type: ignore
|
|
|
|
self.with_prefill = with_prefill
|
|
self.num_tokens_across_dp = num_tokens_across_dp
|
|
self._update_graph_pad_size(with_prefill, maybe_padded_num_tokens)
|
|
attn_metadata: dict[str, Any] = {}
|
|
|
|
# Prepare input_ids
|
|
token_indices = (positions_np +
|
|
req_indices * self.input_batch.token_ids_cpu.shape[1])
|
|
torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
|
|
0,
|
|
torch.from_numpy(token_indices),
|
|
out=self.input_ids_cpu[:total_num_scheduled_tokens])
|
|
# Copy the tensors to the NPU.
|
|
self._prepare_input_ids(total_num_scheduled_tokens, cu_num_tokens)
|
|
|
|
# _prepare_inputs may reorder the batch, so we must gather
|
|
# multi-modal outputs after that to ensure the correct order
|
|
if self.is_multimodal_model:
|
|
# Run the multimodal encoder if any.
|
|
self._execute_mm_encoder(scheduler_output)
|
|
mm_embeds = self._gather_mm_embeddings(scheduler_output)
|
|
|
|
# NOTE(woosuk): To unify token ids and soft tokens (vision
|
|
# embeddings), we always use embeddings (rather than token ids)
|
|
# as input to the multimodal model, even when the input is text.
|
|
input_ids = self.input_ids[:total_num_scheduled_tokens]
|
|
if mm_embeds:
|
|
inputs_embeds = self.model.get_input_embeddings(
|
|
input_ids, mm_embeds)
|
|
else:
|
|
inputs_embeds = self.model.get_input_embeddings(input_ids)
|
|
# TODO(woosuk): Avoid the copy. Optimize.
|
|
self.inputs_embeds[:total_num_scheduled_tokens].copy_(
|
|
inputs_embeds)
|
|
inputs_embeds = self.inputs_embeds[:num_input_tokens]
|
|
input_ids = None
|
|
else:
|
|
# For text-only models, we use token ids as input.
|
|
# While it is possible to use embeddings as input just like the
|
|
# multimodal models, it is not desirable for performance since
|
|
# then the embedding layer is not included in the ACL graph.
|
|
input_ids = self.input_ids[:num_input_tokens]
|
|
inputs_embeds = None
|
|
positions = self.positions[:num_input_tokens]
|
|
input_ids, positions = self._update_input_ids_and_positions(
|
|
input_ids, positions, num_input_tokens, with_prefill,
|
|
maybe_padded_num_tokens)
|
|
|
|
if get_pp_group().is_first_rank:
|
|
intermediate_tensors = None
|
|
else:
|
|
assert intermediate_tensors is not None
|
|
assert self.intermediate_tensors is not None
|
|
for k, v in intermediate_tensors.items():
|
|
self.intermediate_tensors[k][:num_input_tokens].copy_(
|
|
v[:num_input_tokens], non_blocking=True)
|
|
intermediate_tensors = IntermediateTensors({
|
|
k: v[:num_input_tokens]
|
|
for k, v in self.intermediate_tensors.items()
|
|
})
|
|
|
|
use_spec_decode = len(
|
|
scheduler_output.scheduled_spec_decode_tokens) > 0
|
|
if not use_spec_decode:
|
|
# NOTE(woosuk): Due to chunked prefills, the batch may contain
|
|
# partial requests. While we should not sample any token
|
|
# from these partial requests, we do so for simplicity.
|
|
# We will ignore the sampled tokens from the partial requests.
|
|
# TODO: Support prompt logprobs.
|
|
spec_decode_metadata = None
|
|
logits_indices = torch.from_numpy(cu_num_tokens - 1).to(
|
|
self.device, non_blocking=True)
|
|
else:
|
|
# Get the number of draft tokens for each request.
|
|
# Iterate over the dictionary rather than all requests since not all
|
|
# requests have draft tokens.
|
|
num_draft_tokens = np.zeros(num_reqs, dtype=np.int32)
|
|
for req_id, draft_token_ids in (
|
|
scheduler_output.scheduled_spec_decode_tokens.items()):
|
|
req_idx = self.input_batch.req_id_to_index[req_id]
|
|
num_draft_tokens[req_idx] = len(draft_token_ids)
|
|
|
|
spec_decode_metadata = self._calc_spec_decode_metadata(
|
|
num_draft_tokens, cu_num_tokens)
|
|
logits_indices = spec_decode_metadata.logits_indices
|
|
self.num_draft_tokens.np[:num_reqs] = num_draft_tokens
|
|
self.num_draft_tokens.np[num_reqs:].fill(0)
|
|
self.num_draft_tokens.copy_to_gpu()
|
|
|
|
# Used in the below loop.
|
|
# query_start_loc_cpu = self.query_start_loc.cpu[:num_reqs + 1]
|
|
num_computed_tokens_cpu = (
|
|
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs])
|
|
spec_decode_common_attn_metadata = None
|
|
if use_spec_decode and self.need_accepted_tokens:
|
|
self.num_accepted_tokens.np[:num_reqs] = (
|
|
self.input_batch.num_accepted_tokens_cpu[:num_reqs])
|
|
self.num_accepted_tokens.np[num_reqs:].fill(1)
|
|
self.num_accepted_tokens.copy_to_gpu()
|
|
|
|
# Prepare the attention metadata for each KV cache group and make layers
|
|
# in the same group share the same metadata.
|
|
for kv_cache_group_id, kv_cache_group_spec in enumerate(
|
|
self.kv_cache_config.kv_cache_groups):
|
|
blk_table = self.input_batch.block_table[kv_cache_group_id]
|
|
blk_table_tensor = blk_table.get_device_tensor()
|
|
slot_mapping = blk_table.slot_mapping_cpu[:
|
|
total_num_scheduled_tokens]
|
|
self.slot_mapping[:total_num_scheduled_tokens].copy_(
|
|
slot_mapping[:total_num_scheduled_tokens],
|
|
non_blocking=True,
|
|
)
|
|
|
|
# Make AscendCommonAttentionMetadata
|
|
common_attn_metadata = AscendCommonAttentionMetadata(
|
|
query_start_loc=self.query_start_loc[:num_reqs + 1],
|
|
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1],
|
|
seq_lens_cpu=self.seq_lens_cpu,
|
|
seq_lens=self.seq_lens_cpu[:num_reqs],
|
|
num_reqs=num_reqs,
|
|
num_actual_tokens=total_num_scheduled_tokens,
|
|
actual_seq_lengths_q=self.actual_seq_lengths_q,
|
|
# TODO: change this to the right block table for linear attn
|
|
block_table_tensor=blk_table_tensor[:num_reqs],
|
|
slot_mapping=self.slot_mapping,
|
|
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
|
positions=self.positions,
|
|
attn_mask=self.attn_mask,
|
|
spec_attn_mask=self.spec_attn_mask,
|
|
attn_state=self.attn_state,
|
|
enable_dbo_across_dp=enable_dbo,
|
|
is_only_prefill=bool(np.all(num_valid_tokens != 1)),
|
|
max_query_len=max_num_scheduled_tokens,
|
|
graph_pad_size=self.graph_pad_size,
|
|
decode_token_per_req=self.decode_token_per_req,
|
|
cos=self.cos,
|
|
sin=self.sin,
|
|
)
|
|
|
|
if self.speculative_config and \
|
|
spec_decode_common_attn_metadata is None:
|
|
spec_decode_common_attn_metadata = common_attn_metadata
|
|
|
|
for attn_group in self.attn_groups[kv_cache_group_id]:
|
|
common_prefix_len = 0
|
|
extra_attn_metadata_args = {}
|
|
builder = attn_group.get_metadata_builder()
|
|
if isinstance(builder, GDNAttentionMetadataBuilder):
|
|
if use_spec_decode:
|
|
extra_attn_metadata_args = dict(
|
|
num_accepted_tokens=self.num_accepted_tokens.
|
|
gpu[:num_reqs],
|
|
num_draft_tokens=self.num_draft_tokens.
|
|
gpu[:num_reqs],
|
|
)
|
|
attn_metadata_i = builder.build(
|
|
common_prefix_len=common_prefix_len,
|
|
common_attn_metadata=common_attn_metadata,
|
|
**extra_attn_metadata_args)
|
|
else:
|
|
attn_metadata_i = builder.build(
|
|
common_prefix_len=common_prefix_len,
|
|
common_attn_metadata=common_attn_metadata,
|
|
model=self.get_model(),
|
|
**extra_attn_metadata_args)
|
|
|
|
if self.vllm_config.model_config.use_mla or self.ascend_config.use_sfa:
|
|
attn_metadata_i.num_input_tokens = num_input_tokens
|
|
for layer_name in attn_group.layer_names:
|
|
attn_metadata[layer_name] = attn_metadata_i
|
|
|
|
if lmhead_tp_enable():
|
|
max_num_reqs_across_dp = maybe_padded_num_tokens if not with_prefill else self.max_num_reqs
|
|
logits_indices = nn.functional.pad(
|
|
logits_indices,
|
|
(0, max_num_reqs_across_dp - logits_indices.shape[0]))
|
|
|
|
return (attn_metadata, positions, num_scheduled_tokens,
|
|
num_input_tokens, num_tokens_across_dp,
|
|
maybe_padded_num_tokens, logits_indices, spec_decode_metadata,
|
|
input_ids, inputs_embeds, intermediate_tensors,
|
|
max_num_scheduled_tokens)
|
|
|
|
def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill,
|
|
maybe_padded_num_tokens,
|
|
input_ids, positions,
|
|
intermediate_tensors,
|
|
inputs_embeds):
|
|
assert self.model is not None
|
|
hidden_states = self.model(
|
|
input_ids=input_ids,
|
|
positions=positions,
|
|
intermediate_tensors=intermediate_tensors,
|
|
inputs_embeds=inputs_embeds,
|
|
)
|
|
|
|
forward_context = get_forward_context()
|
|
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
|
|
# TODO: maybe_padded_num_tokens will be removed, use num_input_tokens instead
|
|
if self.vllm_config.model_config.use_mla:
|
|
# FIXME: Try using `auto_dispatch_capture=True`
|
|
update_mla_attn_params(self.update_stream, forward_context,
|
|
maybe_padded_num_tokens)
|
|
else:
|
|
update_attn_params(self.update_stream, forward_context,
|
|
maybe_padded_num_tokens)
|
|
|
|
if get_forward_context().sp_enabled:
|
|
hidden_states = tensor_model_parallel_all_gather(hidden_states, 0)
|
|
pad_size = get_forward_context().pad_size
|
|
if pad_size > 0:
|
|
hidden_states = hidden_states[:-pad_size, :]
|
|
return hidden_states
|
|
|
|
def _build_attn_state(self, num_reqs, num_scheduled_tokens,
|
|
num_valid_tokens):
|
|
ascend_config = get_ascend_config()
|
|
if np.array_equal(self.seq_lens_np[:num_reqs], num_scheduled_tokens):
|
|
attn_state = AscendAttentionState.PrefillNoCache
|
|
# We assume it is the decode stage, where prefill occurs but only one token is not hit in cache.
|
|
elif np.all(num_scheduled_tokens == 1):
|
|
attn_state = AscendAttentionState.DecodeOnly
|
|
if self.speculative_config and self.speculative_config.method == 'deepseek_mtp':
|
|
# SpecDecoding now supports seq_len=1 and seq_len=2
|
|
# In Prefilling Decoding Disaggregation scenario, SpecDecoding need to supports seq_len=1
|
|
attn_state = AscendAttentionState.SpecDecoding
|
|
# Speculative decoding.
|
|
elif np.all(num_valid_tokens == 1):
|
|
if self.drafter and (self.drafter.name == SpecDcodeType.EAGLE
|
|
or self.drafter.name == SpecDcodeType.EAGLE3):
|
|
attn_state = AscendAttentionState.ChunkedPrefill
|
|
else:
|
|
attn_state = AscendAttentionState.SpecDecoding
|
|
# splitfuse
|
|
elif not ascend_config.ascend_scheduler_config.enabled or self.chunked_prefill_enabled:
|
|
attn_state = AscendAttentionState.ChunkedPrefill
|
|
else:
|
|
attn_state = AscendAttentionState.PrefillCacheHit
|
|
return attn_state
|
|
|
|
def _update_graph_pad_size(self, with_prefill, graph_pad_size):
|
|
self.graph_pad_size = -1
|
|
|
|
def _update_input_ids_and_positions(self, input_ids, positions,
|
|
num_input_tokens, with_prefill,
|
|
maybe_padded_num_tokens):
|
|
if self.uses_mrope:
|
|
positions = self.mrope_positions[:, :num_input_tokens]
|
|
return input_ids, positions
|
|
|
|
def _calc_spec_decode_metadata(
|
|
self,
|
|
num_draft_tokens: np.ndarray,
|
|
cu_num_scheduled_tokens: np.ndarray,
|
|
) -> SpecDecodeMetadata:
|
|
# Inputs:
|
|
# cu_num_scheduled_tokens: [ 4, 104, 107, 207, 209]
|
|
# num_draft_tokens: [ 3, 0, 2, 0, 1]
|
|
# Outputs:
|
|
# cu_num_draft_tokens: [ 3, 3, 5, 5, 6]
|
|
# logits_indices: [ 0, 1, 2, 3, 103, 104, 105, 106,
|
|
# 206, 207, 208]
|
|
# target_logits_indices: [ 0, 1, 2, 5, 6, 9]
|
|
# bonus_logits_indices: [ 3, 4, 7, 8, 10]
|
|
|
|
# Compute the logits indices.
|
|
# [4, 1, 3, 1, 2]
|
|
num_sampled_tokens = num_draft_tokens + 1
|
|
# Step 1. [4, 5, 8, 9, 11]
|
|
cu_num_sampled_tokens = np.cumsum(num_sampled_tokens, dtype=np.int32)
|
|
total_num_sampled_tokens = cu_num_sampled_tokens[-1]
|
|
# Step 2. [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9]
|
|
cumsums_offsets = np.repeat(cu_num_sampled_tokens - num_sampled_tokens,
|
|
num_sampled_tokens)
|
|
# Step 3. [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1]
|
|
arange = self.arange_np[:total_num_sampled_tokens] - cumsums_offsets
|
|
# Step 4. [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207]
|
|
logits_indices = np.repeat(
|
|
cu_num_scheduled_tokens - num_sampled_tokens, num_sampled_tokens)
|
|
# Step 5. [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208]
|
|
logits_indices += arange
|
|
|
|
# Compute the bonus logits indices.
|
|
bonus_logits_indices = cu_num_sampled_tokens - 1
|
|
|
|
# Compute the draft logits indices.
|
|
# [3, 3, 5, 5, 6]
|
|
cu_num_draft_tokens = np.cumsum(num_draft_tokens, dtype=np.int32)
|
|
total_num_draft_tokens = cu_num_draft_tokens[-1]
|
|
# [0, 0, 0, 3, 3, 5]
|
|
cumsums_offsets = np.repeat(cu_num_draft_tokens - num_draft_tokens,
|
|
num_draft_tokens)
|
|
# [0, 1, 2, 0, 1, 0]
|
|
arange = self.arange_np[:total_num_draft_tokens] - cumsums_offsets
|
|
# [0, 0, 0, 5, 5, 9]
|
|
target_logits_indices = np.repeat(
|
|
cu_num_sampled_tokens - num_sampled_tokens, num_draft_tokens)
|
|
# [0, 1, 2, 5, 6, 9]
|
|
target_logits_indices += arange
|
|
|
|
# TODO: Optimize the CPU -> NPU copy.
|
|
cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).to(
|
|
self.device, non_blocking=True)
|
|
logits_indices = torch.from_numpy(logits_indices).to(self.device,
|
|
non_blocking=True)
|
|
target_logits_indices = torch.from_numpy(target_logits_indices).to(
|
|
self.device, non_blocking=True)
|
|
bonus_logits_indices = torch.from_numpy(bonus_logits_indices).to(
|
|
self.device, non_blocking=True)
|
|
|
|
# Compute the draft token ids.
|
|
# draft_token_indices: [ 1, 2, 3, 105, 106, 208]
|
|
draft_token_ids = self.input_ids[logits_indices]
|
|
draft_token_ids = draft_token_ids[target_logits_indices + 1]
|
|
|
|
metadata = SpecDecodeMetadata(
|
|
draft_token_ids=draft_token_ids,
|
|
num_draft_tokens=num_draft_tokens.tolist(),
|
|
cu_num_draft_tokens=cu_num_draft_tokens,
|
|
target_logits_indices=target_logits_indices,
|
|
bonus_logits_indices=bonus_logits_indices,
|
|
logits_indices=logits_indices,
|
|
)
|
|
return metadata
|
|
|
|
def apply_grammar_bitmask(
|
|
self,
|
|
scheduler_output: "SchedulerOutput",
|
|
logits: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
grammar_bitmask = scheduler_output.grammar_bitmask
|
|
|
|
# We receive the structured output bitmask from the scheduler,
|
|
# compacted to contain bitmasks only for structured output requests.
|
|
# The order of the requests in the bitmask is not guaranteed to be the
|
|
# same as the order of the requests in the gpu runner's batch. We need
|
|
# to sort the bitmask to match the order of the requests used here.
|
|
|
|
# Get the batch indices of the structured output requests.
|
|
# Keep track of the number of speculative tokens scheduled for every
|
|
# request in the batch, as the logit indices are offset by this amount.
|
|
struct_out_req_batch_indices: dict[str, int] = {}
|
|
cumulative_offset = 0
|
|
seq = sorted(self.input_batch.req_id_to_index.items(),
|
|
key=lambda x: x[1])
|
|
for req_id, batch_index in seq:
|
|
logit_index = batch_index + cumulative_offset
|
|
cumulative_offset += len(
|
|
scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
|
|
if req_id in scheduler_output.structured_output_request_ids:
|
|
struct_out_req_batch_indices[req_id] = logit_index
|
|
|
|
out_indices = []
|
|
|
|
# Reorder the bitmask to match the order of the requests in the batch.
|
|
sorted_bitmask = np.zeros_like(grammar_bitmask,
|
|
shape=(logits.shape[0],
|
|
grammar_bitmask.shape[1]))
|
|
cumulative_index = 0
|
|
seq = sorted(scheduler_output.structured_output_request_ids.items(),
|
|
key=lambda x: x[1])
|
|
for req_id, _ in seq:
|
|
logit_index = struct_out_req_batch_indices[req_id]
|
|
num_spec_tokens = len(
|
|
scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
|
|
for i in range(1 + num_spec_tokens):
|
|
sorted_bitmask[logit_index + i] = \
|
|
grammar_bitmask[cumulative_index + i]
|
|
out_indices.append(logit_index + i)
|
|
cumulative_index += 1 + num_spec_tokens
|
|
grammar_bitmask = sorted_bitmask
|
|
|
|
# Serialization of np.ndarray is much more efficient than a tensor,
|
|
# so we receive it in that format.
|
|
grammar_bitmask = torch.from_numpy(grammar_bitmask)
|
|
|
|
# NOTE:
|
|
# 1. XGrammar bitmask applying only supports CPU and GPU.
|
|
# 2. The logits and bitmask should be on the same device.
|
|
# 3. XGrammar logits on CPU only supports float32 dtype.
|
|
logits_dtype = logits.dtype
|
|
logits = logits.to("cpu").float()
|
|
xgr.apply_token_bitmask_inplace(
|
|
logits,
|
|
grammar_bitmask,
|
|
indices=out_indices,
|
|
)
|
|
return logits.to(self.device).to(logits_dtype)
|
|
|
|
def propose_draft_token_ids(
|
|
self,
|
|
valid_sampled_token_ids: list[list[int]],
|
|
sampling_metadata: SamplingMetadata,
|
|
scheduler_output: "SchedulerOutput",
|
|
spec_decode_metadata: SpecDecodeMetadata,
|
|
positions: torch.Tensor,
|
|
num_scheduled_tokens: int,
|
|
hidden_states: torch.Tensor,
|
|
attn_metadata: dict[str, Any],
|
|
aux_hidden_states: torch.Tensor = None,
|
|
) -> Optional[list[list[int]]]:
|
|
if not self.drafter:
|
|
# Speculative decoding is not enabled.
|
|
draft_token_ids = None
|
|
else:
|
|
draft_token_ids = self.drafter.generate_token_ids(
|
|
valid_sampled_token_ids, sampling_metadata, scheduler_output,
|
|
spec_decode_metadata, positions, num_scheduled_tokens,
|
|
hidden_states, attn_metadata, aux_hidden_states)
|
|
return draft_token_ids
|
|
|
|
def _pool(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
num_scheduled_tokens: int,
|
|
num_scheduled_tokens_np: np.ndarray,
|
|
finished_sending: Optional[set[str]] = None,
|
|
finished_recving: Optional[set[str]] = None,
|
|
kv_connector_output: Optional["KVConnectorOutput"] = None,
|
|
) -> ModelRunnerOutput:
|
|
assert self.input_batch.num_reqs ==\
|
|
len(self.input_batch.pooling_params), \
|
|
"Either all or none of the requests in" \
|
|
" a batch must be pooling request"
|
|
|
|
hidden_states = hidden_states[:num_scheduled_tokens]
|
|
pooling_metadata = self.input_batch.pooling_metadata
|
|
pooling_metadata.build_pooling_cursor(num_scheduled_tokens_np.tolist(),
|
|
device=hidden_states.device)
|
|
seq_lens_cpu = self.seq_lens_cpu[:self.input_batch.num_reqs]
|
|
|
|
model = cast(VllmModelForPooling, self.model)
|
|
raw_pooler_output = model.pooler(
|
|
hidden_states=hidden_states,
|
|
pooling_metadata=pooling_metadata,
|
|
)
|
|
raw_pooler_output = json_map_leaves(
|
|
lambda x: x.to("cpu", non_blocking=True),
|
|
raw_pooler_output,
|
|
)
|
|
torch.npu.synchronize()
|
|
|
|
pooler_output: list[Optional[torch.Tensor]] = []
|
|
for raw_output, seq_len, prompt_len in zip(
|
|
raw_pooler_output, seq_lens_cpu, pooling_metadata.prompt_lens):
|
|
output = raw_output if seq_len == prompt_len else None
|
|
pooler_output.append(output)
|
|
|
|
return ModelRunnerOutput(
|
|
req_ids=self.input_batch.req_ids,
|
|
req_id_to_index=self.input_batch.req_id_to_index,
|
|
sampled_token_ids=[],
|
|
logprobs=None,
|
|
prompt_logprobs_dict={},
|
|
pooler_output=pooler_output,
|
|
kv_connector_output=kv_connector_output,
|
|
)
|
|
|
|
def _select_moe_comm_method(self, num_tokens: int,
|
|
with_prefill: bool) -> MoECommType:
|
|
"""1. If expert parallel is not enabled, we use all-gather since MC2 and all-to-all
|
|
are designed for expert parallelism.
|
|
2. If expert parallel is enabled, we need to consider the soc version and the
|
|
number of tokens. This is based on the observation that all-gather is more
|
|
efficient than all-to-all when running on A2.
|
|
|
|
a. For A2, we choose from MC2 and all-gather.
|
|
|
|
b. For A3, we choose from MC2 and all-to-all.
|
|
|
|
In both cases, we use MC2 when the number of tokens is smaller than
|
|
a its capacity threshold.
|
|
|
|
Args:
|
|
num_tokens (int): The number of tokens in the current batch.
|
|
|
|
Raises:
|
|
ValueError: If the soc version is unsupported.
|
|
|
|
Returns:
|
|
MoECommType: The selected MoE communication method.
|
|
"""
|
|
soc_version = get_ascend_soc_version()
|
|
quant_type = getattr(self.vllm_config.model_config.hf_config,
|
|
'moe_quantize', None)
|
|
model_type = self.vllm_config.model_config.hf_config.model_type
|
|
|
|
if not self.parallel_config.enable_expert_parallel:
|
|
moe_comm_type = MoECommType.ALLGATHER
|
|
elif soc_version in {AscendSocVersion.A2}:
|
|
if (num_tokens <= self.mc2_tokens_capacity
|
|
and self.parallel_config.world_size_across_dp >= 16):
|
|
moe_comm_type = MoECommType.MC2
|
|
else:
|
|
# Currently, w4a8_dynamic does not support allgatherep
|
|
if quant_type == "w4a8_dynamic":
|
|
moe_comm_type = MoECommType.ALLTOALL
|
|
else:
|
|
moe_comm_type = MoECommType.ALLGATHER
|
|
|
|
elif soc_version in {AscendSocVersion.A3}:
|
|
moe_comm_type = (MoECommType.MC2
|
|
if num_tokens <= self.mc2_tokens_capacity else
|
|
MoECommType.ALLTOALL)
|
|
else:
|
|
raise ValueError(f"Unsupported soc_version: {soc_version}")
|
|
|
|
if moe_comm_type == MoECommType.ALLGATHER and with_prefill:
|
|
moe_comm_type = MoECommType.NAIVE_MULTICAST
|
|
|
|
# PanguProMoE only supports allgather
|
|
if model_type == "PanguProMoE":
|
|
moe_comm_type = MoECommType.ALLGATHER
|
|
|
|
if is_global_first_rank():
|
|
logger.debug(f"num_tokens: {num_tokens}, "
|
|
f"moe_comm_type: {moe_comm_type}")
|
|
return moe_comm_type
|
|
|
|
@torch.inference_mode()
|
|
def execute_model(
|
|
self,
|
|
scheduler_output: "SchedulerOutput",
|
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
) -> Union[ModelRunnerOutput, AsyncModelRunnerOutput, IntermediateTensors]:
|
|
with ProfileExecuteDuration().capture_async("prepare input"):
|
|
self._update_states(scheduler_output)
|
|
if not scheduler_output.total_num_scheduled_tokens:
|
|
if not has_kv_transfer_group():
|
|
logger.debug(
|
|
"skip this step for we receive the data from remote disaggregate prefill node"
|
|
)
|
|
# Return empty ModelRunnerOuptut if there's no work to do.
|
|
return EMPTY_MODEL_RUNNER_OUTPUT
|
|
return self.kv_connector_no_forward(scheduler_output)
|
|
|
|
if self.dynamic_eplb:
|
|
self.eplb_updator.forward_before()
|
|
|
|
(attn_metadata, positions, num_scheduled_tokens_np,
|
|
num_input_tokens, num_tokens_across_dp, maybe_padded_num_tokens,
|
|
logits_indices, spec_decode_metadata, input_ids, inputs_embeds,
|
|
intermediate_tensors,
|
|
max_query_len) = (self._prepare_inputs(scheduler_output,
|
|
intermediate_tensors))
|
|
|
|
if self.dynamic_eplb:
|
|
self.eplb_updator.take_update_info_from_eplb_process()
|
|
|
|
moe_comm_type = self._select_moe_comm_method(num_input_tokens,
|
|
self.with_prefill)
|
|
|
|
uniform_decode = (max_query_len == self.uniform_decode_query_len) and (
|
|
scheduler_output.total_num_scheduled_tokens
|
|
== self.input_batch.num_reqs * max_query_len)
|
|
batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens,
|
|
uniform_decode=uniform_decode)
|
|
aclgraph_runtime_mode, batch_descriptor = \
|
|
self.aclgraph_dispatcher.dispatch(batch_descriptor)
|
|
|
|
# Run forward pass
|
|
with ProfileExecuteDuration().capture_async("forward"):
|
|
with set_ascend_forward_context(
|
|
attn_metadata,
|
|
self.vllm_config,
|
|
num_tokens=num_input_tokens,
|
|
num_tokens_across_dp=num_tokens_across_dp,
|
|
with_prefill=self.with_prefill,
|
|
reserved_mc2_mask=self.reserved_mc2_mask,
|
|
moe_comm_type=moe_comm_type,
|
|
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
|
batch_descriptor=batch_descriptor,
|
|
num_actual_tokens=scheduler_output.
|
|
total_num_scheduled_tokens,
|
|
prefetch_stream=self.prefetch_stream,
|
|
model_instance=self.model,
|
|
weight_prefetch_method=self.weight_prefetch_method):
|
|
self.maybe_setup_kv_connector(scheduler_output)
|
|
|
|
hidden_states = self._generate_process_reqs_hidden_states(
|
|
attn_metadata, self.with_prefill, maybe_padded_num_tokens,
|
|
input_ids, positions, intermediate_tensors, inputs_embeds)
|
|
|
|
self.maybe_wait_for_kv_save()
|
|
finished_sending, finished_recving = self.get_finished_kv_transfer(
|
|
scheduler_output)
|
|
|
|
aux_hidden_states = None
|
|
if self.drafter and self.drafter.name == SpecDcodeType.EAGLE3:
|
|
hidden_states, aux_hidden_states = hidden_states
|
|
|
|
kv_connector_output = KVConnectorOutput(
|
|
finished_sending=finished_sending,
|
|
finished_recving=finished_recving)
|
|
finished_sending = None
|
|
finished_recving = None
|
|
with ProfileExecuteDuration().capture_async("post process"):
|
|
# Broadcast PP output for external_launcher (torchrun)
|
|
# to make sure we are synced across pp ranks
|
|
# TODO: Support overlapping mirco-batches
|
|
# https://github.com/vllm-project/vllm/issues/18019
|
|
broadcast_pp_output = \
|
|
self.parallel_config.distributed_executor_backend \
|
|
== "external_launcher" and len(get_pp_group().ranks) > 0
|
|
if not get_pp_group().is_last_rank:
|
|
# For mid-pipeline stages, return the hidden states.
|
|
if not broadcast_pp_output:
|
|
hidden_states.kv_connector_output = kv_connector_output
|
|
return hidden_states
|
|
assert isinstance(hidden_states, IntermediateTensors)
|
|
get_pp_group().send_tensor_dict(
|
|
hidden_states.tensors, all_gather_group=get_tp_group())
|
|
logits = None
|
|
else:
|
|
if self.input_batch.pooling_params:
|
|
return self._pool(
|
|
hidden_states,
|
|
scheduler_output.total_num_scheduled_tokens,
|
|
num_scheduled_tokens_np, finished_sending,
|
|
finished_recving, kv_connector_output)
|
|
sample_hidden_states = hidden_states[logits_indices]
|
|
logits = self.model.compute_logits(sample_hidden_states)
|
|
if broadcast_pp_output:
|
|
model_output_broadcast_data = {
|
|
"logits": logits.contiguous(),
|
|
} if logits is not None else {}
|
|
model_output_broadcast_data = get_pp_group(
|
|
).broadcast_tensor_dict(model_output_broadcast_data,
|
|
src=len(get_pp_group().ranks) - 1)
|
|
assert model_output_broadcast_data is not None
|
|
logits = model_output_broadcast_data["logits"]
|
|
|
|
# Apply structured output bitmasks if present
|
|
if scheduler_output.grammar_bitmask is not None:
|
|
logits = self.apply_grammar_bitmask(scheduler_output, logits)
|
|
|
|
# Sample the next token and get logprobs if needed.
|
|
sampling_metadata = self.input_batch.sampling_metadata
|
|
if spec_decode_metadata is None:
|
|
if lmhead_tp_enable() and logits is not None:
|
|
logits = logits[:self.input_batch.num_reqs]
|
|
sampler_output = self.sampler(
|
|
logits=logits,
|
|
sampling_metadata=sampling_metadata,
|
|
)
|
|
else:
|
|
if lmhead_tp_enable() and logits is not None:
|
|
logits = logits[:len(spec_decode_metadata.logits_indices)]
|
|
# When indexing with a tensor (bonus_logits_indices), PyTorch
|
|
# creates a new tensor with separate storage from the original
|
|
# logits tensor. This means any in-place operations on bonus_logits
|
|
# won't affect the original logits tensor.
|
|
assert logits is not None
|
|
bonus_logits = logits[
|
|
spec_decode_metadata.bonus_logits_indices]
|
|
sampler_output = self.sampler(
|
|
logits=bonus_logits,
|
|
sampling_metadata=sampling_metadata,
|
|
)
|
|
bonus_token_ids = sampler_output.sampled_token_ids
|
|
|
|
# Just like `bonus_logits`, `target_logits` is a new tensor with
|
|
# separate storage from the original `logits` tensor. Therefore,
|
|
# it is safe to update `target_logits` in place.
|
|
target_logits = logits[
|
|
spec_decode_metadata.target_logits_indices]
|
|
output_token_ids = self.rejection_sampler(
|
|
spec_decode_metadata,
|
|
None, # draft_probs
|
|
target_logits,
|
|
bonus_token_ids,
|
|
sampling_metadata,
|
|
)
|
|
sampler_output.sampled_token_ids = output_token_ids
|
|
if self.need_accepted_tokens:
|
|
self._update_states_after_model_execute(output_token_ids)
|
|
|
|
discard_sampled_tokens_req_indices: list[int] = []
|
|
# TODO(woosuk): The following loop can be slow since it iterates over
|
|
# the requests one by one. Optimize.
|
|
discard_sampled_tokens_req_indices = []
|
|
for i, req_id in enumerate(self.input_batch.req_ids):
|
|
req_state = self.requests[req_id]
|
|
seq_len = (req_state.num_computed_tokens +
|
|
scheduler_output.num_scheduled_tokens[req_id])
|
|
if seq_len < req_state.num_tokens:
|
|
# Ignore the sampled token.
|
|
# Rewind the generator state as if the token was not sampled.
|
|
generator = self.input_batch.generators.get(i)
|
|
if generator is not None:
|
|
generator.set_offset(generator.get_offset() - 4)
|
|
discard_sampled_tokens_req_indices.append(i)
|
|
|
|
# Copy some objects so they don't get modified after returning.
|
|
# This is important when using async scheduling.
|
|
req_ids_output_copy = self.input_batch.req_ids.copy()
|
|
req_id_to_index_output_copy = \
|
|
self.input_batch.req_id_to_index.copy()
|
|
|
|
# NOTE: NPU -> CPU Sync happens here.
|
|
# Move as many CPU operations as possible before this sync point.
|
|
logprobs_tensors = sampler_output.logprobs_tensors
|
|
logprobs_lists = logprobs_tensors.tolists() \
|
|
if logprobs_tensors is not None else None
|
|
|
|
# Compute prompt logprobs if needed.
|
|
prompt_logprobs_dict = self._get_prompt_logprobs_dict(
|
|
hidden_states[:scheduler_output.total_num_scheduled_tokens],
|
|
scheduler_output,
|
|
)
|
|
|
|
num_sampled_tokens = sampler_output.sampled_token_ids.shape[0]
|
|
sampled_token_ids = sampler_output.sampled_token_ids
|
|
if not self.use_async_scheduling:
|
|
# Get the valid generated tokens.
|
|
max_gen_len = sampled_token_ids.shape[-1]
|
|
if max_gen_len == 1:
|
|
# No spec decode tokens.
|
|
valid_sampled_token_ids = sampled_token_ids.tolist()
|
|
else:
|
|
# Includes spec decode tokens.
|
|
valid_sampled_token_ids = self.rejection_sampler.parse_output(
|
|
sampled_token_ids,
|
|
self.input_batch.vocab_size,
|
|
)
|
|
# Mask out the sampled tokens that should not be sampled.
|
|
for i in discard_sampled_tokens_req_indices:
|
|
valid_sampled_token_ids[i].clear()
|
|
else:
|
|
valid_sampled_token_ids = []
|
|
invalid_req_indices = list(discard_sampled_tokens_req_indices)
|
|
invalid_req_indices_set = set(invalid_req_indices)
|
|
assert sampled_token_ids.shape[-1] == 1
|
|
|
|
# Cache the sampled tokens on the NPU and avoid CPU sync.
|
|
# These will be copied into input_ids in the next step
|
|
# when preparing inputs.
|
|
self.input_batch.prev_sampled_token_ids = \
|
|
sampled_token_ids
|
|
self.input_batch.prev_sampled_token_ids_invalid_indices = \
|
|
invalid_req_indices_set
|
|
self.input_batch.prev_req_id_to_index = {
|
|
req_id: i
|
|
for i, req_id in enumerate(self.input_batch.req_ids)
|
|
if i not in invalid_req_indices_set
|
|
}
|
|
# Cache the sampled tokens in the model runner, so that the scheduler
|
|
# doesn't need to send them back.
|
|
# NOTE(woosuk): As an exception, when using PP, the scheduler sends
|
|
# the sampled tokens back, because there's no direct communication
|
|
# between the first-stage worker and the last-stage worker.
|
|
for req_idx in range(num_sampled_tokens):
|
|
if self.use_async_scheduling:
|
|
sampled_ids = [-1] * 1 if \
|
|
req_idx not in invalid_req_indices_set else None
|
|
else:
|
|
sampled_ids = valid_sampled_token_ids[req_idx]
|
|
if not sampled_ids:
|
|
continue
|
|
|
|
start_idx = self.input_batch.num_tokens_no_spec[req_idx]
|
|
end_idx = start_idx + len(sampled_ids)
|
|
assert end_idx <= self.model_config.max_model_len, (
|
|
"Sampled token IDs exceed the max model length. "
|
|
f"Total number of tokens: {end_idx} > max_model_len: "
|
|
f"{self.model_config.max_model_len}")
|
|
|
|
self.input_batch.token_ids_cpu[req_idx,
|
|
start_idx:end_idx] = sampled_ids
|
|
self.input_batch.num_tokens_no_spec[req_idx] = end_idx
|
|
self.input_batch.num_tokens[req_idx] = end_idx
|
|
req_id = self.input_batch.req_ids[req_idx]
|
|
req_state = self.requests[req_id]
|
|
req_state.output_token_ids.extend(sampled_ids)
|
|
|
|
if self.speculative_config:
|
|
self._draft_token_ids = self.propose_draft_token_ids(
|
|
valid_sampled_token_ids,
|
|
sampling_metadata,
|
|
scheduler_output,
|
|
spec_decode_metadata,
|
|
positions,
|
|
scheduler_output.total_num_scheduled_tokens,
|
|
hidden_states,
|
|
attn_metadata,
|
|
aux_hidden_states,
|
|
)
|
|
|
|
if has_kv_transfer_group():
|
|
get_kv_transfer_group().clear_connector_metadata()
|
|
|
|
extra_args = ({"kv_connector_output": kv_connector_output})
|
|
|
|
model_runner_output = ModelRunnerOutput(
|
|
req_ids=req_ids_output_copy,
|
|
req_id_to_index=req_id_to_index_output_copy,
|
|
sampled_token_ids=valid_sampled_token_ids,
|
|
logprobs=logprobs_lists,
|
|
prompt_logprobs_dict=prompt_logprobs_dict,
|
|
pooler_output=[],
|
|
**extra_args,
|
|
)
|
|
|
|
durations = ProfileExecuteDuration().pop_captured_sync()
|
|
if durations:
|
|
dr_str = [
|
|
f"[{tag}]:{duration:.2f}ms"
|
|
for tag, duration in durations.items()
|
|
]
|
|
captured_name = "Decode" if self.attn_state == AscendAttentionState.DecodeOnly else "Prefill"
|
|
logger.info("Profile execute duration [%s]:%s", captured_name,
|
|
" ".join(dr_str))
|
|
if self.dynamic_eplb:
|
|
self.eplb_updator.forward_end()
|
|
if not self.use_async_scheduling:
|
|
return model_runner_output
|
|
|
|
return AsyncNPUModelRunnerOutput(
|
|
model_runner_output=model_runner_output,
|
|
sampled_token_ids=sampled_token_ids,
|
|
invalid_req_indices=invalid_req_indices,
|
|
async_output_copy_stream=self.async_output_copy_stream,
|
|
)
|
|
|
|
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
|
|
if self._draft_token_ids is None:
|
|
return None
|
|
req_ids = self.input_batch.req_ids
|
|
if isinstance(self._draft_token_ids, torch.Tensor):
|
|
draft_token_ids = self._draft_token_ids.tolist()
|
|
else:
|
|
draft_token_ids = self._draft_token_ids
|
|
self._draft_token_ids = None
|
|
return DraftTokenIds(req_ids, draft_token_ids)
|
|
|
|
def kv_connector_no_forward(
|
|
self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput:
|
|
with set_ascend_forward_context(None, self.vllm_config):
|
|
self.maybe_setup_kv_connector(scheduler_output)
|
|
finished_sending, finished_recving = (
|
|
self.get_finished_kv_transfer(scheduler_output))
|
|
# For the case of no forward caused by receiving remote kv,
|
|
# one round of dummy inference is necessary
|
|
# to prevent hang over the collective calls.
|
|
|
|
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
|
|
output.kv_connector_output = KVConnectorOutput(
|
|
finished_sending=finished_sending,
|
|
finished_recving=finished_recving)
|
|
return output
|
|
|
|
@staticmethod
|
|
def maybe_setup_kv_connector(scheduler_output: "SchedulerOutput"):
|
|
# Update KVConnector with the KVConnector metadata forward().
|
|
if has_kv_transfer_group():
|
|
kv_connector = get_kv_transfer_group()
|
|
assert isinstance(kv_connector, KVConnectorBase_V1)
|
|
assert scheduler_output.kv_connector_metadata is not None
|
|
kv_connector.bind_connector_metadata(
|
|
scheduler_output.kv_connector_metadata)
|
|
|
|
kv_connector.start_load_kv(get_forward_context())
|
|
|
|
@staticmethod
|
|
def maybe_wait_for_kv_save() -> None:
|
|
if has_kv_transfer_group():
|
|
get_kv_transfer_group().wait_for_save()
|
|
|
|
@staticmethod
|
|
def get_finished_kv_transfer(
|
|
scheduler_output: "SchedulerOutput",
|
|
) -> tuple[Optional[set[str]], Optional[set[str]]]:
|
|
if has_kv_transfer_group():
|
|
return get_kv_transfer_group().get_finished(
|
|
scheduler_output.finished_req_ids)
|
|
return None, None
|
|
|
|
def _build_attention_metadata(self, create_mixed_batch, num_reqs,
|
|
num_tokens, max_query_len, force_attention):
|
|
attn_metadata: Optional[dict[str, Any]] = None
|
|
|
|
if force_attention:
|
|
attn_metadata = {}
|
|
|
|
if create_mixed_batch:
|
|
raise NotImplementedError(
|
|
"force_attention=True is not supported for mixed batches.")
|
|
else:
|
|
seq_lens = self.model_config.max_model_len
|
|
self.seq_lens_np[:num_reqs] = seq_lens
|
|
self.seq_lens_np[num_reqs:] = 0
|
|
|
|
num_computed_tokens_cpu = (
|
|
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs])
|
|
|
|
for kv_cache_group_id, kv_cache_group_spec in enumerate(
|
|
self.kv_cache_config.kv_cache_groups):
|
|
block_table_tensor = self.input_batch.block_table[
|
|
kv_cache_group_id].get_device_tensor()
|
|
common_attn_metadata = AscendCommonAttentionMetadata(
|
|
query_start_loc=self.query_start_loc[:num_reqs + 1],
|
|
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs +
|
|
1],
|
|
seq_lens_cpu=self.seq_lens_cpu,
|
|
seq_lens=self.seq_lens_cpu[:num_reqs],
|
|
num_reqs=num_reqs,
|
|
num_actual_tokens=num_tokens,
|
|
actual_seq_lengths_q=self.actual_seq_lengths_q,
|
|
block_table_tensor=block_table_tensor[:num_reqs],
|
|
slot_mapping=self.slot_mapping,
|
|
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
|
positions=self.positions,
|
|
attn_mask=self.attn_mask,
|
|
spec_attn_mask=self.spec_attn_mask,
|
|
attn_state=self.attn_state,
|
|
max_query_len=max_query_len,
|
|
decode_token_per_req=self.decode_token_per_req,
|
|
cos=self.cos,
|
|
sin=self.sin,
|
|
)
|
|
|
|
for attn_group in self.attn_groups[kv_cache_group_id]:
|
|
builder = attn_group.get_metadata_builder()
|
|
attn_metadata_i = builder.build_for_graph_capture(
|
|
common_attn_metadata, AscendAttentionState.DecodeOnly,
|
|
self.get_model())
|
|
for layer_name in kv_cache_group_spec.layer_names:
|
|
attn_metadata[layer_name] = attn_metadata_i
|
|
|
|
return attn_metadata
|
|
|
|
def _generate_dummy_run_hidden_states(self, with_prefill,
|
|
is_torchair_compile, input_ids,
|
|
positions, attn_metadata, num_tokens,
|
|
intermediate_tensors, inputs_embeds):
|
|
hidden_states = self.model(input_ids=input_ids,
|
|
positions=positions,
|
|
intermediate_tensors=intermediate_tensors,
|
|
inputs_embeds=inputs_embeds)
|
|
forward_context = get_forward_context()
|
|
assert forward_context is not None
|
|
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and \
|
|
not forward_context.capturing:
|
|
if self.vllm_config.model_config.use_mla:
|
|
# FIXME: Try using `auto_dispatch_capture=True`
|
|
update_mla_attn_params(self.update_stream, forward_context,
|
|
positions.shape[0])
|
|
else:
|
|
update_attn_params(self.update_stream, forward_context,
|
|
positions.shape[0])
|
|
|
|
if self.drafter and self.drafter.name == SpecDcodeType.EAGLE3:
|
|
hidden_states, _ = hidden_states
|
|
else:
|
|
hidden_states = hidden_states
|
|
return hidden_states
|
|
|
|
@torch.inference_mode()
|
|
def _dummy_run(
|
|
self,
|
|
num_tokens: int,
|
|
with_prefill: bool = False,
|
|
is_torchair_compile: bool = False,
|
|
aclgraph_runtime_mode: Optional[CUDAGraphMode] = None,
|
|
force_attention: bool = False,
|
|
uniform_decode: bool = False,
|
|
) -> torch.Tensor:
|
|
# only support eager mode and piecewise graph now
|
|
assert aclgraph_runtime_mode is None or aclgraph_runtime_mode in {
|
|
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL
|
|
}
|
|
|
|
# Padding for DP
|
|
(num_tokens, num_tokens_across_dp, with_prefill,
|
|
_) = self._sync_metadata_across_dp(num_tokens, with_prefill, False)
|
|
|
|
moe_comm_type = self._select_moe_comm_method(num_tokens, with_prefill)
|
|
|
|
# If cudagraph_mode.decode_mode() == FULL and
|
|
# cudagraph_mode.seperate_routine(). This means that we are using
|
|
# different graphs and/or modes for mixed prefill-decode batches vs.
|
|
# uniform decode batches. A uniform decode batch means that all
|
|
# requests have identical query length, except a potential virtual
|
|
# request (shorter) in the batch account for padding.
|
|
# Uniform decode batch could either be common pure decode, where
|
|
# max_query_len == 1, or speculative decode, where
|
|
# max_query_len == 1 + num_spec_decode_tokens.
|
|
|
|
# When setting max_query_len = 1, we switch to and capture the optimized
|
|
# routine of FA2 for pure decode, i.e., Flashdecode + an optimization
|
|
# for GQA/MQA.
|
|
max_query_len = self.uniform_decode_query_len if uniform_decode else \
|
|
num_tokens
|
|
|
|
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
|
|
# for dummy run with LoRA so that the num_reqs collectively
|
|
# has num_tokens in total.
|
|
assert num_tokens <= self.scheduler_config.max_num_batched_tokens
|
|
max_num_reqs = self.scheduler_config.max_num_seqs
|
|
if uniform_decode:
|
|
num_reqs = cdiv(num_tokens, max_query_len)
|
|
num_scheduled_tokens_list = [max_query_len] * num_reqs
|
|
if num_tokens % max_query_len != 0:
|
|
num_scheduled_tokens_list[-1] = num_tokens % max_query_len
|
|
else:
|
|
if with_prefill:
|
|
num_reqs = num_tokens
|
|
else:
|
|
num_reqs = (num_tokens + self.decode_token_per_req -
|
|
1) // self.decode_token_per_req
|
|
num_reqs = min(num_reqs, max_num_reqs)
|
|
min_tokens_per_req = num_tokens // num_reqs
|
|
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
|
|
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
|
|
assert sum(num_scheduled_tokens_list) == num_tokens
|
|
assert len(num_scheduled_tokens_list) == num_reqs
|
|
num_scheduled_tokens = np.array(num_scheduled_tokens_list,
|
|
dtype=np.int32)
|
|
|
|
# Force dummy run on prefill stage when this node is deemed as kv producer.
|
|
if self.is_kv_producer and not self.is_kv_consumer:
|
|
with_prefill = True
|
|
|
|
# TODO(cmq): check if with_prefill is reasonable
|
|
attn_metadata = self._build_attention_metadata(
|
|
False,
|
|
num_reqs=num_reqs,
|
|
num_tokens=num_tokens,
|
|
max_query_len=max_query_len,
|
|
force_attention=force_attention,
|
|
)
|
|
|
|
if not self.in_profile_run and self.dynamic_eplb:
|
|
self.eplb_updator.forward_before()
|
|
|
|
with self.maybe_dummy_run_with_lora(self.lora_config,
|
|
num_scheduled_tokens):
|
|
if self.is_multimodal_model:
|
|
input_ids = None
|
|
inputs_embeds = self.inputs_embeds[:num_tokens]
|
|
else:
|
|
input_ids = self.input_ids[:num_tokens]
|
|
inputs_embeds = None
|
|
|
|
if self.uses_mrope:
|
|
positions = self.mrope_positions[:, :num_tokens]
|
|
else:
|
|
positions = self.positions[:num_tokens]
|
|
|
|
if get_pp_group().is_first_rank:
|
|
intermediate_tensors = None
|
|
else:
|
|
if self.intermediate_tensors is None:
|
|
self.intermediate_tensors = (
|
|
self.model.make_empty_intermediate_tensors(
|
|
batch_size=num_tokens,
|
|
dtype=self.dtype,
|
|
device=self.device))
|
|
intermediate_tensors = IntermediateTensors({
|
|
k: v[:num_tokens]
|
|
for k, v in self.intermediate_tensors.items()
|
|
})
|
|
|
|
# filter out the valid batch descriptor
|
|
_ag_mode, batch_descriptor = \
|
|
self.aclgraph_dispatcher.dispatch(
|
|
BatchDescriptor(num_tokens=num_tokens,
|
|
uniform_decode=uniform_decode))
|
|
if aclgraph_runtime_mode is not None:
|
|
# we allow forcing NONE when the dispatcher disagrees to support
|
|
# warm ups for aclgraph capture
|
|
assert aclgraph_runtime_mode == CUDAGraphMode.NONE or \
|
|
aclgraph_runtime_mode == _ag_mode, (
|
|
f"Aclgraph runtime mode mismatch at dummy_run. "
|
|
f"Expected {_ag_mode}, but got {aclgraph_runtime_mode}.")
|
|
else:
|
|
aclgraph_runtime_mode = _ag_mode
|
|
|
|
need_dummy_logits = (not self.in_profile_run
|
|
and lmhead_tp_enable())
|
|
|
|
if need_dummy_logits:
|
|
max_num_reqs_across_dp = num_tokens if not with_prefill else max_num_reqs
|
|
dummy_indices = torch.zeros(max_num_reqs_across_dp,
|
|
dtype=torch.int32)
|
|
|
|
def dummy_compute_logits(hidden_states):
|
|
return self.model.compute_logits(
|
|
hidden_states[dummy_indices])
|
|
|
|
with set_ascend_forward_context(
|
|
attn_metadata,
|
|
self.vllm_config,
|
|
num_tokens=num_tokens,
|
|
num_tokens_across_dp=num_tokens_across_dp,
|
|
with_prefill=with_prefill,
|
|
in_profile_run=self.in_profile_run,
|
|
reserved_mc2_mask=self.reserved_mc2_mask,
|
|
moe_comm_type=moe_comm_type,
|
|
num_actual_tokens=0,
|
|
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
|
batch_descriptor=batch_descriptor,
|
|
prefetch_stream=self.prefetch_stream,
|
|
model_instance=self.model,
|
|
weight_prefetch_method=self.weight_prefetch_method):
|
|
hidden_states = self._generate_dummy_run_hidden_states(
|
|
with_prefill, is_torchair_compile, input_ids, positions,
|
|
attn_metadata, num_tokens, intermediate_tensors,
|
|
inputs_embeds)
|
|
if need_dummy_logits:
|
|
dummy_compute_logits(hidden_states)
|
|
|
|
if self.drafter:
|
|
self.drafter.dummy_run(
|
|
num_tokens=num_tokens,
|
|
with_prefill=with_prefill,
|
|
skip_attn=True,
|
|
num_reqs=num_reqs,
|
|
num_tokens_across_dp=num_tokens_across_dp)
|
|
if need_dummy_logits:
|
|
dummy_compute_logits(hidden_states)
|
|
if self.in_profile_run and self.dynamic_eplb:
|
|
self.model.clear_all_moe_loads()
|
|
if not self.in_profile_run and self.dynamic_eplb:
|
|
self.eplb_updator.take_update_info_from_eplb_process()
|
|
self.eplb_updator.forward_end()
|
|
return hidden_states
|
|
|
|
@contextmanager
|
|
def set_in_profile_run(self):
|
|
self.in_profile_run = True
|
|
try:
|
|
yield
|
|
finally:
|
|
self.in_profile_run = False
|
|
|
|
def profile_run(self) -> None:
|
|
# Trigger compilation for general shape.
|
|
with self.set_in_profile_run():
|
|
hidden_states = self._dummy_run(self.max_num_tokens,
|
|
with_prefill=True)
|
|
# MC2 will consume additional NPU memory.
|
|
# Therefore, we need to run the MC2 path once here to complete its initialization,
|
|
# allowing vLLM to correctly estimate the maximum memory required.
|
|
if self.max_num_tokens > self.mc2_tokens_capacity and \
|
|
self._select_moe_comm_method(
|
|
self.mc2_tokens_capacity,
|
|
with_prefill=True) == MoECommType.MC2:
|
|
self._dummy_run(self.mc2_tokens_capacity, with_prefill=True)
|
|
|
|
output = None
|
|
if get_pp_group().is_last_rank:
|
|
if self.is_pooling_model:
|
|
output = self._dummy_pooler_run(hidden_states)
|
|
else:
|
|
# For profile, have maximum num_reqs and that collectively have
|
|
# maximum num_tokens.
|
|
min_tokens_per_req = self.max_num_tokens // self.max_num_reqs
|
|
num_scheduled_tokens_list = [min_tokens_per_req
|
|
] * self.max_num_reqs
|
|
num_scheduled_tokens_list[
|
|
-1] += self.max_num_tokens % self.max_num_reqs
|
|
num_scheduled_tokens = np.array(num_scheduled_tokens_list,
|
|
dtype=np.int32)
|
|
logit_indices = np.cumsum(num_scheduled_tokens) - 1
|
|
# TODO: need to rum a dummy sampler for generate task
|
|
hidden_states = hidden_states[logit_indices]
|
|
output = self.model.compute_logits(hidden_states)
|
|
|
|
NPUPlatform.synchronize()
|
|
del hidden_states, output
|
|
self.encoder_cache.clear()
|
|
gc.collect()
|
|
|
|
def _dummy_pooler_run_task(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
task: PoolingTask,
|
|
) -> PoolerOutput:
|
|
num_tokens = hidden_states.shape[0]
|
|
max_num_reqs = self.scheduler_config.max_num_seqs
|
|
num_reqs = min(num_tokens, max_num_reqs)
|
|
min_tokens_per_req = num_tokens // num_reqs
|
|
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
|
|
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
|
|
assert sum(num_scheduled_tokens_list) == num_tokens
|
|
assert len(num_scheduled_tokens_list) == num_reqs
|
|
|
|
req_num_tokens = num_tokens // num_reqs
|
|
|
|
dummy_token_ids = torch.zeros((num_reqs, req_num_tokens),
|
|
dtype=torch.int32,
|
|
device=self.device)
|
|
|
|
model = cast(VllmModelForPooling, self.get_model())
|
|
dummy_pooling_params = PoolingParams(task=task)
|
|
to_update = model.pooler.get_pooling_updates(task)
|
|
to_update.apply(dummy_pooling_params)
|
|
|
|
dummy_prompt_lens = torch.tensor(
|
|
num_scheduled_tokens_list,
|
|
device="cpu",
|
|
)
|
|
dummy_metadata = PoolingMetadata(
|
|
prompt_lens=dummy_prompt_lens,
|
|
prompt_token_ids=dummy_token_ids,
|
|
pooling_params=[dummy_pooling_params] * num_reqs,
|
|
)
|
|
|
|
dummy_metadata.build_pooling_cursor(num_scheduled_tokens_list,
|
|
device=hidden_states.device)
|
|
|
|
try:
|
|
return model.pooler(hidden_states=hidden_states,
|
|
pooling_metadata=dummy_metadata)
|
|
except RuntimeError as e:
|
|
if 'out of memory' in str(e):
|
|
raise RuntimeError(
|
|
"CUDA out of memory occurred when warming up pooler "
|
|
f"({task=}) with {num_reqs} dummy requests. Please try "
|
|
"lowering `max_num_seqs` or `gpu_memory_utilization` when "
|
|
"initializing the engine.") from e
|
|
else:
|
|
raise e
|
|
|
|
@torch.inference_mode()
|
|
def _dummy_pooler_run(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
) -> PoolerOutput:
|
|
# Find the task that has the largest output for subsequent steps
|
|
output_size = dict[PoolingTask, float]()
|
|
for task in self.get_supported_pooling_tasks():
|
|
# Run a full batch with each task to ensure none of them OOMs
|
|
output = self._dummy_pooler_run_task(hidden_states, task)
|
|
output_size[task] = sum(o.nbytes for o in output)
|
|
del output # Allow GC
|
|
|
|
max_task = max(output_size.items(), key=lambda x: x[1])[0]
|
|
return self._dummy_pooler_run_task(hidden_states, max_task)
|
|
|
|
def eplb_warmup(self):
|
|
if self.dynamic_eplb and not self.is_eplb_warmuped:
|
|
self.is_eplb_warmuped = True
|
|
self.eplb_adaptor = VllmEplbAdaptor(model=self.model)
|
|
self.eplb_loader.set_adator(self.eplb_adaptor)
|
|
self.eplb_updator.set_adaptor(self.eplb_adaptor)
|
|
self.eplb_updator.warm_up_eplb()
|
|
|
|
def load_model(self) -> None:
|
|
logger.info("Starting to load model %s...", self.model_config.model)
|
|
|
|
with DeviceMemoryProfiler() as m: # noqa: SIM117
|
|
self.model = get_model(vllm_config=self.vllm_config)
|
|
if self.dynamic_eplb:
|
|
model_register(self.model, self.model_config)
|
|
if is_310p():
|
|
from vllm.model_executor.layers.linear import (
|
|
MergedColumnParallelLinear, QKVParallelLinear,
|
|
RowParallelLinear)
|
|
for module in self.model.modules():
|
|
if isinstance(module,
|
|
(MergedColumnParallelLinear,
|
|
QKVParallelLinear, RowParallelLinear)):
|
|
module.weight.data = self._convert_torch_format(
|
|
module.weight.data)
|
|
if self.drafter:
|
|
logger.info("Loading drafter model...")
|
|
self.drafter.load_model(self.model)
|
|
if self.drafter.name == SpecDcodeType.EAGLE3:
|
|
self.model.set_aux_hidden_state_layers(
|
|
self.model.get_eagle3_aux_hidden_state_layers())
|
|
|
|
if self.lora_config:
|
|
self.model = self.load_lora_model(self.model, self.vllm_config,
|
|
self.device)
|
|
logger.info("Loading model weights took %.4f GB",
|
|
m.consumed_memory / float(2**30))
|
|
|
|
# wrap the model with full graph wrapper if needed.
|
|
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
|
self.update_stream = torch.npu.Stream()
|
|
set_graph_params(self.compilation_config.cudagraph_capture_sizes)
|
|
self.model = ACLGraphWrapper(self.model,
|
|
self.vllm_config,
|
|
runtime_mode=CUDAGraphMode.FULL)
|
|
|
|
def _convert_torch_format(self, tensor):
|
|
if ACL_FORMAT == ACL_FORMAT_FRACTAL_NZ \
|
|
and not is_enable_nz():
|
|
return tensor
|
|
tensor = torch_npu.npu_format_cast(tensor, ACL_FORMAT)
|
|
return tensor
|
|
|
|
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
|
"""
|
|
Initialize KV cache based on `kv_cache_config`.
|
|
Args:
|
|
kv_cache_config: Configuration for the KV cache, including the KV
|
|
cache size of each layer
|
|
"""
|
|
kv_cache_config = deepcopy(kv_cache_config)
|
|
self.kv_cache_config = kv_cache_config
|
|
self.initialize_attn_backend(kv_cache_config)
|
|
self.use_hybrid_blocks = (len(self.attn_groups) > 1)
|
|
# NOTE: Currently, we determine whether we need `num_accepted_tokens` through `MambaSpec`.
|
|
self.need_accepted_tokens = any([
|
|
isinstance(attn_group[0].kv_cache_spec, MambaSpec)
|
|
for attn_group in self.attn_groups
|
|
])
|
|
|
|
self.may_reinitialize_input_batch(kv_cache_config)
|
|
|
|
if self.ascend_config.is_deepseek_sfa:
|
|
kv_caches = self.initialize_kv_cache_tensors_deepseek_sfa(
|
|
kv_cache_config)
|
|
elif self.model_config.is_deepseek_mla:
|
|
kv_caches = self.initialize_kv_cache_tensors_deepseek_mla(
|
|
kv_cache_config)
|
|
else:
|
|
kv_caches = self.initialize_kv_cache_tensors(kv_cache_config)
|
|
|
|
if has_kv_transfer_group():
|
|
get_kv_transfer_group().register_kv_caches(kv_caches)
|
|
|
|
def _align_memory(self, tensor: torch.Tensor,
|
|
alignment: int) -> torch.Tensor:
|
|
data_ptr = tensor.data_ptr()
|
|
aligned_addr = (data_ptr + alignment - 1) // alignment * alignment
|
|
offset = (aligned_addr - data_ptr) // tensor.element_size()
|
|
return tensor[int(offset):]
|
|
|
|
def initialize_kv_cache_tensors_deepseek_sfa(
|
|
self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]:
|
|
kv_cache_sizes = {}
|
|
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
|
|
assert len(kv_cache_tensor.shared_by) == 1, (
|
|
"KV cache tensor shared by multiple layers is not supported in "
|
|
"NPU.")
|
|
kv_cache_sizes[kv_cache_tensor.shared_by[0]] = kv_cache_tensor.size
|
|
|
|
kv_caches: Dict[str, torch.Tensor] = {}
|
|
for group in self._kv_cache_spec_attn_group_iterator():
|
|
kv_cache_spec = group.kv_cache_spec
|
|
attn_backend = group.backend
|
|
for layer_name in group.layer_names:
|
|
if layer_name in self.runner_only_attn_layers:
|
|
continue
|
|
tensor_size = kv_cache_sizes[layer_name]
|
|
num_blocks = tensor_size // kv_cache_spec.page_size_bytes
|
|
if self.vllm_config.additional_config.get(
|
|
"kv_cache_dtype", None) == 'int8':
|
|
kv_cache_shape = attn_backend.get_bsh_kv_cache_shape(
|
|
num_blocks, kv_cache_spec.block_size,
|
|
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
|
|
elif hasattr(
|
|
attn_backend, "get_supported_block_size"
|
|
) and not self.model_config.is_deepseek_mla and not self.ascend_config.is_deepseek_sfa:
|
|
block_size = attn_backend.get_supported_block_size()[0]
|
|
block_size_chunk = kv_cache_spec.block_size // block_size
|
|
kv_cache_shape = attn_backend.get_kv_cache_shape(
|
|
num_blocks * block_size_chunk, block_size,
|
|
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
|
|
else:
|
|
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
|
|
num_blocks, kv_cache_spec.block_size,
|
|
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
|
|
dtype = kv_cache_spec.dtype
|
|
|
|
alignment = 2 * 1024 * 1024
|
|
num_blocks, block_size, num_kv_heads, head_size = kv_cache_shape
|
|
rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
|
|
nope_dim = head_size - rope_dim
|
|
nope_cache_shape = (num_blocks, block_size, num_kv_heads,
|
|
nope_dim)
|
|
rope_cache_shape = (num_blocks, block_size, num_kv_heads,
|
|
rope_dim)
|
|
#### k cache
|
|
# TODO(zzzzwwjj): wait transformers add these params
|
|
k_cache_shape = (num_blocks, block_size, 1, 128)
|
|
if self.vllm_config.kv_transfer_config is None:
|
|
# For no disaggregate pd scenario, allocate kv cache in normal way
|
|
rope_cache = torch.zeros(rope_cache_shape,
|
|
dtype=dtype,
|
|
device=self.device)
|
|
nope_cache = torch.zeros(nope_cache_shape,
|
|
dtype=dtype,
|
|
device=self.device)
|
|
rope_cache = self._convert_torch_format(rope_cache)
|
|
nope_cache = self._convert_torch_format(nope_cache)
|
|
|
|
#### k cache
|
|
k_cache = torch.zeros(k_cache_shape,
|
|
dtype=dtype,
|
|
device=self.device)
|
|
k_cache = self._convert_torch_format(k_cache)
|
|
else:
|
|
|
|
# In order to transfer kv cache through the reigster_memory api from llmdatadist, the memory
|
|
# address should be aligned by 2M. In most case, torch_npu can allocate 2M aligned memory, but
|
|
# we found there are also some exceptions during test, so we manual align those memory here, this part
|
|
# of code may consume 2M * 2 * elem_size memory every layer.
|
|
nope_allocate_shape = num_blocks * block_size * num_kv_heads * nope_dim
|
|
nope_allocate_shape_alignment = nope_allocate_shape + alignment
|
|
rope_allocate_shape = num_blocks * block_size * num_kv_heads * rope_dim
|
|
rope_allocate_shape_alignment = rope_allocate_shape + alignment
|
|
|
|
nope_cache = torch.zeros(nope_allocate_shape_alignment,
|
|
dtype=dtype,
|
|
device=self.device)
|
|
rope_cache = torch.zeros(rope_allocate_shape_alignment,
|
|
dtype=dtype,
|
|
device=self.device)
|
|
#### k cache
|
|
# TODO(zzzzwwjj): wait transformers add these params
|
|
k_allocate_shape = num_blocks * block_size * 1 * 128
|
|
k_allocate_shape_alignment = k_allocate_shape + alignment
|
|
k_cache = torch.zeros(k_allocate_shape_alignment,
|
|
dtype=dtype,
|
|
device=self.device)
|
|
|
|
nope_cache = self._align_memory(
|
|
nope_cache,
|
|
alignment)[:nope_allocate_shape].view(nope_cache_shape)
|
|
rope_cache = self._align_memory(
|
|
rope_cache,
|
|
alignment)[:rope_allocate_shape].view(rope_cache_shape)
|
|
k_cache = self._align_memory(
|
|
k_cache,
|
|
alignment)[:k_allocate_shape].view(k_cache_shape)
|
|
|
|
kv_caches[layer_name] = (nope_cache, rope_cache, k_cache)
|
|
bind_kv_cache(kv_caches,
|
|
self.compilation_config.static_forward_context,
|
|
self.kv_caches)
|
|
|
|
return kv_caches
|
|
|
|
def initialize_kv_cache_tensors_deepseek_mla(
|
|
self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]:
|
|
kv_cache_sizes = {}
|
|
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
|
|
assert len(kv_cache_tensor.shared_by) == 1, (
|
|
"KV cache tensor shared by multiple layers is not supported in "
|
|
"NPU.")
|
|
kv_cache_sizes[kv_cache_tensor.shared_by[0]] = kv_cache_tensor.size
|
|
|
|
kv_caches: Dict[str, torch.Tensor] = {}
|
|
for group in self._kv_cache_spec_attn_group_iterator():
|
|
kv_cache_spec = group.kv_cache_spec
|
|
attn_backend = group.backend
|
|
for layer_name in group.layer_names:
|
|
if layer_name in self.runner_only_attn_layers:
|
|
continue
|
|
tensor_size = kv_cache_sizes[layer_name]
|
|
num_blocks = tensor_size // kv_cache_spec.page_size_bytes
|
|
if self.vllm_config.additional_config.get(
|
|
"kv_cache_dtype", None) == 'int8':
|
|
kv_cache_shape = attn_backend.get_bsh_kv_cache_shape(
|
|
num_blocks, kv_cache_spec.block_size,
|
|
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
|
|
elif hasattr(attn_backend, "get_supported_block_size"
|
|
) and not self.model_config.is_deepseek_mla:
|
|
block_size = attn_backend.get_supported_block_size()[0]
|
|
block_size_chunk = kv_cache_spec.block_size // block_size
|
|
kv_cache_shape = attn_backend.get_kv_cache_shape(
|
|
num_blocks * block_size_chunk, block_size,
|
|
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
|
|
else:
|
|
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
|
|
num_blocks, kv_cache_spec.block_size,
|
|
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
|
|
dtype = kv_cache_spec.dtype
|
|
|
|
alignment = 2 * 1024 * 1024
|
|
num_blocks, block_size, num_kv_heads, head_size = kv_cache_shape
|
|
rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
|
|
nope_dim = head_size - rope_dim
|
|
nope_cache_shape = (num_blocks, block_size, num_kv_heads,
|
|
nope_dim)
|
|
rope_cache_shape = (num_blocks, block_size, num_kv_heads,
|
|
rope_dim)
|
|
if self.vllm_config.kv_transfer_config is None:
|
|
# For no disaggregate pd scenario, allocate kv cache in normal way
|
|
rope_cache = torch.zeros(rope_cache_shape,
|
|
dtype=dtype,
|
|
device=self.device)
|
|
nope_cache = torch.zeros(nope_cache_shape,
|
|
dtype=dtype,
|
|
device=self.device)
|
|
rope_cache = self._convert_torch_format(rope_cache)
|
|
nope_cache = self._convert_torch_format(nope_cache)
|
|
else:
|
|
|
|
# In order to transfer kv cache through the reigster_memory api from llmdatadist, the memory
|
|
# address should be aligned by 2M. In most case, torch_npu can allocate 2M aligned memory, but
|
|
# we found there are also some exceptions during test, so we manual align those memory here, this part
|
|
# of code may consume 2M * 2 * elem_size memory every layer.
|
|
nope_allocate_shape = num_blocks * block_size * num_kv_heads * nope_dim
|
|
nope_allocate_shape_alignment = nope_allocate_shape + alignment
|
|
rope_allocate_shape = num_blocks * block_size * num_kv_heads * rope_dim
|
|
rope_allocate_shape_alignment = rope_allocate_shape + alignment
|
|
|
|
nope_cache = torch.zeros(nope_allocate_shape_alignment,
|
|
dtype=dtype,
|
|
device=self.device)
|
|
rope_cache = torch.zeros(rope_allocate_shape_alignment,
|
|
dtype=dtype,
|
|
device=self.device)
|
|
nope_cache = self._align_memory(
|
|
nope_cache,
|
|
alignment)[:nope_allocate_shape].view(nope_cache_shape)
|
|
rope_cache = self._align_memory(
|
|
rope_cache,
|
|
alignment)[:rope_allocate_shape].view(rope_cache_shape)
|
|
kv_caches[layer_name] = (nope_cache, rope_cache)
|
|
|
|
bind_kv_cache(kv_caches,
|
|
self.compilation_config.static_forward_context,
|
|
self.kv_caches)
|
|
|
|
return kv_caches
|
|
|
|
def initialize_kv_cache_tensors(
|
|
self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]:
|
|
"""
|
|
Initialize the memory buffer for KV cache.
|
|
|
|
Args:
|
|
kv_cache_config: The KV cache config
|
|
Returns:
|
|
Dict[str, torch.Tensor]: A map between layer names to their
|
|
corresponding memory buffer for KV cache.
|
|
"""
|
|
# init kv cache tensors
|
|
kv_cache_raw_tensors: dict[str, Union[torch.Tensor,
|
|
Optional[torch.Tensor]]] = {}
|
|
# llmdatadist need the addr of cache tensor be aligned with 2M
|
|
alignment = 2 * 1024 * 1024
|
|
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
|
|
# TODO: REFACTOR ME to sharing hybrid cache
|
|
for idx in range(len(kv_cache_tensor.shared_by)):
|
|
layer_name = kv_cache_tensor.shared_by[idx]
|
|
if "linear_attn" in layer_name:
|
|
# for mamba linear attention
|
|
for layer_name_inner in kv_cache_tensor.shared_by:
|
|
if ("attn" in layer_name_inner and "linear_attn" not in layer_name_inner) or \
|
|
layer_name_inner in kv_cache_raw_tensors.keys():
|
|
continue
|
|
if self.vllm_config.kv_transfer_config is None:
|
|
tensor = torch.zeros(kv_cache_tensor.size,
|
|
dtype=torch.int8,
|
|
device=self.device)
|
|
else:
|
|
cache_size_aligned = kv_cache_tensor.size + alignment
|
|
tensor = torch.zeros(cache_size_aligned,
|
|
dtype=torch.int8,
|
|
device=self.device)
|
|
tensor = self._align_memory(
|
|
tensor, alignment)[:kv_cache_tensor.size]
|
|
kv_cache_raw_tensors[layer_name_inner] = tensor
|
|
elif "attn" in layer_name:
|
|
# for other attentions, e.g., self_attn, sliding window attn
|
|
if self.vllm_config.kv_transfer_config is None:
|
|
k_tensor = torch.zeros(kv_cache_tensor.size // 2,
|
|
dtype=torch.int8,
|
|
device=self.device)
|
|
v_tensor = torch.zeros(kv_cache_tensor.size // 2,
|
|
dtype=torch.int8,
|
|
device=self.device)
|
|
else:
|
|
cache_size = kv_cache_tensor.size // 2
|
|
cache_size_aligned = kv_cache_tensor.size // 2 + alignment
|
|
k_tensor = torch.zeros(cache_size_aligned,
|
|
dtype=torch.int8,
|
|
device=self.device)
|
|
v_tensor = torch.zeros(cache_size_aligned,
|
|
dtype=torch.int8,
|
|
device=self.device)
|
|
k_tensor = self._align_memory(k_tensor,
|
|
alignment)[:cache_size]
|
|
v_tensor = self._align_memory(v_tensor,
|
|
alignment)[:cache_size]
|
|
kv_cache_raw_tensors[layer_name] = (k_tensor, v_tensor)
|
|
|
|
layer_names = set()
|
|
for group in kv_cache_config.kv_cache_groups:
|
|
for layer_name in group.layer_names:
|
|
if layer_name in self.runner_only_attn_layers:
|
|
continue
|
|
layer_names.add(layer_name)
|
|
assert layer_names == set(kv_cache_raw_tensors.keys(
|
|
)), "Some layers are not correctly initialized"
|
|
|
|
kv_caches: Dict[str, torch.Tensor] = {}
|
|
for group in self._kv_cache_spec_attn_group_iterator():
|
|
kv_cache_spec = group.kv_cache_spec
|
|
attn_backend = group.backend
|
|
for layer_name in group.layer_names:
|
|
if layer_name in self.runner_only_attn_layers:
|
|
continue
|
|
|
|
# TODO: remove this after the OOM issue is located and fixed, otherwise, some model may
|
|
# encounter OOM issue
|
|
if isinstance(kv_cache_spec, FullAttentionSpec):
|
|
raw_k_tensor, raw_v_tensor = kv_cache_raw_tensors[ # type: ignore
|
|
layer_name]
|
|
assert raw_k_tensor is not None
|
|
assert raw_v_tensor is not None
|
|
assert (raw_k_tensor.numel() + raw_v_tensor.numel()
|
|
) % kv_cache_spec.page_size_bytes == 0
|
|
num_blocks = (raw_k_tensor.numel() + raw_v_tensor.numel()
|
|
) // kv_cache_spec.page_size_bytes
|
|
|
|
# `num_blocks` is the number of blocks the model runner can use.
|
|
# `kv_cache_config.num_blocks` is the number of blocks that
|
|
# KVCacheManager may allocate.
|
|
# Since different GPUs may have different number of layers and
|
|
# different memory capacities, `num_blocks` can be different on
|
|
# different GPUs, and `kv_cache_config.num_blocks` is set to
|
|
# the min of all `num_blocks`. Verify it here.
|
|
assert num_blocks >= kv_cache_config.num_blocks
|
|
|
|
if self.vllm_config.additional_config.get(
|
|
"kv_cache_dtype", None) == 'int8':
|
|
kv_cache_shape = attn_backend.get_bsh_kv_cache_shape(
|
|
num_blocks, kv_cache_spec.block_size,
|
|
kv_cache_spec.num_kv_heads,
|
|
kv_cache_spec.head_size)
|
|
elif hasattr(attn_backend, "get_supported_block_size"
|
|
) and self.use_hybrid_blocks:
|
|
block_size = attn_backend.get_supported_block_size()[0]
|
|
|
|
block_size_chunk = kv_cache_spec.block_size // block_size
|
|
kv_cache_shape = attn_backend.get_kv_cache_shape(
|
|
num_blocks * block_size_chunk, block_size,
|
|
kv_cache_spec.num_kv_heads,
|
|
kv_cache_spec.head_size)
|
|
else:
|
|
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
|
|
num_blocks, kv_cache_spec.block_size,
|
|
kv_cache_spec.num_kv_heads,
|
|
kv_cache_spec.head_size)
|
|
dtype = kv_cache_spec.dtype
|
|
k_cache = raw_k_tensor.view(dtype).view(kv_cache_shape[1:])
|
|
k_cache = self._convert_torch_format(k_cache)
|
|
v_cache = raw_v_tensor.view(dtype).view(kv_cache_shape[1:])
|
|
v_cache = self._convert_torch_format(v_cache)
|
|
kv_caches[layer_name] = (k_cache, v_cache)
|
|
elif isinstance(kv_cache_spec, MambaSpec):
|
|
raw_tensor = kv_cache_raw_tensors[layer_name]
|
|
assert raw_tensor is not None
|
|
assert raw_tensor.numel(
|
|
) % kv_cache_spec.page_size_bytes == 0
|
|
num_blocks = raw_tensor.numel(
|
|
) // kv_cache_spec.page_size_bytes
|
|
|
|
# `num_blocks` is the number of blocks the model runner can use.
|
|
# `kv_cache_config.num_blocks` is the number of blocks that
|
|
# KVCacheManager may allocate.
|
|
# Since different GPUs may have different number of layers and
|
|
# different memory capacities, `num_blocks` can be different on
|
|
# different GPUs, and `kv_cache_config.num_blocks` is set to
|
|
# the min of all `num_blocks`. Verify it here.
|
|
assert num_blocks >= kv_cache_config.num_blocks
|
|
|
|
state_tensors = []
|
|
storage_offset_bytes = 0
|
|
for (shape, dtype) in zip(kv_cache_spec.shapes,
|
|
kv_cache_spec.dtypes):
|
|
dtype_size = get_dtype_size(dtype)
|
|
num_element_per_page = (
|
|
kv_cache_spec.page_size_bytes // dtype_size)
|
|
target_shape = (num_blocks, *shape)
|
|
stride = torch.empty(target_shape).stride()
|
|
target_stride = (num_element_per_page, *stride[1:])
|
|
assert storage_offset_bytes % dtype_size == 0
|
|
tensor = torch.as_strided(
|
|
raw_tensor.view(dtype),
|
|
size=target_shape,
|
|
stride=target_stride,
|
|
storage_offset=storage_offset_bytes // dtype_size,
|
|
)
|
|
state_tensors.append(tensor)
|
|
storage_offset_bytes += stride[0] * dtype_size
|
|
kv_caches[layer_name] = state_tensors
|
|
else:
|
|
raise ValueError("Unknown KV cache spec type.")
|
|
|
|
bind_kv_cache(kv_caches,
|
|
self.compilation_config.static_forward_context,
|
|
self.kv_caches)
|
|
|
|
return kv_caches
|
|
|
|
def may_reinitialize_input_batch(self,
|
|
kv_cache_config: KVCacheConfig) -> None:
|
|
"""
|
|
Re-initialize the input batch if the block sizes are different from
|
|
`[self.cache_config.block_size]`. This usually happens when there
|
|
are multiple KV cache groups.
|
|
|
|
Args:
|
|
kv_cache_config: The KV cache configuration.
|
|
"""
|
|
block_sizes = [
|
|
kv_cache_group.kv_cache_spec.block_size
|
|
for kv_cache_group in kv_cache_config.kv_cache_groups
|
|
]
|
|
|
|
# Generate kernel_block_sizes that matches each block_size
|
|
# For attention backends that support virtual block splitting,
|
|
# use the supported block sizes from the backend
|
|
# For other backends (like Mamba), use [0] (no splitting)
|
|
kernel_block_sizes = []
|
|
for kv_cache_group_id, kv_cache_group in enumerate(
|
|
kv_cache_config.kv_cache_groups):
|
|
if isinstance(kv_cache_group.kv_cache_spec, AttentionSpec):
|
|
# This is an attention backend that supports virtual
|
|
# block splitting. Get the supported block sizes from
|
|
# the backend.
|
|
try:
|
|
attn_groups = self.attn_groups[kv_cache_group_id]
|
|
except IndexError:
|
|
attn_groups = None
|
|
if attn_groups and self.use_hybrid_blocks:
|
|
# Use the backend's supported block size list
|
|
backend = attn_groups[0].backend
|
|
supported_sizes = backend.get_supported_block_size()
|
|
# If no specific sizes supported, use cache config
|
|
# block_size
|
|
kernel_block_size_list = (supported_sizes
|
|
if supported_sizes else
|
|
[self.cache_config.block_size])
|
|
else:
|
|
# Fallback to cache config block_size if no backend found
|
|
kernel_block_size_list = [self.cache_config.block_size]
|
|
kernel_block_sizes.append(kernel_block_size_list)
|
|
else:
|
|
# This is likely Mamba or other non-attention cache,
|
|
# no splitting.
|
|
# NOTE: set kernel_block_sizes to 0 to disable slotmapping computation
|
|
# of mamba block. In this case, BlockTable.block_size will never equal
|
|
# to kernel_block_sizes[0]
|
|
kernel_block_sizes.append([0])
|
|
if kernel_block_sizes != [[self.cache_config.block_size]]:
|
|
assert self.cache_config.cpu_offload_gb == 0, (
|
|
"Cannot re-initialize the input batch when CPU weight "
|
|
"offloading is enabled. See https://github.com/vllm-project/vllm/pull/18298 " # noqa: E501
|
|
"for more details.")
|
|
self.input_batch = InputBatch(
|
|
max_num_reqs=self.max_num_reqs,
|
|
max_model_len=self.model_config.max_model_len,
|
|
max_num_batched_tokens=self.max_num_tokens,
|
|
device=self.device,
|
|
pin_memory=self.pin_memory,
|
|
vocab_size=self.model_config.get_vocab_size(),
|
|
block_sizes=block_sizes,
|
|
is_spec_decode=bool(self.vllm_config.speculative_config),
|
|
logitsprocs=self.input_batch.logitsprocs,
|
|
is_pooling_model=self.is_pooling_model,
|
|
num_speculative_tokens=(
|
|
self.vllm_config.speculative_config.num_speculative_tokens
|
|
if self.vllm_config.speculative_config else 0),
|
|
kernel_block_sizes=kernel_block_sizes,
|
|
)
|
|
|
|
def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:
|
|
"""
|
|
Initialize the attention backends and attention metadata builders.
|
|
"""
|
|
assert len(self.attn_groups) == 0, \
|
|
"Attention backends are already initialized"
|
|
|
|
class AttentionGroupKey(NamedTuple):
|
|
attn_backend: type[AttentionBackend]
|
|
kv_cache_spec: KVCacheSpec
|
|
|
|
def get_attn_backends_for_group(
|
|
kv_cache_group_spec: KVCacheGroupSpec,
|
|
) -> dict[AttentionGroupKey, list[str]]:
|
|
layers = get_layers_from_vllm_config(
|
|
self.vllm_config, AttentionLayerBase,
|
|
kv_cache_group_spec.layer_names)
|
|
attn_backends = {}
|
|
attn_backend_layers = defaultdict(list)
|
|
# Dedupe based on full class name; this is a bit safer than
|
|
# using the class itself as the key because when we create dynamic
|
|
# attention backend subclasses (e.g. ChunkedLocalAttention) unless
|
|
# they are cached correctly, there will be different objects per
|
|
# layer.
|
|
for layer_name in kv_cache_group_spec.layer_names:
|
|
attn_backend = layers[layer_name].get_attn_backend()
|
|
full_cls_name = attn_backend.full_cls_name()
|
|
layer_kv_cache_spec = kv_cache_group_spec.kv_cache_spec
|
|
if isinstance(layer_kv_cache_spec, UniformTypeKVCacheSpecs):
|
|
layer_kv_cache_spec = layer_kv_cache_spec.kv_cache_specs[
|
|
layer_name]
|
|
key = (full_cls_name, layer_kv_cache_spec)
|
|
attn_backends[key] = AttentionGroupKey(attn_backend,
|
|
layer_kv_cache_spec)
|
|
attn_backend_layers[key].append(layer_name)
|
|
return {
|
|
attn_backends[k]: v
|
|
for k, v in attn_backend_layers.items()
|
|
}
|
|
|
|
def create_attn_groups(
|
|
attn_backends_map: dict[AttentionBackend, list[str]],
|
|
) -> list[AttentionGroup]:
|
|
attn_groups: list[AttentionGroup] = []
|
|
for (attn_backend,
|
|
kv_cache_spec), layer_names in attn_backends_map.items():
|
|
attn_metadata_builders = []
|
|
attn_metadata_builders.append(attn_backend.get_builder_cls()(
|
|
kv_cache_spec,
|
|
layer_names,
|
|
self.vllm_config,
|
|
self.device,
|
|
))
|
|
attn_group = AttentionGroup(attn_backend,
|
|
attn_metadata_builders,
|
|
layer_names, kv_cache_spec)
|
|
attn_groups.append(attn_group)
|
|
return attn_groups
|
|
|
|
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
|
|
attn_backends = get_attn_backends_for_group( # type: ignore
|
|
kv_cache_group_spec)
|
|
self.attn_groups.append(create_attn_groups(attn_backends))
|
|
|
|
# Calculate reorder batch threshold (if needed)
|
|
self.calculate_reorder_batch_threshold()
|
|
|
|
def _attn_group_iterator(self) -> Iterator[AttentionGroup]:
|
|
return itertools.chain.from_iterable(self.attn_groups)
|
|
|
|
def _kv_cache_spec_attn_group_iterator(self) -> Iterator[AttentionGroup]:
|
|
if not self.kv_cache_config.kv_cache_groups:
|
|
return
|
|
for attn_groups in self.attn_groups:
|
|
yield from attn_groups
|
|
|
|
def calculate_reorder_batch_threshold(self) -> None:
|
|
"""
|
|
Check that if any backends reorder batches; that the reordering
|
|
is compatible (e.g., decode threshold is the same)
|
|
"""
|
|
for group in self._attn_group_iterator():
|
|
attn_metadata_builder_i = group.get_metadata_builder()
|
|
if hasattr(attn_metadata_builder_i, "reorder_batch_threshold"):
|
|
# check that if any backends reorder batches; that the reordering
|
|
# is compatible (e.g., decode threshold is the same)
|
|
reorder_batch_threshold_i = (
|
|
attn_metadata_builder_i.reorder_batch_threshold)
|
|
if reorder_batch_threshold_i is not None:
|
|
if self.reorder_batch_threshold is not None:
|
|
if reorder_batch_threshold_i != \
|
|
self.reorder_batch_threshold:
|
|
raise ValueError(
|
|
f"Attention backend reorders decodes with "
|
|
f"threshold {reorder_batch_threshold_i} but other "
|
|
f"backend uses threshold "
|
|
f"{self.reorder_batch_threshold}")
|
|
else:
|
|
self.reorder_batch_threshold = reorder_batch_threshold_i
|
|
|
|
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
|
"""
|
|
Generates the KVCacheSpec by parsing the kv cache format from each
|
|
Attention module in the static forward context.
|
|
Returns:
|
|
KVCacheSpec: A dictionary mapping layer names to their KV cache
|
|
format. Layers that do not need KV cache are not included.
|
|
"""
|
|
|
|
block_size = self.vllm_config.cache_config.block_size
|
|
use_mla = self.vllm_config.model_config.use_mla
|
|
use_sfa = self.ascend_config.use_sfa
|
|
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
|
attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
|
|
for layer_name, attn_module in attn_layers.items():
|
|
if (kv_tgt_layer :=
|
|
attn_module.kv_sharing_target_layer_name) is not None:
|
|
# The layer doesn't need its own KV cache and will use that of
|
|
# the target layer. We skip creating a KVCacheSpec for it, so
|
|
# that KV cache management logic will act as this layer does
|
|
# not exist, and doesn't allocate KV cache for the layer. This
|
|
# enables the memory saving of cross-layer kv sharing, allowing
|
|
# a given amount of memory to accommodate longer context lengths
|
|
# or enable more requests to be processed simultaneously.
|
|
self.shared_kv_cache_layers[layer_name] = kv_tgt_layer
|
|
continue
|
|
if isinstance(attn_module, AscendMultiHeadLatentAttention):
|
|
continue
|
|
|
|
# TODO: Support other attention modules, e.g., cross-attention
|
|
# TODO(lucas): move the attention specs into the model layers like
|
|
# the attention backends
|
|
if attn_module.attn_type == AttentionType.DECODER:
|
|
if use_mla and not use_sfa:
|
|
kv_cache_spec[layer_name] = MLAAttentionSpec(
|
|
block_size=block_size,
|
|
num_kv_heads=attn_module.num_kv_heads,
|
|
head_size=attn_module.head_size,
|
|
dtype=self.kv_cache_dtype,
|
|
cache_dtype_str=self.cache_config.cache_dtype)
|
|
else:
|
|
# TODO(cmq): This is a hack way to fix deepseek kvcache when
|
|
# using DSA. Fix the spec in vLLM is a finnal way.
|
|
kv_cache_spec[layer_name] = FullAttentionSpec(
|
|
block_size=block_size,
|
|
num_kv_heads=attn_module.num_kv_heads,
|
|
head_size=attn_module.head_size,
|
|
dtype=self.kv_cache_dtype)
|
|
elif attn_module.attn_type in (AttentionType.ENCODER,
|
|
AttentionType.ENCODER_ONLY):
|
|
# encoder-only attention does not need KV cache.
|
|
continue
|
|
elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
|
|
raise NotImplementedError
|
|
else:
|
|
raise ValueError(
|
|
f"Unknown attention type: {attn_module.attn_type}")
|
|
|
|
mamba_layers = get_layers_from_vllm_config(self.vllm_config, MambaBase)
|
|
if len(mamba_layers) > 0:
|
|
if (self.vllm_config.speculative_config is not None
|
|
and self.vllm_config.model_config.hf_config.model_type
|
|
not in ["qwen3_next"]):
|
|
raise NotImplementedError(
|
|
"Mamba with speculative decoding is not supported yet.")
|
|
if self.vllm_config.cache_config.enable_prefix_caching:
|
|
raise NotImplementedError(
|
|
"Prefix caching is not supported for Mamba yet.")
|
|
max_model_len = self.vllm_config.model_config.max_model_len
|
|
|
|
page_size_padded = (
|
|
self.vllm_config.cache_config.mamba_page_size_padded)
|
|
|
|
# Set block_size to max_model_len, so that mamba model will always
|
|
# have only one block in the KV cache.
|
|
for layer_name, mamba_module in mamba_layers.items():
|
|
kv_cache_spec[layer_name] = MambaSpec(
|
|
shapes=mamba_module.get_state_shape(),
|
|
dtypes=mamba_module.get_state_dtype(),
|
|
block_size=max_model_len,
|
|
page_size_padded=page_size_padded,
|
|
mamba_type=mamba_module.mamba_type,
|
|
num_speculative_blocks=(
|
|
self.speculative_config.num_speculative_tokens
|
|
if self.speculative_config else 0),
|
|
)
|
|
|
|
return kv_cache_spec
|
|
|
|
def initialize_aclgraph_capture(self) -> None:
|
|
min_ag_support = AttentionCGSupport.ALWAYS
|
|
min_ag_builder_name = None
|
|
|
|
for attn_group in self._attn_group_iterator():
|
|
builder = attn_group.get_metadata_builder()
|
|
if builder.aclgraph_support.value < min_ag_support.value:
|
|
min_ag_support = builder.aclgraph_support
|
|
min_ag_builder_name = builder.__class__.__name__
|
|
|
|
# This is an imitation of compilation_config.splitting_ops_contain_attention()
|
|
splitting_ops_contain_attention = (
|
|
self.compilation_config.splitting_ops is not None
|
|
and all(op in self.compilation_config.splitting_ops for op in [
|
|
"vllm.unified_ascend_attention_with_output",
|
|
"vllm.mla_forward",
|
|
]))
|
|
|
|
# Flexible resolve the aclgraph mode
|
|
aclgraph_mode = self.compilation_config.cudagraph_mode
|
|
# check graph for mixed batch is supported
|
|
if aclgraph_mode.mixed_mode() == CUDAGraphMode.FULL \
|
|
and min_ag_support != AttentionCGSupport.ALWAYS:
|
|
msg = (f"ACLGraphMode.{aclgraph_mode.name} is not supported "
|
|
f"with {min_ag_builder_name} backend (support: "
|
|
f"{min_ag_support})")
|
|
if min_ag_support == AttentionCGSupport.NEVER:
|
|
# if not supported any full graphs, just raise it.
|
|
msg += "; please try cudagraph_mode=PIECEWISE, and "\
|
|
"make sure compilation level is piecewise"
|
|
raise ValueError(msg)
|
|
|
|
# attempt to resolve the full graph related mode
|
|
if splitting_ops_contain_attention:
|
|
msg += "; setting cudagraph_mode=FULL_AND_PIECEWISE"
|
|
aclgraph_mode = self.compilation_config.cudagraph_mode = (
|
|
CUDAGraphMode.FULL_AND_PIECEWISE)
|
|
else:
|
|
msg += "; setting cudagraph_mode=FULL_DECODE_ONLY"
|
|
aclgraph_mode = self.compilation_config.cudagraph_mode = (
|
|
CUDAGraphMode.FULL_DECODE_ONLY)
|
|
logger.warning(msg)
|
|
|
|
# check that if spec-decode + decode full-graphs is supported
|
|
if (aclgraph_mode.decode_mode() == CUDAGraphMode.FULL
|
|
and self.uniform_decode_query_len > 1 and min_ag_support.value
|
|
< AttentionCGSupport.UNIFORM_BATCH.value):
|
|
msg = (f"CUDAGraphMode.{aclgraph_mode.name} is not supported"
|
|
f" with spec-decode for attention backend "
|
|
f"{min_ag_builder_name} (support: {min_ag_support})")
|
|
if splitting_ops_contain_attention:
|
|
msg += "; setting cudagraph_mode=PIECEWISE"
|
|
aclgraph_mode = self.compilation_config.cudagraph_mode = \
|
|
CUDAGraphMode.PIECEWISE
|
|
else:
|
|
msg += "; setting cudagraph_mode=NONE"
|
|
aclgraph_mode = self.compilation_config.cudagraph_mode = \
|
|
CUDAGraphMode.NONE
|
|
logger.warning(msg)
|
|
|
|
# double check that we can support full graph if they are requested
|
|
# even after automatic downgrades
|
|
if aclgraph_mode.has_full_cudagraphs() \
|
|
and min_ag_support == AttentionCGSupport.NEVER:
|
|
raise ValueError(f"CUDAGraphMode.{aclgraph_mode.name} is not "
|
|
f"supported with {min_ag_builder_name} backend ("
|
|
f"support:{min_ag_support}) "
|
|
"; please try cudagraph_mode=PIECEWISE, "
|
|
"and make sure compilation level is piecewise")
|
|
|
|
self.aclgraph_dispatcher.initialize_cudagraph_keys(
|
|
self.compilation_config.cudagraph_mode,
|
|
self.uniform_decode_query_len)
|
|
|
|
def _capture_aclgraphs(self, compilation_cases: list[int],
|
|
aclgraph_runtime_mode: CUDAGraphMode,
|
|
uniform_decode: bool):
|
|
assert aclgraph_runtime_mode != CUDAGraphMode.NONE and \
|
|
aclgraph_runtime_mode in [CUDAGraphMode.FULL,
|
|
CUDAGraphMode.PIECEWISE]
|
|
|
|
# Only rank 0 should print progress bar during capture
|
|
if is_global_first_rank():
|
|
logger.info(
|
|
"Starting to capture ACL graphs for cases: %s, "
|
|
"mode: %s, uniform_decode: %s", compilation_cases,
|
|
aclgraph_runtime_mode.name, uniform_decode)
|
|
compilation_cases = tqdm(
|
|
compilation_cases,
|
|
disable=not self.load_config.use_tqdm_on_load,
|
|
desc="Capturing ACL graphs ({}, {})".format(
|
|
"decode" if uniform_decode else "mixed prefill-decode",
|
|
aclgraph_runtime_mode.name))
|
|
# We skip EPLB here since we don't want to record dummy metrics
|
|
for num_tokens in compilation_cases:
|
|
for _ in range(self.compilation_config.cudagraph_num_of_warmups):
|
|
# Use CUDAGraphRuntimeStyle.NONE (default) for warmup.
|
|
# But be careful, warm up with `NONE`is orthogonal to
|
|
# if we want to warm up attention or not. This is
|
|
# different from the case where `FULL` implies capture
|
|
# attention while `PIECEWISE` implies no attention.
|
|
force_attention = (aclgraph_runtime_mode == CUDAGraphMode.FULL)
|
|
self._dummy_run(num_tokens,
|
|
aclgraph_runtime_mode=CUDAGraphMode.NONE,
|
|
force_attention=force_attention,
|
|
uniform_decode=uniform_decode)
|
|
self._dummy_run(num_tokens,
|
|
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
|
force_attention=force_attention,
|
|
uniform_decode=uniform_decode)
|
|
|
|
def _capture_model(self):
|
|
if not self.use_aclgraph:
|
|
logger.warning(
|
|
"Skipping ACL graph capture. To turn on ACL graph capture, "
|
|
"ensure `aclraph_mode` was not manually set to `NONE`")
|
|
return
|
|
else:
|
|
self.initialize_aclgraph_capture()
|
|
|
|
set_cudagraph_capturing_enabled(True)
|
|
# Trigger ACL graph capture for specific shapes.
|
|
# Capture the large shapes first so that the smaller shapes
|
|
# can reuse the memory pool allocated for the large shapes.
|
|
with graph_capture(device=self.device):
|
|
aclgraph_mode = self.compilation_config.cudagraph_mode
|
|
if aclgraph_mode.mixed_mode() != CUDAGraphMode.NONE:
|
|
aclgraph_runtime_mode = aclgraph_mode.mixed_mode()
|
|
|
|
compilation_cases = list(reversed(self.aclgraph_batch_sizes))
|
|
|
|
try:
|
|
self._capture_aclgraphs(
|
|
compilation_cases,
|
|
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
|
uniform_decode=False)
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
error_code = '0x7020023'
|
|
pattern = r'retCode=([^,\s\.]+)'
|
|
match = re.search(pattern, error_msg)
|
|
if match:
|
|
retCode = match.group(1)
|
|
# Determine whether the error message is caused by stream capture failure.
|
|
if match and retCode == error_code:
|
|
logger.error(
|
|
f"ACLgraph sizes capture fail: {type(e).__name__}:\n"
|
|
"ACLgraph has insufficient available streams to capture the configured number of sizes. "
|
|
"Please verify both the availability of adequate streams and the appropriateness of the configured size count.\n\n"
|
|
"Recommended solutions:\n"
|
|
"1. Manually configure the compilation_config parameter "
|
|
"with a reduced set of sizes: '{\"cudagraph_capture_sizes\":[size1, size2, size3, ...]}'.\n"
|
|
"2. Utilize ACLgraph's full graph mode as an alternative to the piece-wise approach.\n\n"
|
|
f"{str(e)}")
|
|
raise
|
|
|
|
if aclgraph_mode.decode_mode() == CUDAGraphMode.FULL and \
|
|
aclgraph_mode.separate_routine():
|
|
max_num_tokens = self.scheduler_config.max_num_seqs * \
|
|
self.uniform_decode_query_len
|
|
decode_cudagraph_batch_sizes = [
|
|
x for x in self.aclgraph_batch_sizes if x <= max_num_tokens
|
|
and x >= self.uniform_decode_query_len
|
|
]
|
|
compilation_cases_decode = list(
|
|
reversed(decode_cudagraph_batch_sizes))
|
|
self._capture_aclgraphs(
|
|
compilation_cases=compilation_cases_decode,
|
|
aclgraph_runtime_mode=CUDAGraphMode.FULL,
|
|
uniform_decode=True)
|
|
|
|
# Disable aclgraph capturing globally, so any unexpected aclgraph
|
|
# capturing will be detected and raise an error after here.
|
|
# Note: We don't put it into graph_capture context manager because
|
|
# we may doing lazy capturing in future that still allows capturing
|
|
# after here.
|
|
set_cudagraph_capturing_enabled(False)
|
|
|
|
def capture_model(self) -> None:
|
|
|
|
compilation_counter.num_gpu_runner_capture_triggers += 1
|
|
|
|
start_time = time.perf_counter()
|
|
start_free_npu_memory = torch.npu.mem_get_info()[0]
|
|
|
|
self._capture_model()
|
|
|
|
end_time = time.perf_counter()
|
|
end_free_npu_memory = torch.npu.mem_get_info()[0]
|
|
elapsed_time = end_time - start_time
|
|
npu_graph_size = start_free_npu_memory - end_free_npu_memory
|
|
# This usually takes 5~20 seconds.
|
|
logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
|
|
elapsed_time, npu_graph_size / (1 << 30))
|
|
|
|
def _get_prompt_logprobs_dict(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
scheduler_output: "SchedulerOutput",
|
|
) -> dict[str, Optional[LogprobsTensors]]:
|
|
num_prompt_logprobs_dict = self.input_batch.num_prompt_logprobs
|
|
if not num_prompt_logprobs_dict:
|
|
return {}
|
|
|
|
in_progress_dict = self.input_batch.in_progress_prompt_logprobs_cpu
|
|
prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] = {}
|
|
|
|
# Since prompt logprobs are a rare feature, prioritize simple,
|
|
# maintainable loop over optimal performance.
|
|
completed_prefill_reqs = []
|
|
for req_id, num_prompt_logprobs in num_prompt_logprobs_dict.items():
|
|
|
|
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
|
|
|
# Get metadata for this request.
|
|
request = self.requests[req_id]
|
|
num_prompt_tokens = len(request.prompt_token_ids)
|
|
prompt_token_ids = torch.tensor(request.prompt_token_ids).to(
|
|
self.device, non_blocking=True)
|
|
|
|
# Set up target LogprobsTensors object.
|
|
logprobs_tensors = in_progress_dict.get(req_id)
|
|
if not logprobs_tensors:
|
|
# Create empty logprobs CPU tensors for the entire prompt.
|
|
# If chunked, we'll copy in slice by slice.
|
|
logprobs_tensors = LogprobsTensors.empty_cpu(
|
|
num_prompt_tokens - 1, num_prompt_logprobs + 1)
|
|
in_progress_dict[req_id] = logprobs_tensors
|
|
|
|
# Determine number of logits to retrieve.
|
|
start_idx = request.num_computed_tokens
|
|
start_tok = start_idx + 1
|
|
num_remaining_tokens = num_prompt_tokens - start_tok
|
|
if num_tokens <= num_remaining_tokens:
|
|
# This is a chunk, more tokens remain.
|
|
# In the == case, there are no more prompt logprobs to produce
|
|
# but we want to defer returning them to the next step where we
|
|
# have new generated tokens to return.
|
|
num_logits = num_tokens
|
|
else:
|
|
# This is the last chunk of prompt tokens to return.
|
|
num_logits = num_remaining_tokens
|
|
completed_prefill_reqs.append(req_id)
|
|
prompt_logprobs_dict[req_id] = logprobs_tensors
|
|
|
|
if num_logits <= 0:
|
|
# This can happen for the final chunk if we prefilled exactly
|
|
# (num_prompt_tokens - 1) tokens for this request in the prior
|
|
# step. There are no more prompt logprobs to produce.
|
|
continue
|
|
|
|
# Get the logits corresponding to this req's prompt tokens.
|
|
# If this is a partial request (i.e. chunked prefill),
|
|
# then there is prompt logprob generated for each index.
|
|
req_idx = self.input_batch.req_id_to_index[req_id]
|
|
offset = self.query_start_loc_np[req_idx].item()
|
|
prompt_hidden_states = hidden_states[offset:offset + num_logits]
|
|
logits = self.model.compute_logits(prompt_hidden_states)
|
|
|
|
# Get the "target" tokens for each index. For prompt at index i,
|
|
# the token at prompt index i+1 is the "sampled" token we want
|
|
# to gather the logprob for.
|
|
tgt_token_ids = prompt_token_ids[start_tok:start_tok + num_logits]
|
|
|
|
# Compute prompt logprobs.
|
|
logprobs = self.sampler.compute_logprobs(logits)
|
|
token_ids, logprobs, ranks = self.sampler.gather_logprobs(
|
|
logprobs, num_prompt_logprobs, tgt_token_ids)
|
|
|
|
# Transfer NPU->CPU async.
|
|
chunk_slice = slice(start_idx, start_idx + num_logits)
|
|
logprobs_tensors.logprob_token_ids[chunk_slice].copy_(
|
|
token_ids, non_blocking=True)
|
|
logprobs_tensors.logprobs[chunk_slice].copy_(logprobs,
|
|
non_blocking=True)
|
|
logprobs_tensors.selected_token_ranks[chunk_slice].copy_(
|
|
ranks, non_blocking=True)
|
|
|
|
# Remove requests that have completed prefill from the batch
|
|
# num_prompt_logprobs_dict.
|
|
for req_id in completed_prefill_reqs:
|
|
del num_prompt_logprobs_dict[req_id]
|
|
del in_progress_dict[req_id]
|
|
|
|
# Must synchronize the non-blocking NPU->CPU transfers.
|
|
if prompt_logprobs_dict:
|
|
torch.npu.synchronize()
|
|
|
|
return prompt_logprobs_dict
|
|
|
|
def get_supported_pooling_tasks(self):
|
|
model = self.get_model()
|
|
if not is_pooling_model(model):
|
|
return []
|
|
|
|
return list(model.pooler.get_supported_tasks())
|
|
|
|
def _build_drafter_prepare_inputs_torchair_param(self):
|
|
return False
|