# # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. # Copyright 2023 The vLLM team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # This file is a part of the vllm-ascend project. # Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py # import copy import gc import itertools import time from collections import defaultdict from collections.abc import Iterator from contextlib import contextmanager, nullcontext from copy import deepcopy from dataclasses import dataclass from multiprocessing import Manager from typing import (TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Union, cast) import numpy as np import numpy.typing as npt import torch import torch._dynamo.cache_size import torch.distributed as dist import torch.nn as nn from tqdm import tqdm # type: ignore from vllm.attention import AttentionType, get_attn_backend from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.layer import Attention from vllm.compilation.counter import compilation_counter from vllm.compilation.monitor import set_cudagraph_capturing_enabled from vllm.config import (CompilationLevel, CUDAGraphMode, VllmConfig, get_layers_from_vllm_config) from vllm.distributed import tensor_model_parallel_all_gather from vllm.distributed.kv_transfer import (get_kv_transfer_group, has_kv_transfer_group) from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 from vllm.distributed.parallel_state import (get_dp_group, get_pp_group, get_tp_group, is_global_first_rank) from vllm.forward_context import BatchDescriptor, get_forward_context from vllm.logger import logger from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.model_loader import get_model from vllm.model_executor.models.interfaces import supports_transcription from vllm.model_executor.models.interfaces_base import ( VllmModelForPooling, is_pooling_model, is_text_generation_model) from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange from vllm.multimodal.utils import group_mm_kwargs_by_modality from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingType from vllm.sequence import IntermediateTensors from vllm.tasks import GenerationTask, PoolingTask, SupportedTask from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, LazyLoader, cdiv, get_dtype_size, is_pin_memory_available) from vllm.utils.jsontree import json_map_leaves from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder from vllm.v1.attention.backends.utils import ( AttentionCGSupport, reorder_batch_to_split_decodes_and_prefills) from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher # yapf conflicts with isort for this block # yapf: disable from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec, KVCacheSpec, MambaSpec, UniformTypeKVCacheSpecs) # yapf: enable from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, DraftTokenIds, LogprobsTensors, ModelRunnerOutput, PoolerOutput) from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.ngram_proposer import NgramProposer from vllm.v1.utils import CpuGpuBuffer from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin from vllm.v1.worker.utils import (AttentionGroup, bind_kv_cache, gather_mm_placeholders, sanity_check_mm_encoder_outputs, scatter_mm_placeholders) from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_forward_context import (MoECommType, set_ascend_forward_context) from vllm_ascend.attention.attention_mask import AttentionMaskBuilder from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.utils import AscendCommonAttentionMetadata from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper, set_graph_params, update_attn_params) from vllm_ascend.eplb.adaptor.vllm_adaptor import VllmEplbAdaptor from vllm_ascend.eplb.core.eplb_device_transfer_loader import \ D2DExpertWeightLoader from vllm_ascend.eplb.core.eplb_worker import EplbProcess from vllm_ascend.eplb.eplb_updator import EplbUpdator from vllm_ascend.eplb.utils import model_register from vllm_ascend.models.layers.mla import AscendMultiHeadLatentAttention from vllm_ascend.multistream.ms_split import compute_split_seq_index from vllm_ascend.platform import NPUPlatform from vllm_ascend.sample.logits_processor import build_logitsprocs from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler from vllm_ascend.spec_decode import get_spec_decode_method from vllm_ascend.spec_decode.eagle_proposer import EagleProposer from vllm_ascend.spec_decode.interface import SpecDcodeType from vllm_ascend.spec_decode.mtp_proposer import MtpProposer from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ, AscendSocVersion, ProfileExecuteDuration, get_ascend_soc_version, is_310p, lmhead_tp_enable) from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch if TYPE_CHECKING: import xgrammar as xgr # type: ignore[import-untyped] from vllm.v1.core.sched.output import SchedulerOutput else: xgr = LazyLoader("xgr", globals(), "xgrammar") import torch_npu import vllm_ascend.envs as envs_ascend # if true, allow tensor initialization and casting with internal format (e.g., NZ) torch.npu.config.allow_internal_format = True if is_310p(): torch_npu.npu.set_compile_mode(jit_compile=False) ACL_FORMAT = ACL_FORMAT_FRACTAL_NZ else: ACL_FORMAT = ACL_FORMAT_FRACTAL_ND @dataclass class GraphCaptureContext: stream: torch.npu.Stream @contextmanager def graph_capture(device: torch.device): """ `graph_capture` is a context manager which should surround the code that is capturing the NPU graph. Its main purpose is to ensure that the some operations will be run after the graph is captured, before the graph is replayed. It returns a `GraphCaptureContext` object which contains the necessary data for the graph capture. Currently, it only contains the stream that the graph capture is running on. This stream is set to the current NPU stream when the context manager is entered and reset to the default stream when the context manager is exited. This is to ensure that the graph capture is running on a separate stream from the default stream, in order to explicitly distinguish the kernels to capture from other kernels possibly launched on background in the default stream. """ graph_capture_context = GraphCaptureContext( torch.npu.Stream(device=device)) stream = graph_capture_context.stream # we use nullcontext now maybe_ca_context = nullcontext() # ensure all initialization operations complete before attempting to # capture the graph on another stream curr_stream = torch.npu.current_stream() if curr_stream != stream: stream.wait_stream(curr_stream) with torch.npu.stream(stream), maybe_ca_context: yield graph_capture_context # Wrapper for ModelRunnerOutput to support overlapped execution. class AsyncNPUModelRunnerOutput(AsyncModelRunnerOutput): def __init__( self, model_runner_output: ModelRunnerOutput, sampled_token_ids: torch.Tensor, invalid_req_indices: list[int], async_output_copy_stream: torch.npu.Stream, ): self._model_runner_output = model_runner_output self._invalid_req_indices = invalid_req_indices # Event on the copy stream so we can synchronize the non-blocking copy. self._async_copy_ready_event = torch.npu.Event() # Keep a reference to the device tensor to avoid it being # deallocated until we finish copying it to the host. self._sampled_token_ids = sampled_token_ids # Initiate the copy on a separate stream, but do not synchronize it. default_stream = torch.npu.current_stream() with torch.npu.stream(async_output_copy_stream): async_output_copy_stream.wait_stream(default_stream) self._sampled_token_ids_cpu = self._sampled_token_ids.to( 'cpu', non_blocking=True) self._async_copy_ready_event.record() def get_output(self) -> ModelRunnerOutput: """Copy the device tensors to the host and return a ModelRunnerOutput. This function blocks until the copy is finished. """ self._async_copy_ready_event.synchronize() # Release the device tensor once the copy has completed del self._sampled_token_ids valid_sampled_token_ids = self._sampled_token_ids_cpu.tolist() for i in self._invalid_req_indices: valid_sampled_token_ids[i].clear() output = self._model_runner_output output.sampled_token_ids = valid_sampled_token_ids return output class NPUModelRunner(LoRAModelRunnerMixin): def __init__(self, vllm_config: VllmConfig, device: torch.device): self.vllm_config = vllm_config self.model_config = vllm_config.model_config self.cache_config = vllm_config.cache_config self.compilation_config = vllm_config.compilation_config self.load_config = vllm_config.load_config self.lora_config = vllm_config.lora_config self.parallel_config = vllm_config.parallel_config self.pin_memory = is_pin_memory_available() self.scheduler_config = vllm_config.scheduler_config self.speculative_config = vllm_config.speculative_config self.block_size = vllm_config.cache_config.block_size self.max_num_blocks_per_req = cdiv(self.model_config.max_model_len, self.block_size) self.max_num_tokens = self.scheduler_config.max_num_batched_tokens decode_max_num_seqs = getattr(self.scheduler_config, 'decode_max_num_seqs', 0) self.max_num_reqs = max(self.scheduler_config.max_num_seqs, decode_max_num_seqs) self.dp_size = vllm_config.parallel_config.data_parallel_size self.dp_rank = vllm_config.parallel_config.data_parallel_rank self.device = device if envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP: self.prefetch_stream = torch.npu.Stream(device=device) else: self.prefetch_stream = None self.dtype = self.model_config.dtype if envs_ascend.VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION: # TODO: drop the env config to use ascend sampler by default from vllm_ascend.sample.sampler import AscendSampler self.sampler = AscendSampler() else: from vllm.v1.sample.sampler import Sampler self.sampler = Sampler() self.reorder_batch_threshold: Optional[int] = None # Lazy initialization, these will be set after __init__ self.kv_caches: List[torch.Tensor] = [] self.attn_groups: list[list[AttentionGroup]] = [] self.encoder_cache: Dict[str, torch.Tensor] = {} self.attn_mask = None self.attn_state = None self.requests: Dict[str, CachedRequestState] = {} self.intermediate_tensors: Optional[IntermediateTensors] = None self.runner_only_attn_layers: set[str] = set() self.ascend_config = get_ascend_config() if self.ascend_config.ascend_scheduler_config.enabled: self.chunked_prefill_enabled = self.scheduler_config.chunked_prefill_enabled else: self.chunked_prefill_enabled = True if self.cache_config.cache_dtype == "auto": self.kv_cache_dtype = self.dtype else: self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ self.cache_config.cache_dtype] # use_hybrid_blocks: if hybrid blocks is used. self.use_hybrid_blocks: bool = False self.need_accepted_tokens: bool = False self.is_multimodal_model = self.model_config.is_multimodal_model self.is_pooling_model = self.model_config.pooler_config is not None if self.is_multimodal_model: self.inputs_embeds = torch.zeros( (self.max_num_tokens, self.model_config.get_hidden_size()), dtype=self.dtype, device=self.device) # Set up Attention self.attn_backend = get_attn_backend( 0, self.dtype, None, self.block_size, use_mla=self.model_config.use_mla, use_sfa=self.ascend_config.use_sfa) if torch.version.cann.startswith("8.3"): self.attn_mask_builder = AttentionMaskBuilder( self.scheduler_config.max_num_batched_tokens, self.dtype, self.device) else: self.attn_mask_builder = AttentionMaskBuilder( self.model_config.max_model_len, self.dtype) # Set up speculative decoding. self.spec_attn_mask = None self.drafter: Optional[Union[NgramProposer, EagleProposer, MtpProposer]] = None self.actual_seq_lengths_q: list[int] = [] self.decode_token_per_req = 1 if self.speculative_config: spec_token_num = self.speculative_config.num_speculative_tokens assert spec_token_num > 0 self.decode_token_per_req = 1 + spec_token_num self.spec_attn_mask = torch.triu(torch.ones(2048, 2048, dtype=torch.bool), diagonal=1).to(self.device) if get_pp_group().is_last_rank: self.drafter = get_spec_decode_method( self.speculative_config.method, self.vllm_config, self.device, self) self.rejection_sampler = AscendRejectionSampler() # Persistent batch. self.input_ids = torch.zeros(self.max_num_tokens, dtype=torch.int32, device=self.device) self.positions = torch.zeros(self.max_num_tokens, dtype=torch.int64, device=self.device) self.query_start_loc = torch.zeros(self.max_num_reqs + 1, dtype=torch.int32, device=self.device) self.seq_lens = torch.zeros(self.max_num_reqs, dtype=torch.int32, device=self.device) self.slot_mapping = torch.zeros(self.max_num_tokens, dtype=torch.int32, device=self.device) self.uses_mrope = self.model_config.uses_mrope # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: # NOTE: `mrope_positions` is implemented with one additional dummy # position on purpose to make it non-contiguous so that it can work # with torch compile. # See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923 # NOTE: When M-RoPE is enabled, position ids are 3D regardless of # the modality of inputs. For text-only inputs, each dimension has # identical position IDs, making M-RoPE functionally equivalent to # 1D-RoPE. # See page 5 of https://arxiv.org/abs/2409.12191 self.mrope_positions = torch.zeros((3, self.max_num_tokens + 1), dtype=torch.int64, device=self.device) self.mrope_positions_cpu = torch.zeros( (3, self.max_num_tokens + 1), dtype=torch.int64, device="cpu", pin_memory=True) self.mrope_positions_np = self.mrope_positions_cpu.numpy() # OPTIMIZATION: Cache the tensors rather than creating them every step. self.arange_np: npt.NDArray[np.int32] = np.arange(max( self.max_num_reqs + 1, self.model_config.max_model_len, self.max_num_tokens), dtype=np.int32) # NOTE(woosuk): These tensors are "stateless", i.e., they are literally # a faster version of creating a new tensor every time. Thus, we should # not make any assumptions about the values in these tensors. self.input_ids_cpu = torch.zeros(self.max_num_tokens, dtype=torch.int32, device="cpu", pin_memory=True) self.positions_cpu = torch.zeros(self.max_num_tokens, dtype=torch.int64, device="cpu", pin_memory=True) self.positions_np = self.positions_cpu.numpy() self.slot_mapping_cpu = torch.zeros(self.max_num_tokens, dtype=torch.int32, device="cpu", pin_memory=True) self.slot_mapping_np = self.slot_mapping_cpu.numpy() self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1, dtype=torch.int32, device="cpu", pin_memory=True) self.query_start_loc_np = self.query_start_loc_cpu.numpy() self.seq_lens_cpu = torch.zeros(self.max_num_reqs, dtype=torch.int32, device="cpu", pin_memory=True) self.seq_lens_np = self.seq_lens_cpu.numpy() self.use_aclgraph = self._use_aclgraph() self.aclgraph_batch_sizes = list( reversed(self.compilation_config.cudagraph_capture_sizes)) self.uniform_decode_query_len = 1 if not self.speculative_config else \ 1 + self.speculative_config.num_speculative_tokens # aclgraph dispatcher for runtime aclgraph dispatching. self.aclgraph_dispatcher = CudagraphDispatcher(self.vllm_config) # Cached outputs. self._draft_token_ids: Optional[Union[list[list[int]], torch.Tensor]] = None # NOTE: we need to use `in_profile_run` to determine whether `enable_force_load_balance` is True self.in_profile_run = False # kv role self.is_kv_producer = False self.is_kv_consumer = False if vllm_config.kv_transfer_config is not None: self.is_kv_producer = vllm_config.kv_transfer_config.is_kv_producer self.is_kv_consumer = vllm_config.kv_transfer_config.is_kv_consumer # NOTE: Technically, MC2 can have 512 tokens each rank, but this will consume too much memory. The formula is: # ((maxBs * tokenNeedSizeDispatch * ep_worldsize * localMoeExpertNum) + (maxBs * tokenNeedSizeCombine * (k + sharedExpertNum))) * 2 # so we have to limit the MC2 tokens to save memory, should fix this in the future. self.mc2_tokens_capacity = 512 self.reserved_mc2_mask = torch.zeros( self.mc2_tokens_capacity, dtype=torch.bool, device=self.device, ) self.dynamic_eplb = self.ascend_config.dynamic_eplb if self.dynamic_eplb: self.is_eplb_warmuped = False self.eplb_loader = D2DExpertWeightLoader() self.manager = Manager() self.shared_dict = self.manager.dict({ "expert_map": None, "moe_load": None, "expert_maps": None }) self.eplb_process = EplbProcess(shared_dict=self.shared_dict, policy_type=1, enable_d2d=True) self.process = self.eplb_process._launch_process() ascend_config = get_ascend_config() self.eplb_updator = EplbUpdator(ascend_config, self.eplb_loader, self.eplb_process, self.process) self.use_async_scheduling = self.scheduler_config.async_scheduling self.async_output_copy_stream = torch.npu.Stream() if \ self.use_async_scheduling else None # Input Batch # NOTE(Chen): Ideally, we should initialize the input batch inside # `initialize_kv_cache` based on the kv cache config. However, as in # https://github.com/vllm-project/vllm/pull/18298, due to some unknown # reasons, we have to initialize the input batch before `load_model`, # quantization + weight offloading will fail otherwise. As a temporary # solution, we initialize the input batch here, and re-initialize it # in `initialize_kv_cache` if the block_sizes here is different from # the block_sizes in the kv cache config. self.input_batch = InputBatch( max_num_reqs=self.max_num_reqs, max_model_len=self.model_config.max_model_len, max_num_batched_tokens=self.max_num_tokens, device=self.device, pin_memory=self.pin_memory, vocab_size=self.model_config.get_vocab_size(), block_sizes=[self.block_size], is_spec_decode=bool(self.vllm_config.speculative_config), logitsprocs=build_logitsprocs( self.vllm_config, self.device, self.pin_memory, self.is_pooling_model, self.vllm_config.model_config.logits_processors), is_pooling_model=self.is_pooling_model, kernel_block_sizes=[[self.vllm_config.cache_config.block_size]], ) self.num_accepted_tokens = self._make_buffer(self.max_num_reqs, dtype=torch.int64) self.num_draft_tokens = self._make_buffer(self.max_num_reqs, dtype=torch.int32) def _make_buffer(self, *size: Union[int, torch.SymInt], dtype: torch.dtype, numpy: bool = True) -> CpuGpuBuffer: # Bfloat16 torch tensors cannot be directly cast to a numpy array, so # if a bfloat16 buffer is needed without a corresponding numpy array, # don't bother instantiating the numpy array. return CpuGpuBuffer(*size, dtype=dtype, device=self.device, pin_memory=self.pin_memory, with_numpy=numpy) def _update_states_after_model_execute( self, output_token_ids: torch.Tensor) -> None: """Update the cached states after model execution. This is used for MTP/EAGLE for hybrid models, as in linear attention, only the last token's state is kept. In MTP/EAGLE, for draft tokens the state are kept util we decide how many tokens are accepted for each sequence, and a shifting is done during the next iteration based on the number of accepted tokens. """ if not self.model_config.is_hybrid or not self.speculative_config: return # Find the number of accepted tokens for each sequence. num_accepted_tokens = (torch.cat( [ output_token_ids, torch.full((output_token_ids.size(0), 1), -1, device=output_token_ids.device), ], dim=1) == -1).int().argmax(-1).cpu().numpy() for i, num_tokens in enumerate(num_accepted_tokens): self.input_batch.num_accepted_tokens_cpu[i] = num_tokens def _use_aclgraph(self) -> bool: return self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE and self.compilation_config.level == CompilationLevel.PIECEWISE and not self.model_config.enforce_eager def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Remove finished requests from the cached states. for req_id in scheduler_output.finished_req_ids: self.requests.pop(req_id, None) # Remove the finished requests from the persistent batch. # NOTE(woosuk): There could be an edge case where finished_req_ids and # scheduled_req_ids overlap. This happens when a request is aborted and # then resubmitted with the same ID. In this case, we treat them as two # distinct requests - clearing the cached states for the first request # and handling the second as a new request. for req_id in scheduler_output.finished_req_ids: self.input_batch.remove_request(req_id) for mm_hash in scheduler_output.free_encoder_mm_hashes: self.encoder_cache.pop(mm_hash, None) # Remove the unscheduled requests from the persistent batch. # NOTE(woosuk): The unscheduled requests are either preempted requests # or running requests that are not scheduled in this step. We remove # them from the persistent batch but keep their cached states since # they will be scheduled again sometime in the future. scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys() cached_req_ids = self.input_batch.req_id_to_index.keys() unscheduled_req_ids = cached_req_ids - scheduled_req_ids # NOTE(woosuk): The persistent batch optimization assumes that # consecutive batches contain mostly the same requests. If batches # have low request overlap (e.g., alternating between two distinct # sets of requests), this optimization becomes very inefficient. for req_id in unscheduled_req_ids: self.input_batch.remove_request(req_id) req_ids_to_add: list[str] = [] # Add new requests to the cached states. for new_req_data in scheduler_output.scheduled_new_reqs: req_id = new_req_data.req_id sampling_params = new_req_data.sampling_params pooling_params = new_req_data.pooling_params if sampling_params and \ sampling_params.sampling_type == SamplingType.RANDOM_SEED: generator = torch.Generator(device=self.device) generator.manual_seed(sampling_params.seed) else: generator = None if pooling_params: assert (task := pooling_params.task) is not None, ( "You did not set `task` in the API") model = cast(VllmModelForPooling, self.get_model()) to_update = model.pooler.get_pooling_updates(task) to_update.apply(pooling_params) backward_kwargs = {} backward_kwargs["mm_features"] = new_req_data.mm_features self.requests[req_id] = CachedRequestState( req_id=req_id, prompt_token_ids=new_req_data.prompt_token_ids, sampling_params=sampling_params, pooling_params=pooling_params, generator=generator, block_ids=new_req_data.block_ids, num_computed_tokens=new_req_data.num_computed_tokens, output_token_ids=[], lora_request=new_req_data.lora_request, **backward_kwargs, ) # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: self._init_mrope_positions(self.requests[req_id]) req_ids_to_add.append(req_id) # Update the states of the running/resumed requests. is_last_rank = get_pp_group().is_last_rank req_data = scheduler_output.scheduled_cached_reqs for i, req_id in enumerate(req_data.req_ids): req_state = self.requests[req_id] num_computed_tokens = req_data.num_computed_tokens[i] new_block_ids = req_data.new_block_ids[i] resumed_from_preemption = req_data.resumed_from_preemption[i] # Update the cached states. req_state.num_computed_tokens = num_computed_tokens if not is_last_rank: # When using PP, the scheduler sends the sampled tokens back, # because there's no direct communication between the first- # stage worker and the last-stage worker. new_token_ids = req_data.new_token_ids[i] # Add the sampled token(s) from the previous step (if any). # This doesn't include "unverified" tokens like spec tokens. num_new_tokens = (num_computed_tokens + len(new_token_ids) - req_state.num_tokens) if num_new_tokens == 1: # Avoid slicing list in most common case. req_state.output_token_ids.append(new_token_ids[-1]) elif num_new_tokens > 0: req_state.output_token_ids.extend( new_token_ids[-num_new_tokens:]) # Update the block IDs. if not resumed_from_preemption: if new_block_ids is not None: # Append the new blocks to the existing block IDs. for block_ids, new_ids in zip(req_state.block_ids, new_block_ids): block_ids.extend(new_ids) else: assert new_block_ids is not None # The request is resumed from preemption. # Replace the existing block IDs with the new ones. req_state.block_ids = new_block_ids req_index = self.input_batch.req_id_to_index.get(req_id) if req_index is None: # The request is not in the persistent batch. # The request was either preempted and resumed later, or was not # scheduled in the previous step and needs to be added again. req_ids_to_add.append(req_id) continue # Update the persistent batch. self.input_batch.num_computed_tokens_cpu[req_index] = ( num_computed_tokens) if new_block_ids is not None: self.input_batch.block_table.append_row( new_block_ids, req_index) # For the last rank, we don't need to update the token_ids_cpu # because the sampled tokens are already cached. if not is_last_rank: # Add new_token_ids to token_ids_cpu. start_token_index = num_computed_tokens end_token_index = num_computed_tokens + len(new_token_ids) self.input_batch.token_ids_cpu[ req_index, start_token_index:end_token_index] = new_token_ids self.input_batch.num_tokens_no_spec[ req_index] = end_token_index self.input_batch.num_tokens[req_index] = end_token_index # Add spec_token_ids to token_ids_cpu. spec_token_ids = ( scheduler_output.scheduled_spec_decode_tokens.get(req_id, ())) if spec_token_ids: num_spec_tokens = len(spec_token_ids) start_index = self.input_batch.num_tokens_no_spec[req_index] end_token_index = start_index + num_spec_tokens self.input_batch.token_ids_cpu[ req_index, start_index:end_token_index] = spec_token_ids # NOTE(woosuk): `num_tokens` here may include spec tokens. self.input_batch.num_tokens[req_index] += num_spec_tokens # Add the new or resumed requests to the persistent batch. # The smaller empty indices are filled first. for req_id in req_ids_to_add: req_state = self.requests[req_id] self.input_batch.add_request(req_state) # Condense the batched states if there are gaps left by removed requests self.input_batch.condense() # Allow attention backend to reorder the batch, potentially self._may_reorder_batch(scheduler_output) # Refresh batch metadata with any pending updates. self.input_batch.refresh_metadata() def _init_mrope_positions(self, req_state: CachedRequestState): image_grid_thw = [] video_grid_thw = [] second_per_grid_ts = [] audio_feature_lengths = [] use_audio_in_video = False assert req_state.mm_features is not None for mm_feature in req_state.mm_features: mm_item = mm_feature.data if mm_item is None: continue mm_input = mm_item.get_data() if (t := mm_input.get("image_grid_thw")) is not None: image_grid_thw.append(t.tolist()) if (t := mm_input.get("video_grid_thw")) is not None: video_grid_thw.append(t.tolist()) if (t := mm_input.get("second_per_grid_ts")) is not None: second_per_grid_ts.append(t) if (t := mm_input.get("audio_feature_lengths")) is not None: audio_feature_lengths.append(t) if mm_input.get("use_audio_in_video") is True: use_audio_in_video = True req_state.mrope_positions, req_state.mrope_position_delta = \ MRotaryEmbedding.get_input_positions_tensor( req_state.prompt_token_ids, hf_config=self.model_config.hf_config, image_grid_thw=image_grid_thw, video_grid_thw=video_grid_thw, second_per_grid_ts=second_per_grid_ts, audio_feature_lengths=audio_feature_lengths, use_audio_in_video=use_audio_in_video, ) def _sync_metadata_across_dp( self, num_tokens: int, with_prefill: bool, enable_dbo: bool ) -> tuple[int, Optional[torch.Tensor], bool, bool]: # TODO: In vLLM, the only thing that needs to be synced is num_tokens, but in # our case, we still need to sync the other two flags as well. So we need to # include them in the all_reduce operation, and more over, we CANNOT skip it # even if we are running in eager mode, which harms performance. # FIXME: Restore the `or self.vllm_config.model_config.enforce_eager` here # immediately once the other two flags are no longer needed. if self.dp_size == 1: return num_tokens, None, with_prefill, enable_dbo # Sync num_tokens, with_prefill, enable_dbo across dp ranks num_tokens_tensor = torch.tensor([ num_tokens if i == self.dp_rank else 0 for i in range(self.dp_size) ], dtype=torch.int32, device="npu") flags_tensor = torch.tensor( [int(with_prefill), int(not enable_dbo)], dtype=torch.int32, device="npu") packed_tensor = torch.cat([num_tokens_tensor, flags_tensor]) dist.all_reduce(packed_tensor, group=get_dp_group().device_group) # Unpack the results num_tokens_across_dp = packed_tensor[:-2] synced_flags = packed_tensor[-2:] max_tokens_across_dp = torch.max(num_tokens_across_dp).item() global_with_prefill = bool(synced_flags[0]) global_enable_dbo = not bool(synced_flags[1]) # Create a tensor for num_tokens_after_padding num_tokens_after_padding = torch.tensor([max_tokens_across_dp] * self.dp_size, device="npu", dtype=torch.int32) return max_tokens_across_dp, num_tokens_after_padding, global_with_prefill, global_enable_dbo def _check_dbo_is_valid(self, query_lens: torch.Tensor, attn_state: AscendAttentionState, num_tokens: int) -> bool: # do the checks for dp + dbo if attn_state in [ AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding ]: return False # considering the case that one dp rank may enable dbo while others may not if not self.vllm_config.model_config.use_mla or not envs_ascend.VLLM_ASCEND_ENABLE_DBO: return False # TODO: remove it if token-level microbatch is enabled [token_index, seq_index] = compute_split_seq_index(query_lens, attn_state, num_tokens) if token_index == 0 or seq_index == 0 or seq_index == len( query_lens) or num_tokens < 256: return False return True def get_model(self) -> nn.Module: # get raw model out of the aclgraph wrapper. if isinstance(self.model, ACLGraphWrapper): return self.model.unwrap() return self.model def get_supported_generation_tasks(self) -> "list[GenerationTask]": model = self.get_model() supported_tasks = list[GenerationTask]() if is_text_generation_model(model): supported_tasks.append("generate") if supports_transcription(model): if model.supports_transcription_only: return ["transcription"] supported_tasks.append("transcription") return supported_tasks def get_supported_tasks(self) -> "tuple[SupportedTask, ...]": tasks = list[SupportedTask]() if self.model_config.runner_type == "generate": tasks.extend(self.get_supported_generation_tasks()) if self.model_config.runner_type == "pooling": tasks.extend(self.get_supported_pooling_tasks()) return tuple(tasks) def _make_attention_mask(self, seq_lens, position, attn_state) -> torch.Tensor: # Chunk Prefill situation. if attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla and not self.ascend_config.use_sfa: if torch.version.cann.startswith("8.3"): return self.attn_mask_builder.get_splitfuse_attn_mask() else: return self.attn_mask_builder.get_splitfuse_attn_mask( seq_lens, position, self.dtype, self.device) # Prefill without cache situation. elif attn_state == AscendAttentionState.PrefillNoCache: max_seq_len = max(seq_lens.max().item(), 0) return self.attn_mask_builder.get_attn_mask( max_seq_len, self.dtype, self.device) # Prefill with cache hit. elif attn_state == AscendAttentionState.PrefillCacheHit: return self.attn_mask_builder.get_attn_mask( 128, self.dtype, self.device) # Decode-only situation. else: return None def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"): mrope_pos_ptr = 0 for index, req_id in enumerate(self.input_batch.req_ids): req = self.requests[req_id] assert req.mrope_positions is not None num_computed_tokens = \ self.input_batch.num_computed_tokens_cpu[index] num_scheduled_tokens = \ scheduler_output.num_scheduled_tokens[req_id] num_prompt_tokens = len(req.prompt_token_ids) if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens: prompt_part_len = max(0, num_prompt_tokens - num_computed_tokens) completion_part_len = max( 0, num_scheduled_tokens - prompt_part_len) else: prompt_part_len = num_scheduled_tokens completion_part_len = 0 assert num_scheduled_tokens == prompt_part_len + completion_part_len if prompt_part_len > 0: # prompt's mrope_positions are pre-computed dst_start = mrope_pos_ptr dst_end = mrope_pos_ptr + prompt_part_len src_start = num_computed_tokens src_end = num_computed_tokens + prompt_part_len self.mrope_positions_cpu[:, dst_start:dst_end] = \ req.mrope_positions[:,src_start:src_end] mrope_pos_ptr += prompt_part_len if completion_part_len > 0: # compute completion's mrope_positions on-the-fly dst_start = mrope_pos_ptr dst_end = mrope_pos_ptr + completion_part_len MRotaryEmbedding.get_next_input_positions_tensor( out=self.mrope_positions_np, out_offset=dst_start, mrope_position_delta=req.mrope_position_delta, context_len=num_computed_tokens + prompt_part_len, num_new_tokens=completion_part_len, ) mrope_pos_ptr += completion_part_len def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs if not scheduled_encoder_inputs: return # Batch the multi-modal inputs. mm_kwargs, mm_hashes_pos = self._batch_mm_kwargs_from_scheduler( scheduler_output) encoder_outputs = [] for _, num_items, mm_kwargs_group in group_mm_kwargs_by_modality( mm_kwargs, device=self.device, pin_memory=True, ): # Run the encoder. # `curr_group_outputs` is either of the following: # 1. A tensor of shape (num_items, feature_size, hidden_size) # in case feature_size is fixed across all multimodal items. # 2. A list or tuple (length: num_items) of tensors, each of shape # (feature_size, hidden_size) in case the feature size is dynamic # depending on the input multimodal items. curr_group_outputs = self.model.get_multimodal_embeddings( **mm_kwargs_group) sanity_check_mm_encoder_outputs( curr_group_outputs, expected_num_items=num_items, ) for output in curr_group_outputs: encoder_outputs.append(output) for (mm_hash, pos_info), output in zip(mm_hashes_pos, encoder_outputs): self.encoder_cache[mm_hash] = scatter_mm_placeholders( output, is_embed=pos_info.is_embed, ) def _batch_mm_kwargs_from_scheduler( self, scheduler_output: "SchedulerOutput", ) -> tuple[list[MultiModalKwargsItem], list[tuple[str, PlaceholderRange]]]: """Batch multimodal kwargs from scheduled encoder inputs. Args: scheduler_output: The scheduler output containing scheduled encoder inputs. Returns: A tuple of (mm_kwargs, req_ids_pos) where: - mm_kwargs: List of multimodal kwargs items to be batched - mm_hashes_pos: List of (mm_hash, position_info) tuples """ scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs if not scheduled_encoder_inputs: return [], [] # Batch the multi-modal inputs. mm_kwargs = list[MultiModalKwargsItem]() # list of tuple (mm_hash, position_info) mm_hashes_pos = list[tuple[str, PlaceholderRange]]() for req_id, encoder_input_ids in scheduled_encoder_inputs.items(): req_state = self.requests[req_id] assert req_state.mm_features is not None for mm_input_id in encoder_input_ids: mm_feature = req_state.mm_features[mm_input_id] mm_hash = mm_feature.identifier mm_kwargs.append(mm_feature.data) mm_hashes_pos.append((mm_hash, mm_feature.mm_position)) return mm_kwargs, mm_hashes_pos def _gather_mm_embeddings( self, scheduler_output: "SchedulerOutput", ) -> list[torch.Tensor]: def _iter_mm_features(req_state: CachedRequestState): assert req_state.mm_features is not None for mm_feature in req_state.mm_features: pos_info = mm_feature.mm_position yield mm_feature.identifier, pos_info, getattr( pos_info, "is_embed", None) mm_embeds: list[torch.Tensor] = [] for req_id in self.input_batch.req_ids: num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ req_id] req_state = self.requests[req_id] num_computed_tokens = req_state.num_computed_tokens for mm_hash, pos_info, is_embed in _iter_mm_features(req_state): start_pos = pos_info.offset num_encoder_tokens = pos_info.length if start_pos >= num_computed_tokens + num_scheduled_tokens: break if start_pos + num_encoder_tokens <= num_computed_tokens: continue start_idx = max(num_computed_tokens - start_pos, 0) end_idx = min( num_computed_tokens - start_pos + num_scheduled_tokens, num_encoder_tokens, ) assert start_idx < end_idx encoder_output = self.encoder_cache.get(mm_hash, None) assert encoder_output is not None, \ f"Encoder cache miss for {mm_hash}." if is_embed is not None: is_embed = is_embed[start_idx:end_idx] mm_embeds_item = gather_mm_placeholders( encoder_output[start_idx:end_idx], is_embed=is_embed, ) mm_embeds.append(mm_embeds_item) return mm_embeds def _get_cumsum_and_arange( self, num_tokens: np.ndarray, cumsum_dtype: Optional[np.dtype] = None, ) -> tuple[np.ndarray, np.ndarray]: """Get the cumulative sum and batched arange of the given array. # E.g., [2, 5, 3] -> ([2, 7, 10], [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]) # Equivalent to but faster than: # np.concatenate([np.arange(n) for n in num_tokens]) """ # Step 1. [2, 5, 3] -> [2, 7, 10] cu_num_tokens = np.cumsum(num_tokens, dtype=cumsum_dtype) total_num_tokens = cu_num_tokens[-1] # Step 2. [2, 7, 10] -> [0, 0, 2, 2, 2, 2, 2, 7, 7, 7] cumsums_offsets = np.repeat(cu_num_tokens - num_tokens, num_tokens) # Step 3. [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] arange = self.arange_np[:total_num_tokens] - cumsums_offsets return cu_num_tokens, arange def _prepare_input_ids(self, total_num_scheduled_tokens: int, cu_num_tokens: np.ndarray) -> None: """Prepare the input IDs for the current batch. Carefully handles the `prev_sampled_token_ids` which can be cached from the previous engine iteration, in which case those tokens on the NPU need to be copied into the corresponding slots into input_ids.""" if self.input_batch.prev_sampled_token_ids is None: # Normal scheduling case self.input_ids[:total_num_scheduled_tokens].copy_( self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True) return # Async scheduling case, where some decode requests from the previous # iteration won't have entries in input_ids_cpu and need to be copied # on the NPU from prev_sampled_token_ids. prev_req_id_to_index = self.input_batch.prev_req_id_to_index assert prev_req_id_to_index is not None flattened_indices = [] prev_common_req_indices = [] indices_match = True max_flattened_index = -1 for req_id, cur_index in self.input_batch.req_id_to_index.items(): if (prev_index := prev_req_id_to_index.get(req_id)) is not None: prev_common_req_indices.append(prev_index) # We need to compute the flattened input_ids index of the # last token in each common request. flattened_index = cu_num_tokens[cur_index].item() - 1 flattened_indices.append(flattened_index) indices_match &= (prev_index == flattened_index) max_flattened_index = max(max_flattened_index, flattened_index) num_commmon_tokens = len(flattened_indices) if num_commmon_tokens < total_num_scheduled_tokens: # If not all requests are decodes from the last iteration, # We need to copy the input_ids_cpu to the NPU first. self.input_ids[:total_num_scheduled_tokens].copy_( self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True) if num_commmon_tokens == 0: # No requests in common with the previous iteration # So input_ids_cpu will have all the input ids. return if indices_match and max_flattened_index == (num_commmon_tokens - 1): # Common-case optimization: the batch is unchanged # and no reordering happened. # The indices are both the same permutation of 0..N-1 so # we can copy directly using a single slice. self.input_ids[:num_commmon_tokens].copy_( self.input_batch.prev_sampled_token_ids[:num_commmon_tokens, 0], non_blocking=True) return # Upload the index tensors asynchronously # so the scatter can be non-blocking. input_ids_index_tensor = torch.tensor(flattened_indices, dtype=torch.int64, pin_memory=self.pin_memory).to( self.device, non_blocking=True) prev_common_req_indices_tensor = torch.tensor( prev_common_req_indices, dtype=torch.int64, pin_memory=self.pin_memory).to(self.device, non_blocking=True) self.input_ids.scatter_(dim=0, index=input_ids_index_tensor, src=self.input_batch.prev_sampled_token_ids[ prev_common_req_indices_tensor, 0]) def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: """ Update the order of requests in the batch based on the attention backend's needs. For example, some attention backends (namely MLA) may want to separate requests based on if the attention computation will be compute-bound or memory-bound. Args: scheduler_output: The scheduler output. """ # Attention free models have zero kv_cache_goups, however models # like Mamba are also attention free but use the kv_cache for # keeping its internal state. This is why we check the number # of kv_cache groups instead of solely checking # for self.model_config.is_attention_free. if len(self.kv_cache_config.kv_cache_groups) == 0: return if self.reorder_batch_threshold is not None: reorder_batch_to_split_decodes_and_prefills( self.input_batch, scheduler_output, decode_threshold=self.reorder_batch_threshold) def _prepare_inputs( self, scheduler_output: "SchedulerOutput", intermediate_tensors: Optional[IntermediateTensors] = None, ) -> tuple[dict[str, Any], torch.Tensor, np.ndarray, int, torch.Tensor, int, torch.Tensor, SpecDecodeMetadata, Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], int]: total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens assert total_num_scheduled_tokens > 0 num_reqs = self.input_batch.num_reqs assert num_reqs > 0 # OPTIMIZATION: Start copying the block table first. # This way, we can overlap the copy with the following CPU operations. self.input_batch.block_table.commit_block_table(num_reqs) # Get the number of scheduled tokens for each request. req_ids = self.input_batch.req_ids tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] num_scheduled_tokens = np.array(tokens, dtype=np.int32) max_num_scheduled_tokens = num_scheduled_tokens.max() num_valid_tokens = np.array([ num_tokens - len(scheduler_output.scheduled_spec_decode_tokens.get(i, [])) for num_tokens, i in zip(tokens, req_ids) ], dtype=np.int32) if (self.use_aclgraph and total_num_scheduled_tokens <= self.aclgraph_batch_sizes[-1]): # Add padding to the batch size. num_input_tokens = self.vllm_config.pad_for_cudagraph( total_num_scheduled_tokens) else: # Eager mode. num_input_tokens = total_num_scheduled_tokens # Get the attention state. attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens, num_valid_tokens) self.attn_state = attn_state # type: ignore # Determine if it's a splitfuse batch with_prefill = attn_state not in [ AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding ] self.query_lens = torch.from_numpy(num_scheduled_tokens) enable_dbo = self._check_dbo_is_valid(self.query_lens.tolist(), attn_state, total_num_scheduled_tokens) # Get info across DP ranks. # NOTE: maybe_padded_num_tokens is only used when using TorchAir with DP, # Otherwise, it's just max_tokens_across_dp_cpu (maybe_padded_num_tokens, num_tokens_across_dp, with_prefill, enable_dbo) = self._sync_metadata_across_dp(num_input_tokens, with_prefill, enable_dbo) # TODO: Now that num_input_tokens is basically identical with maybe_padded_num_tokens # We should consider removing maybe_padded_num_tokens later num_input_tokens = maybe_padded_num_tokens # Hot-Swap lora model if self.lora_config: self.set_active_loras(self.input_batch, num_scheduled_tokens) # Get request indices. # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens) # cu_num_tokens: [2, 5, 3] -> [2, 7, 10] # arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] cu_num_tokens, arange = self._get_cumsum_and_arange( num_scheduled_tokens) positions_np = self.positions_np[:total_num_scheduled_tokens] np.add(self.input_batch.num_computed_tokens_cpu[req_indices], arange, out=positions_np) # Calculate M-RoPE positions. # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: self._calc_mrope_positions(scheduler_output) # Only relevant for models using M-RoPE (e.g, Qwen2-VL) self.mrope_positions[:, :total_num_scheduled_tokens].copy_( self.mrope_positions_cpu[:, :total_num_scheduled_tokens], non_blocking=True) # Get token indices. # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] # where M is the max_model_len. token_indices = (positions_np + req_indices * self.input_batch.token_ids_cpu.shape[1]) # Prepare input_ids. # NOTE(woosuk): We use torch.index_select instead of np.take here # because torch.index_select is much faster than np.take for large # tensors. torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(), 0, torch.from_numpy(token_indices), out=self.input_ids_cpu[:total_num_scheduled_tokens]) # Prepare some information for building Attention-Metadata # Compute and commit slot mapping self.input_batch.block_table.compute_slot_mapping( req_indices, positions_np) self.input_batch.block_table.commit_slot_mapping( total_num_scheduled_tokens) self.query_start_loc_np[0] = 0 self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens self.query_start_loc[:num_reqs + 1].copy_( self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True) self.seq_lens_np[:num_reqs] = ( self.input_batch.num_computed_tokens_cpu[:num_reqs] + num_scheduled_tokens) self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs], non_blocking=True) # Fill unused with -1. Needed for reshape_and_cache self.query_start_loc[num_reqs + 1:].fill_(-1) self.seq_lens[num_reqs:].fill_(0) self.query_lens = torch.from_numpy(num_scheduled_tokens) # Copy the tensors to the NPU. self.input_ids[:total_num_scheduled_tokens].copy_( self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True) self.positions_cpu[total_num_scheduled_tokens:num_input_tokens].zero_() self.positions[:num_input_tokens].copy_( self.positions_cpu[:num_input_tokens], non_blocking=True) # Make Attention metadata positions_cpu = self.positions_cpu[:num_input_tokens] positions = self.positions[:num_input_tokens] seq_lens_cpu = self.seq_lens_cpu[:num_reqs] self.attn_mask = self._make_attention_mask(seq_lens=seq_lens_cpu, position=positions_cpu, attn_state=attn_state) self.attn_state = attn_state # type: ignore self.with_prefill = with_prefill self.num_tokens_across_dp = num_tokens_across_dp self._update_graph_pad_size(with_prefill, maybe_padded_num_tokens) attn_metadata: dict[str, Any] = {} # Prepare input_ids token_indices = (positions_np + req_indices * self.input_batch.token_ids_cpu.shape[1]) torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(), 0, torch.from_numpy(token_indices), out=self.input_ids_cpu[:total_num_scheduled_tokens]) # Copy the tensors to the NPU. self._prepare_input_ids(total_num_scheduled_tokens, cu_num_tokens) # _prepare_inputs may reorder the batch, so we must gather # multi-modal outputs after that to ensure the correct order if self.is_multimodal_model: # Run the multimodal encoder if any. self._execute_mm_encoder(scheduler_output) mm_embeds = self._gather_mm_embeddings(scheduler_output) # NOTE(woosuk): To unify token ids and soft tokens (vision # embeddings), we always use embeddings (rather than token ids) # as input to the multimodal model, even when the input is text. input_ids = self.input_ids[:total_num_scheduled_tokens] if mm_embeds: inputs_embeds = self.model.get_input_embeddings( input_ids, mm_embeds) else: inputs_embeds = self.model.get_input_embeddings(input_ids) # TODO(woosuk): Avoid the copy. Optimize. self.inputs_embeds[:total_num_scheduled_tokens].copy_( inputs_embeds) inputs_embeds = self.inputs_embeds[:num_input_tokens] input_ids = None else: # For text-only models, we use token ids as input. # While it is possible to use embeddings as input just like the # multimodal models, it is not desirable for performance since # then the embedding layer is not included in the ACL graph. input_ids = self.input_ids[:num_input_tokens] inputs_embeds = None positions = self.positions[:num_input_tokens] input_ids, positions = self._update_input_ids_and_positions( input_ids, positions, num_input_tokens, with_prefill, maybe_padded_num_tokens) if get_pp_group().is_first_rank: intermediate_tensors = None else: assert intermediate_tensors is not None assert self.intermediate_tensors is not None for k, v in intermediate_tensors.items(): self.intermediate_tensors[k][:num_input_tokens].copy_( v[:num_input_tokens], non_blocking=True) intermediate_tensors = IntermediateTensors({ k: v[:num_input_tokens] for k, v in self.intermediate_tensors.items() }) use_spec_decode = len( scheduler_output.scheduled_spec_decode_tokens) > 0 if not use_spec_decode: # NOTE(woosuk): Due to chunked prefills, the batch may contain # partial requests. While we should not sample any token # from these partial requests, we do so for simplicity. # We will ignore the sampled tokens from the partial requests. # TODO: Support prompt logprobs. spec_decode_metadata = None logits_indices = torch.from_numpy(cu_num_tokens - 1).to( self.device, non_blocking=True) else: # Get the number of draft tokens for each request. # Iterate over the dictionary rather than all requests since not all # requests have draft tokens. num_draft_tokens = np.zeros(num_reqs, dtype=np.int32) for req_id, draft_token_ids in ( scheduler_output.scheduled_spec_decode_tokens.items()): req_idx = self.input_batch.req_id_to_index[req_id] num_draft_tokens[req_idx] = len(draft_token_ids) spec_decode_metadata = self._calc_spec_decode_metadata( num_draft_tokens, cu_num_tokens) logits_indices = spec_decode_metadata.logits_indices self.num_draft_tokens.np[:num_reqs] = num_draft_tokens self.num_draft_tokens.np[num_reqs:].fill(0) self.num_draft_tokens.copy_to_gpu() # Used in the below loop. # query_start_loc_cpu = self.query_start_loc.cpu[:num_reqs + 1] num_computed_tokens_cpu = ( self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs]) spec_decode_common_attn_metadata = None if use_spec_decode and self.need_accepted_tokens: self.num_accepted_tokens.np[:num_reqs] = ( self.input_batch.num_accepted_tokens_cpu[:num_reqs]) self.num_accepted_tokens.np[num_reqs:].fill(1) self.num_accepted_tokens.copy_to_gpu() # Prepare the attention metadata for each KV cache group and make layers # in the same group share the same metadata. for kv_cache_group_id, kv_cache_group_spec in enumerate( self.kv_cache_config.kv_cache_groups): blk_table = self.input_batch.block_table[kv_cache_group_id] blk_table_tensor = blk_table.get_device_tensor() slot_mapping = blk_table.slot_mapping_cpu[: total_num_scheduled_tokens] self.slot_mapping[:total_num_scheduled_tokens].copy_( slot_mapping[:total_num_scheduled_tokens], non_blocking=True, ) # Make AscendCommonAttentionMetadata common_attn_metadata = AscendCommonAttentionMetadata( query_start_loc=self.query_start_loc[:num_reqs + 1], query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1], seq_lens_cpu=self.seq_lens_cpu, seq_lens=self.seq_lens_cpu[:num_reqs], num_reqs=num_reqs, num_actual_tokens=total_num_scheduled_tokens, actual_seq_lengths_q=self.actual_seq_lengths_q, # TODO: change this to the right block table for linear attn block_table_tensor=blk_table_tensor[:num_reqs], slot_mapping=self.slot_mapping, num_computed_tokens_cpu=num_computed_tokens_cpu, positions=self.positions, attn_mask=self.attn_mask, spec_attn_mask=self.spec_attn_mask, attn_state=self.attn_state, enable_dbo_across_dp=enable_dbo, is_only_prefill=bool(np.all(num_valid_tokens != 1)), max_query_len=max_num_scheduled_tokens, graph_pad_size=self.graph_pad_size, decode_token_per_req=self.decode_token_per_req, ) if self.speculative_config and \ spec_decode_common_attn_metadata is None: spec_decode_common_attn_metadata = common_attn_metadata for attn_group in self.attn_groups[kv_cache_group_id]: common_prefix_len = 0 extra_attn_metadata_args = {} builder = attn_group.get_metadata_builder() if isinstance(builder, GDNAttentionMetadataBuilder): if use_spec_decode: extra_attn_metadata_args = dict( num_accepted_tokens=self.num_accepted_tokens. gpu[:num_reqs], num_draft_tokens=self.num_draft_tokens. gpu[:num_reqs], ) attn_metadata_i = builder.build( common_prefix_len=common_prefix_len, common_attn_metadata=common_attn_metadata, **extra_attn_metadata_args) else: attn_metadata_i = builder.build( common_prefix_len=common_prefix_len, common_attn_metadata=common_attn_metadata, model=self.model, **extra_attn_metadata_args) if self.vllm_config.model_config.use_mla or self.ascend_config.use_sfa: attn_metadata_i.num_input_tokens = num_input_tokens for layer_name in attn_group.layer_names: attn_metadata[layer_name] = attn_metadata_i if lmhead_tp_enable(): max_num_reqs_across_dp = maybe_padded_num_tokens if not with_prefill else self.max_num_reqs logits_indices = nn.functional.pad( logits_indices, (0, max_num_reqs_across_dp - logits_indices.shape[0])) return (attn_metadata, positions, num_scheduled_tokens, num_input_tokens, num_tokens_across_dp, maybe_padded_num_tokens, logits_indices, spec_decode_metadata, input_ids, inputs_embeds, intermediate_tensors, max_num_scheduled_tokens) def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill, maybe_padded_num_tokens, input_ids, positions, intermediate_tensors, inputs_embeds): assert self.model is not None hidden_states = self.model( input_ids=input_ids, positions=positions, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, ) forward_context = get_forward_context() if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL: update_attn_params(self.update_stream, forward_context, positions.shape[0]) if get_forward_context().sp_enabled: hidden_states = tensor_model_parallel_all_gather(hidden_states, 0) pad_size = get_forward_context().pad_size if pad_size > 0: hidden_states = hidden_states[:-pad_size, :] return hidden_states def _build_attn_state(self, num_reqs, num_scheduled_tokens, num_valid_tokens): ascend_config = get_ascend_config() if np.array_equal(self.seq_lens_np[:num_reqs], num_scheduled_tokens): attn_state = AscendAttentionState.PrefillNoCache # We assume it is the decode stage, where prefill occurs but only one token is not hit in cache. elif np.all(num_scheduled_tokens == 1): attn_state = AscendAttentionState.DecodeOnly if self.speculative_config and self.speculative_config.method == 'deepseek_mtp': # SpecDecoding now supports seq_len=1 and seq_len=2 # In Prefilling Decoding Disaggregation scenario, SpecDecoding need to supports seq_len=1 attn_state = AscendAttentionState.SpecDecoding # Speculative decoding. elif np.all(num_valid_tokens == 1): if self.drafter and (self.drafter.name == SpecDcodeType.EAGLE or self.drafter.name == SpecDcodeType.EAGLE3): attn_state = AscendAttentionState.ChunkedPrefill else: attn_state = AscendAttentionState.SpecDecoding # splitfuse elif not ascend_config.ascend_scheduler_config.enabled or self.chunked_prefill_enabled: attn_state = AscendAttentionState.ChunkedPrefill else: attn_state = AscendAttentionState.PrefillCacheHit return attn_state def _update_graph_pad_size(self, with_prefill, graph_pad_size): self.graph_pad_size = -1 def _update_input_ids_and_positions(self, input_ids, positions, num_input_tokens, with_prefill, maybe_padded_num_tokens): if self.uses_mrope: positions = self.mrope_positions[:, :num_input_tokens] return input_ids, positions def _calc_spec_decode_metadata( self, num_draft_tokens: np.ndarray, cu_num_scheduled_tokens: np.ndarray, ) -> SpecDecodeMetadata: # Inputs: # cu_num_scheduled_tokens: [ 4, 104, 107, 207, 209] # num_draft_tokens: [ 3, 0, 2, 0, 1] # Outputs: # cu_num_draft_tokens: [ 3, 3, 5, 5, 6] # logits_indices: [ 0, 1, 2, 3, 103, 104, 105, 106, # 206, 207, 208] # target_logits_indices: [ 0, 1, 2, 5, 6, 9] # bonus_logits_indices: [ 3, 4, 7, 8, 10] # Compute the logits indices. # [4, 1, 3, 1, 2] num_sampled_tokens = num_draft_tokens + 1 # Step 1. [4, 5, 8, 9, 11] cu_num_sampled_tokens = np.cumsum(num_sampled_tokens, dtype=np.int32) total_num_sampled_tokens = cu_num_sampled_tokens[-1] # Step 2. [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9] cumsums_offsets = np.repeat(cu_num_sampled_tokens - num_sampled_tokens, num_sampled_tokens) # Step 3. [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1] arange = self.arange_np[:total_num_sampled_tokens] - cumsums_offsets # Step 4. [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207] logits_indices = np.repeat( cu_num_scheduled_tokens - num_sampled_tokens, num_sampled_tokens) # Step 5. [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208] logits_indices += arange # Compute the bonus logits indices. bonus_logits_indices = cu_num_sampled_tokens - 1 # Compute the draft logits indices. # [3, 3, 5, 5, 6] cu_num_draft_tokens = np.cumsum(num_draft_tokens, dtype=np.int32) total_num_draft_tokens = cu_num_draft_tokens[-1] # [0, 0, 0, 3, 3, 5] cumsums_offsets = np.repeat(cu_num_draft_tokens - num_draft_tokens, num_draft_tokens) # [0, 1, 2, 0, 1, 0] arange = self.arange_np[:total_num_draft_tokens] - cumsums_offsets # [0, 0, 0, 5, 5, 9] target_logits_indices = np.repeat( cu_num_sampled_tokens - num_sampled_tokens, num_draft_tokens) # [0, 1, 2, 5, 6, 9] target_logits_indices += arange # TODO: Optimize the CPU -> NPU copy. cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).to( self.device, non_blocking=True) logits_indices = torch.from_numpy(logits_indices).to(self.device, non_blocking=True) target_logits_indices = torch.from_numpy(target_logits_indices).to( self.device, non_blocking=True) bonus_logits_indices = torch.from_numpy(bonus_logits_indices).to( self.device, non_blocking=True) # Compute the draft token ids. # draft_token_indices: [ 1, 2, 3, 105, 106, 208] draft_token_ids = self.input_ids[logits_indices] draft_token_ids = draft_token_ids[target_logits_indices + 1] metadata = SpecDecodeMetadata( draft_token_ids=draft_token_ids, num_draft_tokens=num_draft_tokens.tolist(), cu_num_draft_tokens=cu_num_draft_tokens, target_logits_indices=target_logits_indices, bonus_logits_indices=bonus_logits_indices, logits_indices=logits_indices, ) return metadata def apply_grammar_bitmask( self, scheduler_output: "SchedulerOutput", logits: torch.Tensor, ) -> torch.Tensor: grammar_bitmask = scheduler_output.grammar_bitmask # We receive the structured output bitmask from the scheduler, # compacted to contain bitmasks only for structured output requests. # The order of the requests in the bitmask is not guaranteed to be the # same as the order of the requests in the gpu runner's batch. We need # to sort the bitmask to match the order of the requests used here. # Get the batch indices of the structured output requests. # Keep track of the number of speculative tokens scheduled for every # request in the batch, as the logit indices are offset by this amount. struct_out_req_batch_indices: dict[str, int] = {} cumulative_offset = 0 seq = sorted(self.input_batch.req_id_to_index.items(), key=lambda x: x[1]) for req_id, batch_index in seq: logit_index = batch_index + cumulative_offset cumulative_offset += len( scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])) if req_id in scheduler_output.structured_output_request_ids: struct_out_req_batch_indices[req_id] = logit_index out_indices = [] # Reorder the bitmask to match the order of the requests in the batch. sorted_bitmask = np.zeros_like(grammar_bitmask, shape=(logits.shape[0], grammar_bitmask.shape[1])) cumulative_index = 0 seq = sorted(scheduler_output.structured_output_request_ids.items(), key=lambda x: x[1]) for req_id, _ in seq: logit_index = struct_out_req_batch_indices[req_id] num_spec_tokens = len( scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])) for i in range(1 + num_spec_tokens): sorted_bitmask[logit_index + i] = \ grammar_bitmask[cumulative_index + i] out_indices.append(logit_index + i) cumulative_index += 1 + num_spec_tokens grammar_bitmask = sorted_bitmask # Serialization of np.ndarray is much more efficient than a tensor, # so we receive it in that format. grammar_bitmask = torch.from_numpy(grammar_bitmask) # NOTE: # 1. XGrammar bitmask applying only supports CPU and GPU. # 2. The logits and bitmask should be on the same device. # 3. XGrammar logits on CPU only supports float32 dtype. logits_dtype = logits.dtype logits = logits.to("cpu").float() xgr.apply_token_bitmask_inplace( logits, grammar_bitmask, indices=out_indices, ) return logits.to(self.device).to(logits_dtype) def propose_draft_token_ids( self, valid_sampled_token_ids: list[list[int]], sampling_metadata: SamplingMetadata, scheduler_output: "SchedulerOutput", spec_decode_metadata: SpecDecodeMetadata, positions: torch.Tensor, num_scheduled_tokens: int, hidden_states: torch.Tensor, attn_metadata: dict[str, Any], aux_hidden_states: torch.Tensor = None, ) -> Optional[list[list[int]]]: if not self.drafter: # Speculative decoding is not enabled. draft_token_ids = None else: draft_token_ids = self.drafter.generate_token_ids( valid_sampled_token_ids, sampling_metadata, scheduler_output, spec_decode_metadata, positions, num_scheduled_tokens, hidden_states, attn_metadata, aux_hidden_states) return draft_token_ids def _pool( self, hidden_states: torch.Tensor, num_scheduled_tokens: int, num_scheduled_tokens_np: np.ndarray, finished_sending: Optional[set[str]] = None, finished_recving: Optional[set[str]] = None, kv_connector_output: Optional["KVConnectorOutput"] = None, ) -> ModelRunnerOutput: assert self.input_batch.num_reqs ==\ len(self.input_batch.pooling_params), \ "Either all or none of the requests in" \ " a batch must be pooling request" hidden_states = hidden_states[:num_scheduled_tokens] pooling_metadata = self.input_batch.pooling_metadata pooling_metadata.build_pooling_cursor(num_scheduled_tokens_np.tolist(), device=hidden_states.device) seq_lens_cpu = self.seq_lens_cpu[:self.input_batch.num_reqs] model = cast(VllmModelForPooling, self.model) raw_pooler_output = model.pooler( hidden_states=hidden_states, pooling_metadata=pooling_metadata, ) raw_pooler_output = json_map_leaves( lambda x: x.to("cpu", non_blocking=True), raw_pooler_output, ) torch.npu.synchronize() pooler_output: list[Optional[torch.Tensor]] = [] for raw_output, seq_len, prompt_len in zip( raw_pooler_output, seq_lens_cpu, pooling_metadata.prompt_lens): output = raw_output if seq_len == prompt_len else None pooler_output.append(output) return ModelRunnerOutput( req_ids=self.input_batch.req_ids, req_id_to_index=self.input_batch.req_id_to_index, sampled_token_ids=[], logprobs=None, prompt_logprobs_dict={}, pooler_output=pooler_output, kv_connector_output=kv_connector_output, ) def _select_moe_comm_method(self, num_tokens: int, with_prefill: bool) -> MoECommType: """1. If expert parallel is not enabled, we use all-gather since MC2 and all-to-all are designed for expert parallelism. 2. If expert parallel is enabled, we need to consider the soc version and the number of tokens. This is based on the observation that all-gather is more efficient than all-to-all when running on A2. a. For A2, we choose from MC2 and all-gather. b. For A3, we choose from MC2 and all-to-all. In both cases, we use MC2 when the number of tokens is smaller than a its capacity threshold. Args: num_tokens (int): The number of tokens in the current batch. Raises: ValueError: If the soc version is unsupported. Returns: MoECommType: The selected MoE communication method. """ soc_version = get_ascend_soc_version() quant_type = getattr(self.vllm_config.model_config.hf_config, 'moe_quantize', None) model_type = self.vllm_config.model_config.hf_config.model_type if not self.parallel_config.enable_expert_parallel: moe_comm_type = MoECommType.ALLGATHER elif soc_version in {AscendSocVersion.A2}: if (num_tokens <= self.mc2_tokens_capacity and self.parallel_config.world_size_across_dp >= 16): moe_comm_type = MoECommType.MC2 else: # Currently, w4a8_dynamic does not support allgatherep if quant_type == "w4a8_dynamic": moe_comm_type = MoECommType.ALLTOALL else: moe_comm_type = MoECommType.ALLGATHER elif soc_version in {AscendSocVersion.A3}: moe_comm_type = (MoECommType.MC2 if num_tokens <= self.mc2_tokens_capacity else MoECommType.ALLTOALL) else: raise ValueError(f"Unsupported soc_version: {soc_version}") if moe_comm_type == MoECommType.ALLGATHER and with_prefill: moe_comm_type = MoECommType.NAIVE_MULTICAST # PanguProMoE only supports allgather if model_type == "PanguProMoE": moe_comm_type = MoECommType.ALLGATHER if is_global_first_rank(): logger.debug(f"num_tokens: {num_tokens}, " f"moe_comm_type: {moe_comm_type}") return moe_comm_type @torch.inference_mode() def execute_model( self, scheduler_output: "SchedulerOutput", intermediate_tensors: Optional[IntermediateTensors] = None, ) -> Union[ModelRunnerOutput, AsyncModelRunnerOutput, IntermediateTensors]: with ProfileExecuteDuration().capture_async("prepare input"): self._update_states(scheduler_output) if not scheduler_output.total_num_scheduled_tokens: if not has_kv_transfer_group(): logger.debug( "skip this step for we receive the data from remote disaggregate prefill node" ) # Return empty ModelRunnerOuptut if there's no work to do. return EMPTY_MODEL_RUNNER_OUTPUT return self.kv_connector_no_forward(scheduler_output) if self.dynamic_eplb: self.eplb_updator.forward_before() (attn_metadata, positions, num_scheduled_tokens_np, num_input_tokens, num_tokens_across_dp, maybe_padded_num_tokens, logits_indices, spec_decode_metadata, input_ids, inputs_embeds, intermediate_tensors, max_query_len) = (self._prepare_inputs(scheduler_output, intermediate_tensors)) if self.dynamic_eplb: self.eplb_updator.take_update_info_from_eplb_process() moe_comm_type = self._select_moe_comm_method(num_input_tokens, self.with_prefill) uniform_decode = (max_query_len == self.uniform_decode_query_len) and ( scheduler_output.total_num_scheduled_tokens == self.input_batch.num_reqs * max_query_len) batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens, uniform_decode=uniform_decode) aclgraph_runtime_mode, batch_descriptor = \ self.aclgraph_dispatcher.dispatch(batch_descriptor) # Run forward pass with ProfileExecuteDuration().capture_async("forward"): with set_ascend_forward_context( attn_metadata, self.vllm_config, num_tokens=num_input_tokens, num_tokens_across_dp=num_tokens_across_dp, with_prefill=self.with_prefill, reserved_mc2_mask=self.reserved_mc2_mask, moe_comm_type=moe_comm_type, aclgraph_runtime_mode=aclgraph_runtime_mode, batch_descriptor=batch_descriptor, num_actual_tokens=scheduler_output. total_num_scheduled_tokens, prefetch_stream=self.prefetch_stream, model_instance=self.model): self.maybe_setup_kv_connector(scheduler_output) hidden_states = self._generate_process_reqs_hidden_states( attn_metadata, self.with_prefill, maybe_padded_num_tokens, input_ids, positions, intermediate_tensors, inputs_embeds) self.maybe_wait_for_kv_save() finished_sending, finished_recving = self.get_finished_kv_transfer( scheduler_output) aux_hidden_states = None if self.drafter and self.drafter.name == SpecDcodeType.EAGLE3: hidden_states, aux_hidden_states = hidden_states kv_connector_output = KVConnectorOutput( finished_sending=finished_sending, finished_recving=finished_recving) finished_sending = None finished_recving = None with ProfileExecuteDuration().capture_async("post process"): # Broadcast PP output for external_launcher (torchrun) # to make sure we are synced across pp ranks # TODO: Support overlapping mirco-batches # https://github.com/vllm-project/vllm/issues/18019 broadcast_pp_output = \ self.parallel_config.distributed_executor_backend \ == "external_launcher" and len(get_pp_group().ranks) > 0 if not get_pp_group().is_last_rank: # For mid-pipeline stages, return the hidden states. if not broadcast_pp_output: hidden_states.kv_connector_output = kv_connector_output return hidden_states assert isinstance(hidden_states, IntermediateTensors) get_pp_group().send_tensor_dict( hidden_states.tensors, all_gather_group=get_tp_group()) logits = None else: if self.input_batch.pooling_params: return self._pool( hidden_states, scheduler_output.total_num_scheduled_tokens, num_scheduled_tokens_np, finished_sending, finished_recving, kv_connector_output) sample_hidden_states = hidden_states[logits_indices] logits = self.model.compute_logits(sample_hidden_states) if broadcast_pp_output: model_output_broadcast_data = { "logits": logits.contiguous(), } if logits is not None else {} model_output_broadcast_data = get_pp_group( ).broadcast_tensor_dict(model_output_broadcast_data, src=len(get_pp_group().ranks) - 1) assert model_output_broadcast_data is not None logits = model_output_broadcast_data["logits"] # Apply structured output bitmasks if present if scheduler_output.grammar_bitmask is not None: logits = self.apply_grammar_bitmask(scheduler_output, logits) # Sample the next token and get logprobs if needed. sampling_metadata = self.input_batch.sampling_metadata if spec_decode_metadata is None: if lmhead_tp_enable() and logits is not None: logits = logits[:self.input_batch.num_reqs] sampler_output = self.sampler( logits=logits, sampling_metadata=sampling_metadata, ) else: if lmhead_tp_enable() and logits is not None: logits = logits[:len(spec_decode_metadata.logits_indices)] # When indexing with a tensor (bonus_logits_indices), PyTorch # creates a new tensor with separate storage from the original # logits tensor. This means any in-place operations on bonus_logits # won't affect the original logits tensor. assert logits is not None bonus_logits = logits[ spec_decode_metadata.bonus_logits_indices] sampler_output = self.sampler( logits=bonus_logits, sampling_metadata=sampling_metadata, ) bonus_token_ids = sampler_output.sampled_token_ids # Just like `bonus_logits`, `target_logits` is a new tensor with # separate storage from the original `logits` tensor. Therefore, # it is safe to update `target_logits` in place. target_logits = logits[ spec_decode_metadata.target_logits_indices] output_token_ids = self.rejection_sampler( spec_decode_metadata, None, # draft_probs target_logits, bonus_token_ids, sampling_metadata, ) sampler_output.sampled_token_ids = output_token_ids if self.need_accepted_tokens: self._update_states_after_model_execute(output_token_ids) discard_sampled_tokens_req_indices: list[int] = [] # TODO(woosuk): The following loop can be slow since it iterates over # the requests one by one. Optimize. discard_sampled_tokens_req_indices = [] for i, req_id in enumerate(self.input_batch.req_ids): req_state = self.requests[req_id] seq_len = (req_state.num_computed_tokens + scheduler_output.num_scheduled_tokens[req_id]) if seq_len < req_state.num_tokens: # Ignore the sampled token. # Rewind the generator state as if the token was not sampled. generator = self.input_batch.generators.get(i) if generator is not None: generator.set_offset(generator.get_offset() - 4) discard_sampled_tokens_req_indices.append(i) # Copy some objects so they don't get modified after returning. # This is important when using async scheduling. req_ids_output_copy = self.input_batch.req_ids.copy() req_id_to_index_output_copy = \ self.input_batch.req_id_to_index.copy() # NOTE: NPU -> CPU Sync happens here. # Move as many CPU operations as possible before this sync point. logprobs_tensors = sampler_output.logprobs_tensors logprobs_lists = logprobs_tensors.tolists() \ if logprobs_tensors is not None else None # Compute prompt logprobs if needed. prompt_logprobs_dict = self._get_prompt_logprobs_dict( hidden_states[:scheduler_output.total_num_scheduled_tokens], scheduler_output, ) num_sampled_tokens = sampler_output.sampled_token_ids.shape[0] sampled_token_ids = sampler_output.sampled_token_ids if not self.use_async_scheduling: # Get the valid generated tokens. max_gen_len = sampled_token_ids.shape[-1] if max_gen_len == 1: # No spec decode tokens. valid_sampled_token_ids = sampled_token_ids.tolist() else: # Includes spec decode tokens. valid_sampled_token_ids = self.rejection_sampler.parse_output( sampled_token_ids, self.input_batch.vocab_size, ) # Mask out the sampled tokens that should not be sampled. for i in discard_sampled_tokens_req_indices: valid_sampled_token_ids[i].clear() else: valid_sampled_token_ids = [] invalid_req_indices = list(discard_sampled_tokens_req_indices) invalid_req_indices_set = set(invalid_req_indices) assert sampled_token_ids.shape[-1] == 1 # Cache the sampled tokens on the NPU and avoid CPU sync. # These will be copied into input_ids in the next step # when preparing inputs. self.input_batch.prev_sampled_token_ids = \ sampled_token_ids self.input_batch.prev_sampled_token_ids_invalid_indices = \ invalid_req_indices_set self.input_batch.prev_req_id_to_index = { req_id: i for i, req_id in enumerate(self.input_batch.req_ids) if i not in invalid_req_indices_set } # Cache the sampled tokens in the model runner, so that the scheduler # doesn't need to send them back. # NOTE(woosuk): As an exception, when using PP, the scheduler sends # the sampled tokens back, because there's no direct communication # between the first-stage worker and the last-stage worker. for req_idx in range(num_sampled_tokens): if self.use_async_scheduling: sampled_ids = [-1] * 1 if \ req_idx not in invalid_req_indices_set else None else: sampled_ids = valid_sampled_token_ids[req_idx] if not sampled_ids: continue start_idx = self.input_batch.num_tokens_no_spec[req_idx] end_idx = start_idx + len(sampled_ids) assert end_idx <= self.model_config.max_model_len, ( "Sampled token IDs exceed the max model length. " f"Total number of tokens: {end_idx} > max_model_len: " f"{self.model_config.max_model_len}") self.input_batch.token_ids_cpu[req_idx, start_idx:end_idx] = sampled_ids self.input_batch.num_tokens_no_spec[req_idx] = end_idx self.input_batch.num_tokens[req_idx] = end_idx req_id = self.input_batch.req_ids[req_idx] req_state = self.requests[req_id] req_state.output_token_ids.extend(sampled_ids) if self.speculative_config: self._draft_token_ids = self.propose_draft_token_ids( valid_sampled_token_ids, sampling_metadata, scheduler_output, spec_decode_metadata, positions, scheduler_output.total_num_scheduled_tokens, hidden_states, attn_metadata, aux_hidden_states, ) if has_kv_transfer_group(): get_kv_transfer_group().clear_connector_metadata() extra_args = ({"kv_connector_output": kv_connector_output}) model_runner_output = ModelRunnerOutput( req_ids=req_ids_output_copy, req_id_to_index=req_id_to_index_output_copy, sampled_token_ids=valid_sampled_token_ids, logprobs=logprobs_lists, prompt_logprobs_dict=prompt_logprobs_dict, pooler_output=[], **extra_args, ) durations = ProfileExecuteDuration().pop_captured_sync() if durations: dr_str = [ f"[{tag}]:{duration:.2f}ms" for tag, duration in durations.items() ] captured_name = "Decode" if self.attn_state == AscendAttentionState.DecodeOnly else "Prefill" logger.info("Profile execute duration [%s]:%s", captured_name, " ".join(dr_str)) if self.dynamic_eplb: self.eplb_updator.forward_end() if not self.use_async_scheduling: return model_runner_output return AsyncNPUModelRunnerOutput( model_runner_output=model_runner_output, sampled_token_ids=sampled_token_ids, invalid_req_indices=invalid_req_indices, async_output_copy_stream=self.async_output_copy_stream, ) def take_draft_token_ids(self) -> Optional[DraftTokenIds]: if self._draft_token_ids is None: return None req_ids = self.input_batch.req_ids if isinstance(self._draft_token_ids, torch.Tensor): draft_token_ids = self._draft_token_ids.tolist() else: draft_token_ids = self._draft_token_ids self._draft_token_ids = None return DraftTokenIds(req_ids, draft_token_ids) def kv_connector_no_forward( self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput: with set_ascend_forward_context(None, self.vllm_config): self.maybe_setup_kv_connector(scheduler_output) finished_sending, finished_recving = ( self.get_finished_kv_transfer(scheduler_output)) # For the case of no forward caused by receiving remote kv, # one round of dummy inference is necessary # to prevent hang over the collective calls. output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) output.kv_connector_output = KVConnectorOutput( finished_sending=finished_sending, finished_recving=finished_recving) return output @staticmethod def maybe_setup_kv_connector(scheduler_output: "SchedulerOutput"): # Update KVConnector with the KVConnector metadata forward(). if has_kv_transfer_group(): kv_connector = get_kv_transfer_group() assert isinstance(kv_connector, KVConnectorBase_V1) assert scheduler_output.kv_connector_metadata is not None kv_connector.bind_connector_metadata( scheduler_output.kv_connector_metadata) kv_connector.start_load_kv(get_forward_context()) @staticmethod def maybe_wait_for_kv_save() -> None: if has_kv_transfer_group(): get_kv_transfer_group().wait_for_save() @staticmethod def get_finished_kv_transfer( scheduler_output: "SchedulerOutput", ) -> tuple[Optional[set[str]], Optional[set[str]]]: if has_kv_transfer_group(): return get_kv_transfer_group().get_finished( scheduler_output.finished_req_ids) return None, None def _build_attention_metadata(self, create_mixed_batch, num_reqs, num_tokens, max_query_len, force_attention): attn_metadata: Optional[dict[str, Any]] = None if force_attention: attn_metadata = {} if create_mixed_batch: raise NotImplementedError( "force_attention=True is not supported for mixed batches.") else: seq_lens = self.model_config.max_model_len self.seq_lens_np[:num_reqs] = seq_lens self.seq_lens_np[num_reqs:] = 0 num_computed_tokens_cpu = ( self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs]) for kv_cache_group_id, kv_cache_group_spec in enumerate( self.kv_cache_config.kv_cache_groups): block_table_tensor = self.input_batch.block_table[ kv_cache_group_id].get_device_tensor() common_attn_metadata = AscendCommonAttentionMetadata( query_start_loc=self.query_start_loc[:num_reqs + 1], query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1], seq_lens_cpu=self.seq_lens_cpu, seq_lens=self.seq_lens_cpu[:num_reqs], num_reqs=num_reqs, num_actual_tokens=num_tokens, actual_seq_lengths_q=self.actual_seq_lengths_q, block_table_tensor=block_table_tensor[:num_reqs], slot_mapping=self.slot_mapping, num_computed_tokens_cpu=num_computed_tokens_cpu, max_query_len=max_query_len, decode_token_per_req=self.decode_token_per_req, ) for attn_group in self.attn_groups[kv_cache_group_id]: builder = attn_group.get_metadata_builder() attn_metadata_i = builder.build_for_graph_capture( common_attn_metadata) for layer_name in kv_cache_group_spec.layer_names: attn_metadata[layer_name] = attn_metadata_i return attn_metadata def _generate_dummy_run_hidden_states(self, with_prefill, is_torchair_compile, input_ids, positions, attn_metadata, num_tokens, intermediate_tensors, inputs_embeds): hidden_states = self.model(input_ids=input_ids, positions=positions, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds) forward_context = get_forward_context() assert forward_context is not None if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL: update_attn_params(self.update_stream, forward_context, positions.shape[0]) if self.drafter and self.drafter.name == SpecDcodeType.EAGLE3: hidden_states, _ = hidden_states else: hidden_states = hidden_states return hidden_states @torch.inference_mode() def _dummy_run( self, num_tokens: int, with_prefill: bool = False, is_torchair_compile: bool = False, aclgraph_runtime_mode: Optional[CUDAGraphMode] = None, force_attention: bool = False, uniform_decode: bool = False, ) -> torch.Tensor: # only support eager mode and piecewise graph now assert aclgraph_runtime_mode is None or aclgraph_runtime_mode in { CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL } # Padding for DP (num_tokens, num_tokens_across_dp, with_prefill, _) = self._sync_metadata_across_dp(num_tokens, with_prefill, False) moe_comm_type = self._select_moe_comm_method(num_tokens, with_prefill) # If cudagraph_mode.decode_mode() == FULL and # cudagraph_mode.seperate_routine(). This means that we are using # different graphs and/or modes for mixed prefill-decode batches vs. # uniform decode batches. A uniform decode batch means that all # requests have identical query length, except a potential virtual # request (shorter) in the batch account for padding. # Uniform decode batch could either be common pure decode, where # max_query_len == 1, or speculative decode, where # max_query_len == 1 + num_spec_decode_tokens. # When setting max_query_len = 1, we switch to and capture the optimized # routine of FA2 for pure decode, i.e., Flashdecode + an optimization # for GQA/MQA. max_query_len = self.uniform_decode_query_len if uniform_decode else \ num_tokens # Set num_scheduled_tokens based on num_tokens and max_num_seqs # for dummy run with LoRA so that the num_reqs collectively # has num_tokens in total. assert num_tokens <= self.scheduler_config.max_num_batched_tokens max_num_reqs = self.scheduler_config.max_num_seqs if uniform_decode: num_reqs = cdiv(num_tokens, max_query_len) num_scheduled_tokens_list = [max_query_len] * num_reqs if num_tokens % max_query_len != 0: num_scheduled_tokens_list[-1] = num_tokens % max_query_len else: if with_prefill: num_reqs = num_tokens else: num_reqs = (num_tokens + self.decode_token_per_req - 1) // self.decode_token_per_req num_reqs = min(num_reqs, max_num_reqs) min_tokens_per_req = num_tokens // num_reqs num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs num_scheduled_tokens_list[-1] += num_tokens % num_reqs assert sum(num_scheduled_tokens_list) == num_tokens assert len(num_scheduled_tokens_list) == num_reqs num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32) # Force dummy run on prefill stage when this node is deemed as kv producer. if self.is_kv_producer and not self.is_kv_consumer: with_prefill = True # TODO(cmq): check if with_prefill is reasonable attn_metadata = self._build_attention_metadata( False, num_reqs=num_reqs, num_tokens=num_tokens, max_query_len=max_query_len, force_attention=force_attention, ) if not self.in_profile_run and self.dynamic_eplb: self.eplb_updator.forward_before() with self.maybe_dummy_run_with_lora(self.lora_config, num_scheduled_tokens): if self.is_multimodal_model: input_ids = None inputs_embeds = self.inputs_embeds[:num_tokens] else: input_ids = self.input_ids[:num_tokens] inputs_embeds = None if self.uses_mrope: positions = self.mrope_positions[:, :num_tokens] else: positions = self.positions[:num_tokens] if get_pp_group().is_first_rank: intermediate_tensors = None else: if self.intermediate_tensors is None: self.intermediate_tensors = ( self.model.make_empty_intermediate_tensors( batch_size=num_tokens, dtype=self.dtype, device=self.device)) intermediate_tensors = IntermediateTensors({ k: v[:num_tokens] for k, v in self.intermediate_tensors.items() }) # filter out the valid batch descriptor _ag_mode, batch_descriptor = \ self.aclgraph_dispatcher.dispatch( BatchDescriptor(num_tokens=num_tokens, uniform_decode=uniform_decode)) if aclgraph_runtime_mode is not None: # we allow forcing NONE when the dispatcher disagrees to support # warm ups for aclgraph capture assert aclgraph_runtime_mode == CUDAGraphMode.NONE or \ aclgraph_runtime_mode == _ag_mode, ( f"Aclgraph runtime mode mismatch at dummy_run. " f"Expected {_ag_mode}, but got {aclgraph_runtime_mode}.") else: aclgraph_runtime_mode = _ag_mode need_dummy_logits = (not self.in_profile_run and lmhead_tp_enable()) if need_dummy_logits: max_num_reqs_across_dp = num_tokens if not with_prefill else max_num_reqs dummy_indices = torch.zeros(max_num_reqs_across_dp, dtype=torch.int32) def dummy_compute_logits(hidden_states): return self.model.compute_logits( hidden_states[dummy_indices]) with set_ascend_forward_context( attn_metadata, self.vllm_config, num_tokens=num_tokens, num_tokens_across_dp=num_tokens_across_dp, with_prefill=with_prefill, in_profile_run=self.in_profile_run, reserved_mc2_mask=self.reserved_mc2_mask, moe_comm_type=moe_comm_type, num_actual_tokens=0, aclgraph_runtime_mode=aclgraph_runtime_mode, batch_descriptor=batch_descriptor, prefetch_stream=self.prefetch_stream, model_instance=self.model): hidden_states = self._generate_dummy_run_hidden_states( with_prefill, is_torchair_compile, input_ids, positions, attn_metadata, num_tokens, intermediate_tensors, inputs_embeds) if need_dummy_logits: dummy_compute_logits(hidden_states) if self.drafter: self.drafter.dummy_run( num_tokens=num_tokens, with_prefill=with_prefill, skip_attn=True, num_reqs=num_reqs, num_tokens_across_dp=num_tokens_across_dp) if need_dummy_logits: dummy_compute_logits(hidden_states) if self.in_profile_run and self.dynamic_eplb: self.model.clear_all_moe_loads() if not self.in_profile_run and self.dynamic_eplb: self.eplb_updator.take_update_info_from_eplb_process() self.eplb_updator.forward_end() return hidden_states @contextmanager def set_in_profile_run(self): self.in_profile_run = True try: yield finally: self.in_profile_run = False def profile_run(self) -> None: # Trigger compilation for general shape. with self.set_in_profile_run(): hidden_states = self._dummy_run(self.max_num_tokens, with_prefill=True) # MC2 will consume additional NPU memory. # Therefore, we need to run the MC2 path once here to complete its initialization, # allowing vLLM to correctly estimate the maximum memory required. if self.max_num_tokens > self.mc2_tokens_capacity and \ self._select_moe_comm_method( self.mc2_tokens_capacity, with_prefill=True) == MoECommType.MC2: self._dummy_run(self.mc2_tokens_capacity, with_prefill=True) output = None if get_pp_group().is_last_rank: if self.is_pooling_model: output = self._dummy_pooler_run(hidden_states) else: # For profile, have maximum num_reqs and that collectively have # maximum num_tokens. min_tokens_per_req = self.max_num_tokens // self.max_num_reqs num_scheduled_tokens_list = [min_tokens_per_req ] * self.max_num_reqs num_scheduled_tokens_list[ -1] += self.max_num_tokens % self.max_num_reqs num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32) logit_indices = np.cumsum(num_scheduled_tokens) - 1 # TODO: need to rum a dummy sampler for generate task hidden_states = hidden_states[logit_indices] output = self.model.compute_logits(hidden_states) NPUPlatform.synchronize() del hidden_states, output self.encoder_cache.clear() gc.collect() def _dummy_pooler_run_task( self, hidden_states: torch.Tensor, task: PoolingTask, ) -> PoolerOutput: num_tokens = hidden_states.shape[0] max_num_reqs = self.scheduler_config.max_num_seqs num_reqs = min(num_tokens, max_num_reqs) min_tokens_per_req = num_tokens // num_reqs num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs num_scheduled_tokens_list[-1] += num_tokens % num_reqs assert sum(num_scheduled_tokens_list) == num_tokens assert len(num_scheduled_tokens_list) == num_reqs req_num_tokens = num_tokens // num_reqs dummy_token_ids = torch.zeros((num_reqs, req_num_tokens), dtype=torch.int32, device=self.device) model = cast(VllmModelForPooling, self.get_model()) dummy_pooling_params = PoolingParams(task=task) to_update = model.pooler.get_pooling_updates(task) to_update.apply(dummy_pooling_params) dummy_prompt_lens = torch.tensor( num_scheduled_tokens_list, device="cpu", ) dummy_metadata = PoolingMetadata( prompt_lens=dummy_prompt_lens, prompt_token_ids=dummy_token_ids, pooling_params=[dummy_pooling_params] * num_reqs, ) dummy_metadata.build_pooling_cursor(num_scheduled_tokens_list, device=hidden_states.device) try: return model.pooler(hidden_states=hidden_states, pooling_metadata=dummy_metadata) except RuntimeError as e: if 'out of memory' in str(e): raise RuntimeError( "CUDA out of memory occurred when warming up pooler " f"({task=}) with {num_reqs} dummy requests. Please try " "lowering `max_num_seqs` or `gpu_memory_utilization` when " "initializing the engine.") from e else: raise e @torch.inference_mode() def _dummy_pooler_run( self, hidden_states: torch.Tensor, ) -> PoolerOutput: # Find the task that has the largest output for subsequent steps output_size = dict[PoolingTask, float]() for task in self.get_supported_pooling_tasks(): # Run a full batch with each task to ensure none of them OOMs output = self._dummy_pooler_run_task(hidden_states, task) output_size[task] = sum(o.nbytes for o in output) del output # Allow GC max_task = max(output_size.items(), key=lambda x: x[1])[0] return self._dummy_pooler_run_task(hidden_states, max_task) def eplb_warmup(self): if self.dynamic_eplb and not self.is_eplb_warmuped: self.is_eplb_warmuped = True self.eplb_adaptor = VllmEplbAdaptor(model=self.model) self.eplb_loader.set_adator(self.eplb_adaptor) self.eplb_updator.set_adaptor(self.eplb_adaptor) self.eplb_updator.warm_up_eplb() def load_model(self) -> None: logger.info("Starting to load model %s...", self.model_config.model) with DeviceMemoryProfiler() as m: # noqa: SIM117 self.model = get_model(vllm_config=self.vllm_config) if self.dynamic_eplb: model_register(self.model, self.model_config) if is_310p(): from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) for module in self.model.modules(): if isinstance(module, (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear)): module.weight.data = self._convert_torch_format( module.weight.data) if self.drafter: logger.info("Loading drafter model...") self.drafter.load_model(self.model) if self.drafter.name == SpecDcodeType.EAGLE3: self.model.set_aux_hidden_state_layers( self.model.get_eagle3_aux_hidden_state_layers()) if self.lora_config: self.model = self.load_lora_model(self.model, self.vllm_config, self.device) logger.info("Loading model weights took %.4f GB", m.consumed_memory / float(2**30)) # wrap the model with full graph wrapper if needed. if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): self.update_stream = torch.npu.Stream() set_graph_params(self.compilation_config.cudagraph_capture_sizes) self.model = ACLGraphWrapper(self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL) def _convert_torch_format(self, tensor): tensor = torch_npu.npu_format_cast(tensor, ACL_FORMAT) return tensor def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: """ Initialize KV cache based on `kv_cache_config`. Args: kv_cache_config: Configuration for the KV cache, including the KV cache size of each layer """ kv_cache_config = deepcopy(kv_cache_config) self.kv_cache_config = kv_cache_config self.initialize_attn_backend(kv_cache_config) self.use_hybrid_blocks = (len(self.attn_groups) > 1) # NOTE: Currently, we determine whether we need `num_accepted_tokens` through `MambaSpec`. self.need_accepted_tokens = any([ isinstance(attn_group[0].kv_cache_spec, MambaSpec) for attn_group in self.attn_groups ]) self.may_reinitialize_input_batch(kv_cache_config) if self.ascend_config.is_deepseek_sfa: kv_caches = self.initialize_kv_cache_tensors_deepseek_sfa( kv_cache_config) elif self.model_config.is_deepseek_mla: kv_caches = self.initialize_kv_cache_tensors_deepseek_mla( kv_cache_config) else: kv_caches = self.initialize_kv_cache_tensors(kv_cache_config) if has_kv_transfer_group(): get_kv_transfer_group().register_kv_caches(kv_caches) def _align_memory(self, tensor: torch.Tensor, alignment: int) -> torch.Tensor: data_ptr = tensor.data_ptr() aligned_addr = (data_ptr + alignment - 1) // alignment * alignment offset = (aligned_addr - data_ptr) // tensor.element_size() return tensor[int(offset):] def initialize_kv_cache_tensors_deepseek_sfa( self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]: kv_cache_sizes = {} for kv_cache_tensor in kv_cache_config.kv_cache_tensors: assert len(kv_cache_tensor.shared_by) == 1, ( "KV cache tensor shared by multiple layers is not supported in " "NPU.") kv_cache_sizes[kv_cache_tensor.shared_by[0]] = kv_cache_tensor.size kv_caches: Dict[str, torch.Tensor] = {} for group in self._kv_cache_spec_attn_group_iterator(): kv_cache_spec = group.kv_cache_spec attn_backend = group.backend for layer_name in group.layer_names: if layer_name in self.runner_only_attn_layers: continue tensor_size = kv_cache_sizes[layer_name] num_blocks = tensor_size // kv_cache_spec.page_size_bytes if self.vllm_config.additional_config.get( "kv_cache_dtype", None) == 'int8': kv_cache_shape = attn_backend.get_bsh_kv_cache_shape( num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) elif hasattr( attn_backend, "get_supported_block_size" ) and not self.model_config.is_deepseek_mla and not self.ascend_config.is_deepseek_sfa: block_size = attn_backend.get_supported_block_size()[0] block_size_chunk = kv_cache_spec.block_size // block_size kv_cache_shape = attn_backend.get_kv_cache_shape( num_blocks * block_size_chunk, block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) else: kv_cache_shape = self.attn_backend.get_kv_cache_shape( num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) dtype = kv_cache_spec.dtype alignment = 2 * 1024 * 1024 num_blocks, block_size, num_kv_heads, head_size = kv_cache_shape rope_dim = self.model_config.hf_text_config.qk_rope_head_dim nope_dim = head_size - rope_dim nope_cache_shape = (num_blocks, block_size, num_kv_heads, nope_dim) rope_cache_shape = (num_blocks, block_size, num_kv_heads, rope_dim) #### k cache # TODO(zzzzwwjj): wait transformers add these params k_cache_shape = (num_blocks, block_size, 1, 128) if self.vllm_config.kv_transfer_config is None: # For no disaggregate pd scenario, allocate kv cache in normal way rope_cache = torch.zeros(rope_cache_shape, dtype=dtype, device=self.device) nope_cache = torch.zeros(nope_cache_shape, dtype=dtype, device=self.device) rope_cache = self._convert_torch_format(rope_cache) nope_cache = self._convert_torch_format(nope_cache) #### k cache k_cache = torch.zeros(k_cache_shape, dtype=dtype, device=self.device) k_cache = self._convert_torch_format(k_cache) else: # In order to transfer kv cache through the reigster_memory api from llmdatadist, the memory # address should be aligned by 2M. In most case, torch_npu can allocate 2M aligned memory, but # we found there are also some exceptions during test, so we manual align those memory here, this part # of code may consume 2M * 2 * elem_size memory every layer. nope_allocate_shape = num_blocks * block_size * num_kv_heads * nope_dim nope_allocate_shape_alignment = nope_allocate_shape + alignment rope_allocate_shape = num_blocks * block_size * num_kv_heads * rope_dim rope_allocate_shape_alignment = rope_allocate_shape + alignment nope_cache = torch.zeros(nope_allocate_shape_alignment, dtype=dtype, device=self.device) rope_cache = torch.zeros(rope_allocate_shape_alignment, dtype=dtype, device=self.device) #### k cache # TODO(zzzzwwjj): wait transformers add these params k_allocate_shape = num_blocks * block_size * 1 * 128 k_allocate_shape_alignment = k_allocate_shape + alignment k_cache = torch.zeros(k_allocate_shape_alignment, dtype=dtype, device=self.device) nope_cache = self._align_memory( nope_cache, alignment)[:nope_allocate_shape].view(nope_cache_shape) rope_cache = self._align_memory( rope_cache, alignment)[:rope_allocate_shape].view(rope_cache_shape) k_cache = self._align_memory( k_cache, alignment)[:k_allocate_shape].view(k_cache_shape) kv_caches[layer_name] = (nope_cache, rope_cache, k_cache) bind_kv_cache(kv_caches, self.compilation_config.static_forward_context, self.kv_caches) return kv_caches def initialize_kv_cache_tensors_deepseek_mla( self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]: kv_cache_sizes = {} for kv_cache_tensor in kv_cache_config.kv_cache_tensors: assert len(kv_cache_tensor.shared_by) == 1, ( "KV cache tensor shared by multiple layers is not supported in " "NPU.") kv_cache_sizes[kv_cache_tensor.shared_by[0]] = kv_cache_tensor.size kv_caches: Dict[str, torch.Tensor] = {} for group in self._kv_cache_spec_attn_group_iterator(): kv_cache_spec = group.kv_cache_spec attn_backend = group.backend for layer_name in group.layer_names: if layer_name in self.runner_only_attn_layers: continue tensor_size = kv_cache_sizes[layer_name] num_blocks = tensor_size // kv_cache_spec.page_size_bytes if self.vllm_config.additional_config.get( "kv_cache_dtype", None) == 'int8': kv_cache_shape = attn_backend.get_bsh_kv_cache_shape( num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) elif hasattr(attn_backend, "get_supported_block_size" ) and not self.model_config.is_deepseek_mla: block_size = attn_backend.get_supported_block_size()[0] block_size_chunk = kv_cache_spec.block_size // block_size kv_cache_shape = attn_backend.get_kv_cache_shape( num_blocks * block_size_chunk, block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) else: kv_cache_shape = self.attn_backend.get_kv_cache_shape( num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) dtype = kv_cache_spec.dtype alignment = 2 * 1024 * 1024 num_blocks, block_size, num_kv_heads, head_size = kv_cache_shape rope_dim = self.model_config.hf_text_config.qk_rope_head_dim nope_dim = head_size - rope_dim nope_cache_shape = (num_blocks, block_size, num_kv_heads, nope_dim) rope_cache_shape = (num_blocks, block_size, num_kv_heads, rope_dim) if self.vllm_config.kv_transfer_config is None: # For no disaggregate pd scenario, allocate kv cache in normal way rope_cache = torch.zeros(rope_cache_shape, dtype=dtype, device=self.device) nope_cache = torch.zeros(nope_cache_shape, dtype=dtype, device=self.device) rope_cache = self._convert_torch_format(rope_cache) nope_cache = self._convert_torch_format(nope_cache) else: # In order to transfer kv cache through the reigster_memory api from llmdatadist, the memory # address should be aligned by 2M. In most case, torch_npu can allocate 2M aligned memory, but # we found there are also some exceptions during test, so we manual align those memory here, this part # of code may consume 2M * 2 * elem_size memory every layer. nope_allocate_shape = num_blocks * block_size * num_kv_heads * nope_dim nope_allocate_shape_alignment = nope_allocate_shape + alignment rope_allocate_shape = num_blocks * block_size * num_kv_heads * rope_dim rope_allocate_shape_alignment = rope_allocate_shape + alignment nope_cache = torch.zeros(nope_allocate_shape_alignment, dtype=dtype, device=self.device) rope_cache = torch.zeros(rope_allocate_shape_alignment, dtype=dtype, device=self.device) nope_cache = self._align_memory( nope_cache, alignment)[:nope_allocate_shape].view(nope_cache_shape) rope_cache = self._align_memory( rope_cache, alignment)[:rope_allocate_shape].view(rope_cache_shape) kv_caches[layer_name] = (nope_cache, rope_cache) bind_kv_cache(kv_caches, self.compilation_config.static_forward_context, self.kv_caches) return kv_caches def initialize_kv_cache_tensors( self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]: """ Initialize the memory buffer for KV cache. Args: kv_cache_config: The KV cache config Returns: Dict[str, torch.Tensor]: A map between layer names to their corresponding memory buffer for KV cache. """ # init kv cache tensors kv_cache_raw_tensors: dict[str, Union[torch.Tensor, Optional[torch.Tensor]]] = {} # llmdatadist need the addr of cache tensor be aligned with 2M alignment = 2 * 1024 * 1024 for kv_cache_tensor in kv_cache_config.kv_cache_tensors: # TODO: REFACTOR ME to sharing hybrid cache for idx in range(len(kv_cache_tensor.shared_by)): layer_name = kv_cache_tensor.shared_by[idx] if "linear_attn" in layer_name: # for mamba linear attention for layer_name_inner in kv_cache_tensor.shared_by: if ("attn" in layer_name_inner and "linear_attn" not in layer_name_inner) or \ layer_name_inner in kv_cache_raw_tensors.keys(): continue if self.vllm_config.kv_transfer_config is None: tensor = torch.zeros(kv_cache_tensor.size, dtype=torch.int8, device=self.device) else: cache_size_aligned = kv_cache_tensor.size + alignment tensor = torch.zeros(cache_size_aligned, dtype=torch.int8, device=self.device) tensor = self._align_memory( tensor, alignment)[:kv_cache_tensor.size] kv_cache_raw_tensors[layer_name_inner] = tensor elif "attn" in layer_name: # for other attentions, e.g., self_attn, sliding window attn if self.vllm_config.kv_transfer_config is None: k_tensor = torch.zeros(kv_cache_tensor.size // 2, dtype=torch.int8, device=self.device) v_tensor = torch.zeros(kv_cache_tensor.size // 2, dtype=torch.int8, device=self.device) else: cache_size = kv_cache_tensor.size // 2 cache_size_aligned = kv_cache_tensor.size // 2 + alignment k_tensor = torch.zeros(cache_size_aligned, dtype=torch.int8, device=self.device) v_tensor = torch.zeros(cache_size_aligned, dtype=torch.int8, device=self.device) k_tensor = self._align_memory(k_tensor, alignment)[:cache_size] v_tensor = self._align_memory(v_tensor, alignment)[:cache_size] kv_cache_raw_tensors[layer_name] = (k_tensor, v_tensor) layer_names = set() for group in kv_cache_config.kv_cache_groups: for layer_name in group.layer_names: if layer_name in self.runner_only_attn_layers: continue layer_names.add(layer_name) assert layer_names == set(kv_cache_raw_tensors.keys( )), "Some layers are not correctly initialized" kv_caches: Dict[str, torch.Tensor] = {} for group in self._kv_cache_spec_attn_group_iterator(): kv_cache_spec = group.kv_cache_spec attn_backend = group.backend for layer_name in group.layer_names: if layer_name in self.runner_only_attn_layers: continue # TODO: remove this after the OOM issue is located and fixed, otherwise, some model may # encounter OOM issue if isinstance(kv_cache_spec, FullAttentionSpec): raw_k_tensor, raw_v_tensor = kv_cache_raw_tensors[ # type: ignore layer_name] assert raw_k_tensor is not None assert raw_v_tensor is not None assert (raw_k_tensor.numel() + raw_v_tensor.numel() ) % kv_cache_spec.page_size_bytes == 0 num_blocks = (raw_k_tensor.numel() + raw_v_tensor.numel() ) // kv_cache_spec.page_size_bytes # `num_blocks` is the number of blocks the model runner can use. # `kv_cache_config.num_blocks` is the number of blocks that # KVCacheManager may allocate. # Since different GPUs may have different number of layers and # different memory capacities, `num_blocks` can be different on # different GPUs, and `kv_cache_config.num_blocks` is set to # the min of all `num_blocks`. Verify it here. assert num_blocks >= kv_cache_config.num_blocks if self.vllm_config.additional_config.get( "kv_cache_dtype", None) == 'int8': kv_cache_shape = attn_backend.get_bsh_kv_cache_shape( num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) elif hasattr(attn_backend, "get_supported_block_size" ) and self.use_hybrid_blocks: block_size = attn_backend.get_supported_block_size()[0] block_size_chunk = kv_cache_spec.block_size // block_size kv_cache_shape = attn_backend.get_kv_cache_shape( num_blocks * block_size_chunk, block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) else: kv_cache_shape = self.attn_backend.get_kv_cache_shape( num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) dtype = kv_cache_spec.dtype k_cache = raw_k_tensor.view(dtype).view(kv_cache_shape[1:]) k_cache = self._convert_torch_format(k_cache) v_cache = raw_v_tensor.view(dtype).view(kv_cache_shape[1:]) v_cache = self._convert_torch_format(v_cache) kv_caches[layer_name] = (k_cache, v_cache) elif isinstance(kv_cache_spec, MambaSpec): raw_tensor = kv_cache_raw_tensors[layer_name] assert raw_tensor is not None assert raw_tensor.numel( ) % kv_cache_spec.page_size_bytes == 0 num_blocks = raw_tensor.numel( ) // kv_cache_spec.page_size_bytes # `num_blocks` is the number of blocks the model runner can use. # `kv_cache_config.num_blocks` is the number of blocks that # KVCacheManager may allocate. # Since different GPUs may have different number of layers and # different memory capacities, `num_blocks` can be different on # different GPUs, and `kv_cache_config.num_blocks` is set to # the min of all `num_blocks`. Verify it here. assert num_blocks >= kv_cache_config.num_blocks state_tensors = [] storage_offset_bytes = 0 for (shape, dtype) in zip(kv_cache_spec.shapes, kv_cache_spec.dtypes): dtype_size = get_dtype_size(dtype) num_element_per_page = ( kv_cache_spec.page_size_bytes // dtype_size) target_shape = (num_blocks, *shape) stride = torch.empty(target_shape).stride() target_stride = (num_element_per_page, *stride[1:]) assert storage_offset_bytes % dtype_size == 0 tensor = torch.as_strided( raw_tensor.view(dtype), size=target_shape, stride=target_stride, storage_offset=storage_offset_bytes // dtype_size, ) state_tensors.append(tensor) storage_offset_bytes += stride[0] * dtype_size kv_caches[layer_name] = state_tensors else: raise ValueError("Unknown KV cache spec type.") bind_kv_cache(kv_caches, self.compilation_config.static_forward_context, self.kv_caches) return kv_caches def may_reinitialize_input_batch(self, kv_cache_config: KVCacheConfig) -> None: """ Re-initialize the input batch if the block sizes are different from `[self.cache_config.block_size]`. This usually happens when there are multiple KV cache groups. Args: kv_cache_config: The KV cache configuration. """ block_sizes = [ kv_cache_group.kv_cache_spec.block_size for kv_cache_group in kv_cache_config.kv_cache_groups ] # Generate kernel_block_sizes that matches each block_size # For attention backends that support virtual block splitting, # use the supported block sizes from the backend # For other backends (like Mamba), use [0] (no splitting) kernel_block_sizes = [] for kv_cache_group_id, kv_cache_group in enumerate( kv_cache_config.kv_cache_groups): if isinstance(kv_cache_group.kv_cache_spec, AttentionSpec): # This is an attention backend that supports virtual # block splitting. Get the supported block sizes from # the backend. try: attn_groups = self.attn_groups[kv_cache_group_id] except IndexError: attn_groups = None if attn_groups and self.use_hybrid_blocks: # Use the backend's supported block size list backend = attn_groups[0].backend supported_sizes = backend.get_supported_block_size() # If no specific sizes supported, use cache config # block_size kernel_block_size_list = (supported_sizes if supported_sizes else [self.cache_config.block_size]) else: # Fallback to cache config block_size if no backend found kernel_block_size_list = [self.cache_config.block_size] kernel_block_sizes.append(kernel_block_size_list) else: # This is likely Mamba or other non-attention cache, # no splitting. # NOTE: set kernel_block_sizes to 0 to disable slotmapping computation # of mamba block. In this case, BlockTable.block_size will never equal # to kernel_block_sizes[0] kernel_block_sizes.append([0]) if kernel_block_sizes != [[self.cache_config.block_size]]: assert self.cache_config.cpu_offload_gb == 0, ( "Cannot re-initialize the input batch when CPU weight " "offloading is enabled. See https://github.com/vllm-project/vllm/pull/18298 " # noqa: E501 "for more details.") self.input_batch = InputBatch( max_num_reqs=self.max_num_reqs, max_model_len=self.model_config.max_model_len, max_num_batched_tokens=self.max_num_tokens, device=self.device, pin_memory=self.pin_memory, vocab_size=self.model_config.get_vocab_size(), block_sizes=block_sizes, is_spec_decode=bool(self.vllm_config.speculative_config), logitsprocs=self.input_batch.logitsprocs, is_pooling_model=self.is_pooling_model, num_speculative_tokens=( self.vllm_config.speculative_config.num_speculative_tokens if self.vllm_config.speculative_config else 0), kernel_block_sizes=kernel_block_sizes, ) def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: """ Initialize the attention backends and attention metadata builders. """ assert len(self.attn_groups) == 0, \ "Attention backends are already initialized" class AttentionGroupKey(NamedTuple): attn_backend: type[AttentionBackend] kv_cache_spec: KVCacheSpec def get_attn_backends_for_group( kv_cache_group_spec: KVCacheGroupSpec, ) -> dict[AttentionGroupKey, list[str]]: layers = get_layers_from_vllm_config( self.vllm_config, AttentionLayerBase, kv_cache_group_spec.layer_names) attn_backends = {} attn_backend_layers = defaultdict(list) # Dedupe based on full class name; this is a bit safer than # using the class itself as the key because when we create dynamic # attention backend subclasses (e.g. ChunkedLocalAttention) unless # they are cached correctly, there will be different objects per # layer. for layer_name in kv_cache_group_spec.layer_names: attn_backend = layers[layer_name].get_attn_backend() full_cls_name = attn_backend.full_cls_name() layer_kv_cache_spec = kv_cache_group_spec.kv_cache_spec if isinstance(layer_kv_cache_spec, UniformTypeKVCacheSpecs): layer_kv_cache_spec = layer_kv_cache_spec.kv_cache_specs[ layer_name] key = (full_cls_name, layer_kv_cache_spec) attn_backends[key] = AttentionGroupKey(attn_backend, layer_kv_cache_spec) attn_backend_layers[key].append(layer_name) return { attn_backends[k]: v for k, v in attn_backend_layers.items() } def create_attn_groups( attn_backends_map: dict[AttentionBackend, list[str]], ) -> list[AttentionGroup]: attn_groups: list[AttentionGroup] = [] for (attn_backend, kv_cache_spec), layer_names in attn_backends_map.items(): attn_metadata_builders = [] attn_metadata_builders.append(attn_backend.get_builder_cls()( kv_cache_spec, layer_names, self.vllm_config, self.device, )) attn_group = AttentionGroup(attn_backend, attn_metadata_builders, layer_names, kv_cache_spec) attn_groups.append(attn_group) return attn_groups for kv_cache_group_spec in kv_cache_config.kv_cache_groups: attn_backends = get_attn_backends_for_group( # type: ignore kv_cache_group_spec) self.attn_groups.append(create_attn_groups(attn_backends)) # Calculate reorder batch threshold (if needed) self.calculate_reorder_batch_threshold() def _attn_group_iterator(self) -> Iterator[AttentionGroup]: return itertools.chain.from_iterable(self.attn_groups) def _kv_cache_spec_attn_group_iterator(self) -> Iterator[AttentionGroup]: if not self.kv_cache_config.kv_cache_groups: return for attn_groups in self.attn_groups: yield from attn_groups def calculate_reorder_batch_threshold(self) -> None: """ Check that if any backends reorder batches; that the reordering is compatible (e.g., decode threshold is the same) """ for group in self._attn_group_iterator(): attn_metadata_builder_i = group.get_metadata_builder() if hasattr(attn_metadata_builder_i, "reorder_batch_threshold"): # check that if any backends reorder batches; that the reordering # is compatible (e.g., decode threshold is the same) reorder_batch_threshold_i = ( attn_metadata_builder_i.reorder_batch_threshold) if reorder_batch_threshold_i is not None: if self.reorder_batch_threshold is not None: if reorder_batch_threshold_i != \ self.reorder_batch_threshold: raise ValueError( f"Attention backend reorders decodes with " f"threshold {reorder_batch_threshold_i} but other " f"backend uses threshold " f"{self.reorder_batch_threshold}") else: self.reorder_batch_threshold = reorder_batch_threshold_i def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: """ Generates the KVCacheSpec by parsing the kv cache format from each Attention module in the static forward context. Returns: KVCacheSpec: A dictionary mapping layer names to their KV cache format. Layers that do not need KV cache are not included. """ block_size = self.vllm_config.cache_config.block_size use_mla = self.vllm_config.model_config.use_mla use_sfa = self.ascend_config.use_sfa kv_cache_spec: dict[str, KVCacheSpec] = {} attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) for layer_name, attn_module in attn_layers.items(): if (kv_tgt_layer := attn_module.kv_sharing_target_layer_name) is not None: # The layer doesn't need its own KV cache and will use that of # the target layer. We skip creating a KVCacheSpec for it, so # that KV cache management logic will act as this layer does # not exist, and doesn't allocate KV cache for the layer. This # enables the memory saving of cross-layer kv sharing, allowing # a given amount of memory to accommodate longer context lengths # or enable more requests to be processed simultaneously. self.shared_kv_cache_layers[layer_name] = kv_tgt_layer continue if isinstance(attn_module, AscendMultiHeadLatentAttention): continue # TODO: Support other attention modules, e.g., cross-attention # TODO(lucas): move the attention specs into the model layers like # the attention backends if attn_module.attn_type == AttentionType.DECODER: kv_cache_spec[layer_name] = FullAttentionSpec( block_size=block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, dtype=self.kv_cache_dtype, use_mla=use_mla, use_sfa=use_sfa) elif attn_module.attn_type in (AttentionType.ENCODER, AttentionType.ENCODER_ONLY): # encoder-only attention does not need KV cache. continue elif attn_module.attn_type == AttentionType.ENCODER_DECODER: raise NotImplementedError else: raise ValueError( f"Unknown attention type: {attn_module.attn_type}") mamba_layers = get_layers_from_vllm_config(self.vllm_config, MambaBase) if len(mamba_layers) > 0: if (self.vllm_config.speculative_config is not None and self.vllm_config.model_config.hf_config.model_type not in ["qwen3_next"]): raise NotImplementedError( "Mamba with speculative decoding is not supported yet.") if self.vllm_config.cache_config.enable_prefix_caching: raise NotImplementedError( "Prefix caching is not supported for Mamba yet.") max_model_len = self.vllm_config.model_config.max_model_len page_size_padded = ( self.vllm_config.cache_config.mamba_page_size_padded) # Set block_size to max_model_len, so that mamba model will always # have only one block in the KV cache. for layer_name, mamba_module in mamba_layers.items(): kv_cache_spec[layer_name] = MambaSpec( shapes=mamba_module.get_state_shape(), dtypes=mamba_module.get_state_dtype(), block_size=max_model_len, page_size_padded=page_size_padded, mamba_type=mamba_module.mamba_type, num_speculative_blocks=( self.speculative_config.num_speculative_tokens if self.speculative_config else 0), ) return kv_cache_spec def initialize_aclgraph_capture(self) -> None: min_ag_support = AttentionCGSupport.ALWAYS min_ag_builder_name = None for attn_group in self._attn_group_iterator(): builder = attn_group.get_metadata_builder() if builder.aclgraph_support.value < min_ag_support.value: min_ag_support = builder.aclgraph_support min_ag_builder_name = builder.__class__.__name__ # This is an imitation of compilation_config.splitting_ops_contain_attention() splitting_ops_contain_attention = ( self.compilation_config.splitting_ops is not None and all(op in self.compilation_config.splitting_ops for op in [ "vllm.unified_ascend_attention_with_output", "vllm.mla_forward", ])) # Flexible resolve the aclgraph mode aclgraph_mode = self.compilation_config.cudagraph_mode # check graph for mixed batch is supported if aclgraph_mode.mixed_mode() == CUDAGraphMode.FULL \ and min_ag_support != AttentionCGSupport.ALWAYS: msg = (f"ACLGraphMode.{aclgraph_mode.name} is not supported " f"with {min_ag_builder_name} backend (support: " f"{min_ag_support})") if min_ag_support == AttentionCGSupport.NEVER: # if not supported any full graphs, just raise it. msg += "; please try cudagraph_mode=PIECEWISE, and "\ "make sure compilation level is piecewise" raise ValueError(msg) # attempt to resolve the full graph related mode if splitting_ops_contain_attention: msg += "; setting cudagraph_mode=FULL_AND_PIECEWISE" aclgraph_mode = self.compilation_config.cudagraph_mode = ( CUDAGraphMode.FULL_AND_PIECEWISE) else: msg += "; setting cudagraph_mode=FULL_DECODE_ONLY" aclgraph_mode = self.compilation_config.cudagraph_mode = ( CUDAGraphMode.FULL_DECODE_ONLY) logger.warning(msg) # check that if spec-decode + decode full-graphs is supported if (aclgraph_mode.decode_mode() == CUDAGraphMode.FULL and self.uniform_decode_query_len > 1 and min_ag_support.value < AttentionCGSupport.UNIFORM_BATCH.value): msg = (f"CUDAGraphMode.{aclgraph_mode.name} is not supported" f" with spec-decode for attention backend " f"{min_ag_builder_name} (support: {min_ag_support})") if splitting_ops_contain_attention: msg += "; setting cudagraph_mode=PIECEWISE" aclgraph_mode = self.compilation_config.cudagraph_mode = \ CUDAGraphMode.PIECEWISE else: msg += "; setting cudagraph_mode=NONE" aclgraph_mode = self.compilation_config.cudagraph_mode = \ CUDAGraphMode.NONE logger.warning(msg) # double check that we can support full graph if they are requested # even after automatic downgrades if aclgraph_mode.has_full_cudagraphs() \ and min_ag_support == AttentionCGSupport.NEVER: raise ValueError(f"CUDAGraphMode.{aclgraph_mode.name} is not " f"supported with {min_ag_builder_name} backend (" f"support:{min_ag_support}) " "; please try cudagraph_mode=PIECEWISE, " "and make sure compilation level is piecewise") self.aclgraph_dispatcher.initialize_cudagraph_keys( self.compilation_config.cudagraph_mode, self.uniform_decode_query_len) def _capture_aclgraphs(self, compilation_cases: list[int], aclgraph_runtime_mode: CUDAGraphMode, uniform_decode: bool): assert aclgraph_runtime_mode != CUDAGraphMode.NONE and \ aclgraph_runtime_mode in [CUDAGraphMode.FULL, CUDAGraphMode.PIECEWISE] # Only rank 0 should print progress bar during capture if is_global_first_rank(): logger.info( "Starting to capture ACL graphs for cases: %s, " "mode: %s, uniform_decode: %s", compilation_cases, aclgraph_runtime_mode.name, uniform_decode) compilation_cases = tqdm( compilation_cases, disable=not self.load_config.use_tqdm_on_load, desc="Capturing ACL graphs ({}, {})".format( "decode" if uniform_decode else "mixed prefill-decode", aclgraph_runtime_mode.name)) # We skip EPLB here since we don't want to record dummy metrics for num_tokens in compilation_cases: for _ in range(self.compilation_config.cudagraph_num_of_warmups): # Use CUDAGraphRuntimeStyle.NONE (default) for warmup. # But be careful, warm up with `NONE`is orthogonal to # if we want to warm up attention or not. This is # different from the case where `FULL` implies capture # attention while `PIECEWISE` implies no attention. force_attention = (aclgraph_runtime_mode == CUDAGraphMode.FULL) self._dummy_run(num_tokens, aclgraph_runtime_mode=CUDAGraphMode.NONE, force_attention=force_attention, uniform_decode=uniform_decode) self._dummy_run(num_tokens, aclgraph_runtime_mode=aclgraph_runtime_mode, force_attention=force_attention, uniform_decode=uniform_decode) def _capture_model(self): if not self.use_aclgraph: logger.warning( "Skipping ACL graph capture. To turn on ACL graph capture, " "ensure `aclraph_mode` was not manually set to `NONE`") return else: self.initialize_aclgraph_capture() set_cudagraph_capturing_enabled(True) # Trigger ACL graph capture for specific shapes. # Capture the large shapes first so that the smaller shapes # can reuse the memory pool allocated for the large shapes. with graph_capture(device=self.device): aclgraph_mode = self.compilation_config.cudagraph_mode if aclgraph_mode.mixed_mode() != CUDAGraphMode.NONE: aclgraph_runtime_mode = aclgraph_mode.mixed_mode() compilation_cases = list(reversed(self.aclgraph_batch_sizes)) try: self._capture_aclgraphs( compilation_cases, aclgraph_runtime_mode=aclgraph_runtime_mode, uniform_decode=False) except Exception as e: logger.error( f"ACLgraph sizes capture fail: {type(e).__name__}:\n" "ACLgraph has insufficient available streams to capture the configured number of sizes. " "Please verify both the availability of adequate streams and the appropriateness of the configured size count.\n\n" "Recommended solutions:\n" "1. Manually configure the compilation_config parameter " "with a reduced set of sizes: '{\"cudagraph_capture_sizes\":[size1, size2, size3, ...]}'.\n" "2. Utilize ACLgraph's full graph mode as an alternative to the piece-wise approach.\n\n" f"{str(e)}") raise if aclgraph_mode.decode_mode() == CUDAGraphMode.FULL and \ aclgraph_mode.separate_routine(): max_num_tokens = self.scheduler_config.max_num_seqs * \ self.uniform_decode_query_len decode_cudagraph_batch_sizes = [ x for x in self.aclgraph_batch_sizes if x <= max_num_tokens and x >= self.uniform_decode_query_len ] compilation_cases_decode = list( reversed(decode_cudagraph_batch_sizes)) self._capture_aclgraphs( compilation_cases=compilation_cases_decode, aclgraph_runtime_mode=CUDAGraphMode.FULL, uniform_decode=True) # Disable aclgraph capturing globally, so any unexpected aclgraph # capturing will be detected and raise an error after here. # Note: We don't put it into graph_capture context manager because # we may doing lazy capturing in future that still allows capturing # after here. set_cudagraph_capturing_enabled(False) def capture_model(self) -> None: compilation_counter.num_gpu_runner_capture_triggers += 1 start_time = time.perf_counter() start_free_npu_memory = torch.npu.mem_get_info()[0] self._capture_model() end_time = time.perf_counter() end_free_npu_memory = torch.npu.mem_get_info()[0] elapsed_time = end_time - start_time npu_graph_size = start_free_npu_memory - end_free_npu_memory # This usually takes 5~20 seconds. logger.info("Graph capturing finished in %.0f secs, took %.2f GiB", elapsed_time, npu_graph_size / (1 << 30)) def _get_prompt_logprobs_dict( self, hidden_states: torch.Tensor, scheduler_output: "SchedulerOutput", ) -> dict[str, Optional[LogprobsTensors]]: num_prompt_logprobs_dict = self.input_batch.num_prompt_logprobs if not num_prompt_logprobs_dict: return {} in_progress_dict = self.input_batch.in_progress_prompt_logprobs_cpu prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] = {} # Since prompt logprobs are a rare feature, prioritize simple, # maintainable loop over optimal performance. completed_prefill_reqs = [] for req_id, num_prompt_logprobs in num_prompt_logprobs_dict.items(): num_tokens = scheduler_output.num_scheduled_tokens[req_id] # Get metadata for this request. request = self.requests[req_id] num_prompt_tokens = len(request.prompt_token_ids) prompt_token_ids = torch.tensor(request.prompt_token_ids).to( self.device, non_blocking=True) # Set up target LogprobsTensors object. logprobs_tensors = in_progress_dict.get(req_id) if not logprobs_tensors: # Create empty logprobs CPU tensors for the entire prompt. # If chunked, we'll copy in slice by slice. logprobs_tensors = LogprobsTensors.empty_cpu( num_prompt_tokens - 1, num_prompt_logprobs + 1) in_progress_dict[req_id] = logprobs_tensors # Determine number of logits to retrieve. start_idx = request.num_computed_tokens start_tok = start_idx + 1 num_remaining_tokens = num_prompt_tokens - start_tok if num_tokens <= num_remaining_tokens: # This is a chunk, more tokens remain. # In the == case, there are no more prompt logprobs to produce # but we want to defer returning them to the next step where we # have new generated tokens to return. num_logits = num_tokens else: # This is the last chunk of prompt tokens to return. num_logits = num_remaining_tokens completed_prefill_reqs.append(req_id) prompt_logprobs_dict[req_id] = logprobs_tensors if num_logits <= 0: # This can happen for the final chunk if we prefilled exactly # (num_prompt_tokens - 1) tokens for this request in the prior # step. There are no more prompt logprobs to produce. continue # Get the logits corresponding to this req's prompt tokens. # If this is a partial request (i.e. chunked prefill), # then there is prompt logprob generated for each index. req_idx = self.input_batch.req_id_to_index[req_id] offset = self.query_start_loc_np[req_idx].item() prompt_hidden_states = hidden_states[offset:offset + num_logits] logits = self.model.compute_logits(prompt_hidden_states) # Get the "target" tokens for each index. For prompt at index i, # the token at prompt index i+1 is the "sampled" token we want # to gather the logprob for. tgt_token_ids = prompt_token_ids[start_tok:start_tok + num_logits] # Compute prompt logprobs. logprobs = self.sampler.compute_logprobs(logits) token_ids, logprobs, ranks = self.sampler.gather_logprobs( logprobs, num_prompt_logprobs, tgt_token_ids) # Transfer NPU->CPU async. chunk_slice = slice(start_idx, start_idx + num_logits) logprobs_tensors.logprob_token_ids[chunk_slice].copy_( token_ids, non_blocking=True) logprobs_tensors.logprobs[chunk_slice].copy_(logprobs, non_blocking=True) logprobs_tensors.selected_token_ranks[chunk_slice].copy_( ranks, non_blocking=True) # Remove requests that have completed prefill from the batch # num_prompt_logprobs_dict. for req_id in completed_prefill_reqs: del num_prompt_logprobs_dict[req_id] del in_progress_dict[req_id] # Must synchronize the non-blocking NPU->CPU transfers. if prompt_logprobs_dict: torch.npu.synchronize() return prompt_logprobs_dict def get_supported_pooling_tasks(self): model = self.get_model() if not is_pooling_model(model): return [] return list(model.pooler.get_supported_tasks()) def _build_drafter_prepare_inputs_torchair_param(self): return False