From b4233a2ec35a91beca4b8c8402ea7cf6394b62c9 Mon Sep 17 00:00:00 2001 From: Shirley125 <54166744+Shirley125@users.noreply.github.com> Date: Sat, 18 Oct 2025 15:56:44 +0800 Subject: [PATCH] [Bugfix] Route requests requiring KVC recomputation from the decode instance to the P instance (#3448) ### What this PR does / why we need it? This PR is aimed to fix the recomputing out of memory bug in decode instance. When recomputing happens in decode, kv cache usage may exceed the pre-allocated memory, and it will cause OOM. So we propose a new scheduling strategy, when decode instance cannot allocate new block for running requests, we will stop the request that will be preempted. These stopped request will be recognied by proxy, and they will be send to prefill instance again to calculate kvc and then direct to decode instance. This is a temporary plan to fix the bug. The long-term stratege is to use CPU offload in decode instance. ### Does this PR introduce _any_ user-facing change? An extra ascend configuration option **-- recompute_scheduler_enable = True** is added to enable this strategy. The default value is False ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: CHEN <116010019@link.cuhk.edu.cn> --- ..._balance_proxy_layerwise_server_example.py | 228 ++- .../load_balance_proxy_server_example.py | 208 ++- vllm_ascend/ascend_config.py | 2 + vllm_ascend/core/recompute_schedule_config.py | 39 + vllm_ascend/core/recompute_scheduler.py | 1392 +++++++++++++++++ vllm_ascend/platform.py | 6 + 6 files changed, 1761 insertions(+), 114 deletions(-) create mode 100644 vllm_ascend/core/recompute_schedule_config.py create mode 100644 vllm_ascend/core/recompute_scheduler.py diff --git a/examples/disaggregated_prefill_v1/load_balance_proxy_layerwise_server_example.py b/examples/disaggregated_prefill_v1/load_balance_proxy_layerwise_server_example.py index 61d420156..1336e5a34 100644 --- a/examples/disaggregated_prefill_v1/load_balance_proxy_layerwise_server_example.py +++ b/examples/disaggregated_prefill_v1/load_balance_proxy_layerwise_server_example.py @@ -84,17 +84,18 @@ # # For more details, see the code and comments in this file. - import argparse import asyncio import functools import heapq +import json import os import sys -import uuid import threading +import uuid from contextlib import asynccontextmanager -from typing import List +from dataclasses import dataclass +from typing import Any, List import httpx from fastapi import FastAPI, Request @@ -106,6 +107,7 @@ logger = init_logger(__name__) # Add uvloop for faster event loop if available try: import uvloop + asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) except ImportError: pass @@ -324,7 +326,7 @@ async def listen_for_disconnect(request: Request) -> None: def with_cancellation(handler_func): - + @functools.wraps(handler_func) async def wrapper(*args, **kwargs): request = kwargs["request"] @@ -337,9 +339,9 @@ def with_cancellation(handler_func): if handler_task in done: return handler_task.result() return None - + return wrapper - + app = FastAPI(lifespan=lifespan) @@ -362,7 +364,8 @@ async def send_request_to_service(client: httpx.AsyncClient, "remote_host": None, "remote_port": None, "aborted_request": list(aborted_requests), - "metaserver": f"http://{global_args.host}:{global_args.port}/v1/metaserver" + "metaserver": + f"http://{global_args.host}:{global_args.port}/v1/metaserver" } req_data["stream"] = False req_data["max_tokens"] = 1 @@ -455,72 +458,174 @@ def get_api_request_id(api, req_id): return "chatcmpl-" + req_id +async def _handle_select_instance(api: str, req_data: Any, + request_length: int): + prefiller_score = proxy_state.calculate_prefill_scores(request_length) + logger.debug( + f"Request length: {request_length}, Prefiller score: {prefiller_score}" + ) + request_id = await proxy_state.next_req_id() + # Select prefiller + prefiller_idx = proxy_state.select_prefiller(prefiller_score) + prefiller = proxy_state.prefillers[prefiller_idx] + result_future = asyncio.Future() # type: ignore + request_id_api = get_api_request_id(api, request_id) + proxy_state.req_id_future[request_id_api] = result_future + # Send request to prefiller + asyncio.get_running_loop().create_task( + send_request_to_service(prefiller.client, + prefiller_idx, + api, + req_data, + request_id, + max_retries=global_args.max_retries, + base_delay=global_args.retry_delay)) + proxy_state.release_prefiller(prefiller_idx, prefiller_score) + + response = await result_future + del proxy_state.req_id_future[request_id_api] + req_data["kv_transfer_params"] = response + + # Select decoder + decoder_score = proxy_state.calculate_decode_scores(request_length) + logger.debug("Decoder score: %f", decoder_score) + # Use the prefiller's kv_transfer_params to select decoder + decoder_idx = proxy_state.select_decoder(decoder_score) + decoder = proxy_state.decoders[decoder_idx] + logger.debug("Using %s %s", prefiller.url, decoder.url) + return InstanceInfo(request_id=request_id, + prefiller_idx=prefiller_idx, + prefiller_score=prefiller_score, + prefiller=prefiller, + decoder=decoder, + decoder_idx=decoder_idx, + decoder_score=decoder_score) + + +@dataclass +class InstanceInfo: + request_id: str + prefiller_idx: int + prefiller_score: float + prefiller: ServerState + decoder_idx: int + decoder_score: float + decoder: ServerState + + async def _handle_completions(api: str, request: Request): try: req_data = await request.json() req_body = await request.body() request_length = len(req_body) - prefiller_score = proxy_state.calculate_prefill_scores(request_length) - logger.debug( - f"Request length: {request_length}, Prefiller score: {prefiller_score}" - ) - request_id = await proxy_state.next_req_id() - # Select prefiller - prefiller_idx = proxy_state.select_prefiller(prefiller_score) - prefiller = proxy_state.prefillers[prefiller_idx] - result_future = asyncio.Future() # type: ignore - request_id_api = get_api_request_id(api, request_id) - proxy_state.req_id_future[request_id_api] = result_future - # Send request to prefiller - asyncio.get_running_loop().create_task(send_request_to_service( - prefiller.client, - prefiller_idx, - api, - req_data, - request_id, - max_retries=global_args.max_retries, - base_delay=global_args.retry_delay)) - proxy_state.release_prefiller(prefiller_idx, prefiller_score) - - response = await result_future - del proxy_state.req_id_future[request_id_api] - req_data["kv_transfer_params"] = response + instance_info = await _handle_select_instance(api, req_data, + request_length) + stream_flag = bool(req_data.get("stream", False)) + chat_flag = "messages" in req_data + + if "prompt" in req_data: + origin_prompt = req_data["prompt"] + elif chat_flag: + messages = req_data["messages"] + origin_prompt = messages[0].get("content", "") + else: + origin_prompt = "" + # refer to vLLM sampling_params: max_token default value + origin_max_tokens = req_data.get("max_tokens", 16) - # Select decoder - decoder_score = proxy_state.calculate_decode_scores(request_length) - logger.debug("Decoder score: %f", decoder_score) - # Use the prefiller's kv_transfer_params to select decoder - decoder_idx = proxy_state.select_decoder(decoder_score) - decoder = proxy_state.decoders[decoder_idx] - logger.debug("Using %s %s", prefiller.url, decoder.url) - # Stream response from decoder - released_kv = False async def generate_stream(): - nonlocal released_kv + nonlocal instance_info + generated_token = "" + released_kv = False + retry_count = 0 + retry = True + completion_tokens = 0 # Only one await per chunk, minimal logic in loop try: - async for chunk in stream_service_response_with_retry( - decoder.client, - api, - req_data, - request_id=request_id, - max_retries=global_args.max_retries, - base_delay=global_args.retry_delay): - if not released_kv and chunk: - proxy_state.release_prefiller_kv( - prefiller_idx, prefiller_score) - released_kv = True - yield chunk + while retry: + retry = False + async for chunk in stream_service_response_with_retry( + instance_info.decoder.client, + api, + req_data, + request_id=instance_info.request_id, + max_retries=global_args.max_retries, + base_delay=global_args.retry_delay): + if not released_kv and chunk: + proxy_state.release_prefiller_kv( + instance_info.prefiller_idx, + instance_info.prefiller_score) + released_kv = True + chunk_str = chunk.decode("utf-8").strip() + if not chunk_str: + continue + if chunk_str.startswith("data: "): + chunk_str = chunk_str[len("data: "):] + try: + chunk_json = json.loads(chunk_str) + except json.JSONDecodeError: + # if chunk is [done], skip it. + logger.warning( + f"Skipping chunk: {chunk_str}") + yield chunk + continue + choices = chunk_json.get("choices", []) + if not choices: + yield chunk + continue + + choice = choices[0] + delta = choice.get("delta") or {} + message = choice.get("message") or {} + content = ( + delta.get("content") + or message.get("content") + or choice.get("text") + or "" + ) + generated_token += content + + stop_reason = choice.get( + "stop_reason") + usage = chunk_json.get("usage", {}) + completion_tokens = (completion_tokens + 1) if stream_flag else \ + (completion_tokens + usage.get("completion_tokens")) + if stop_reason == "recomputed": + retry = True + retry_count += 1 + if chat_flag: + messages[0][ + "content"] = origin_prompt + generated_token + else: + req_data[ + "prompt"] = origin_prompt + generated_token + req_data[ + "max_tokens"] = origin_max_tokens - completion_tokens + retry_count + tmp_request_length = len( + json.dumps(req_data).encode("utf-8")) + instance_info = await _handle_select_instance( + api, req_data, tmp_request_length) + break + if retry_count > 0 and not stream_flag: + if chat_flag: + choices[0]["message"][ + "content"] = generated_token + else: + choices[0]["text"] = generated_token + chunk = json.dumps(chunk_json).encode("utf-8") + yield chunk except Exception as e: logger.error( - f"Error during streaming from decoder {decoder.url}: {str(e)} the aborted request {request_id} will be routing to the target prefiller when new request is ready to dispatch to it" + f"Error during streaming from decoder {instance_info.decoder.url}: {str(e)} the aborted request {instance_info.request_id} will be routing to the target prefiller when new request is ready to dispatch to it" ) - proxy_state.abort_prefiller_request(prefiller_idx, request_id) - proxy_state.release_prefiller_kv(prefiller_idx, - prefiller_score) + proxy_state.abort_prefiller_request( + instance_info.prefiller_idx, instance_info.request_id) + proxy_state.release_prefiller_kv(instance_info.prefiller_idx, + instance_info.prefiller_score) # After streaming done, release tokens - proxy_state.release_decoder(decoder_idx, decoder_score) + proxy_state.release_decoder(instance_info.decoder_idx, + instance_info.decoder_score) return StreamingResponse(generate_stream(), media_type="application/json") @@ -564,13 +669,12 @@ async def metaserver(request: Request): result_future = proxy_state.req_id_future[request_id] result_future.set_result(req_data) except Exception as e: - logger.error( - f"Post metaserver failed with: {str(e)}" - ) + logger.error(f"Post metaserver failed with: {str(e)}") if __name__ == '__main__': global global_args global_args = parse_args() import uvicorn + uvicorn.run(app, host=global_args.host, port=global_args.port) diff --git a/examples/disaggregated_prefill_v1/load_balance_proxy_server_example.py b/examples/disaggregated_prefill_v1/load_balance_proxy_server_example.py index fd1c7e593..0e28deb8d 100644 --- a/examples/disaggregated_prefill_v1/load_balance_proxy_server_example.py +++ b/examples/disaggregated_prefill_v1/load_balance_proxy_server_example.py @@ -84,16 +84,17 @@ # # For more details, see the code and comments in this file. - import argparse import asyncio import functools import heapq +import json import os import sys import uuid from contextlib import asynccontextmanager -from typing import List +from dataclasses import dataclass +from typing import Any, List import httpx from fastapi import FastAPI, Request @@ -105,6 +106,7 @@ logger = init_logger(__name__) # Add uvloop for faster event loop if available try: import uvloop + asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) except ImportError: pass @@ -443,69 +445,170 @@ async def stream_service_response_with_retry(client: httpx.AsyncClient, raise e +async def _handle_select_instance(api: str, req_data: Any, + request_length: int): + prefiller_score = proxy_state.calculate_prefill_scores(request_length) + logger.debug( + f"Request length: {request_length}, Prefiller score: {prefiller_score}" + ) + request_id = await proxy_state.next_req_id() + # Select prefiller + prefiller_idx = proxy_state.select_prefiller(prefiller_score) + prefiller = proxy_state.prefillers[prefiller_idx] + # Send request to prefiller + response = await send_request_to_service( + prefiller.client, + prefiller_idx, + api, + req_data, + request_id, + max_retries=global_args.max_retries, + base_delay=global_args.retry_delay) + proxy_state.release_prefiller(prefiller_idx, prefiller_score) + response_json = response.json() + kv_transfer_params = response_json.get('kv_transfer_params', {}) + if kv_transfer_params: + req_data["kv_transfer_params"] = kv_transfer_params + # Select decoder + decoder_score = proxy_state.calculate_decode_scores(request_length) + logger.debug("Decoder score: %f", decoder_score) + # Use the prefiller's kv_transfer_params to select decoder + decoder_idx = proxy_state.select_decoder(decoder_score) + decoder = proxy_state.decoders[decoder_idx] + logger.debug("Using %s %s", prefiller.url, decoder.url) + return InstanceInfo(request_id=request_id, + prefiller_idx=prefiller_idx, + prefiller_score=prefiller_score, + prefiller=prefiller, + decoder=decoder, + decoder_idx=decoder_idx, + decoder_score=decoder_score) + + +@dataclass +class InstanceInfo: + request_id: str + prefiller_idx: int + prefiller_score: float + prefiller: ServerState + decoder_idx: int + decoder_score: float + decoder: ServerState + + async def _handle_completions(api: str, request: Request): try: req_data = await request.json() req_body = await request.body() request_length = len(req_body) - prefiller_score = proxy_state.calculate_prefill_scores(request_length) - logger.debug( - f"Request length: {request_length}, Prefiller score: {prefiller_score}" - ) - request_id = await proxy_state.next_req_id() - # Select prefiller - prefiller_idx = proxy_state.select_prefiller(prefiller_score) - prefiller = proxy_state.prefillers[prefiller_idx] - # Send request to prefiller - response = await send_request_to_service( - prefiller.client, - prefiller_idx, - api, - req_data, - request_id, - max_retries=global_args.max_retries, - base_delay=global_args.retry_delay) - proxy_state.release_prefiller(prefiller_idx, prefiller_score) - response_json = response.json() - kv_transfer_params = response_json.get('kv_transfer_params', {}) - if kv_transfer_params: - req_data["kv_transfer_params"] = kv_transfer_params - # Select decoder - decoder_score = proxy_state.calculate_decode_scores(request_length) - logger.debug("Decoder score: %f", decoder_score) - # Use the prefiller's kv_transfer_params to select decoder - decoder_idx = proxy_state.select_decoder(decoder_score) - decoder = proxy_state.decoders[decoder_idx] - logger.debug("Using %s %s", prefiller.url, decoder.url) - # Stream response from decoder - released_kv = False + instance_info = await _handle_select_instance(api, req_data, + request_length) + stream_flag = bool(req_data.get("stream", False)) + chat_flag = "messages" in req_data + + if "prompt" in req_data: + origin_prompt = req_data["prompt"] + elif chat_flag: + messages = req_data["messages"] + origin_prompt = messages[0].get("content", "") + else: + origin_prompt = "" + # refer to vLLM sampling_params: max_token default value + origin_max_tokens = req_data.get("max_tokens", 16) async def generate_stream(): - nonlocal released_kv + nonlocal instance_info + generated_token = "" + released_kv = False + retry_count = 0 + retry = True + completion_tokens = 0 # Only one await per chunk, minimal logic in loop try: - async for chunk in stream_service_response_with_retry( - decoder.client, - api, - req_data, - request_id=request_id, - max_retries=global_args.max_retries, - base_delay=global_args.retry_delay): - if not released_kv and chunk: - proxy_state.release_prefiller_kv( - prefiller_idx, prefiller_score) - released_kv = True - yield chunk + while retry: + retry = False + async for chunk in stream_service_response_with_retry( + instance_info.decoder.client, + api, + req_data, + request_id=instance_info.request_id, + max_retries=global_args.max_retries, + base_delay=global_args.retry_delay): + if not released_kv and chunk: + proxy_state.release_prefiller_kv( + instance_info.prefiller_idx, + instance_info.prefiller_score) + released_kv = True + chunk_str = chunk.decode("utf-8").strip() + if not chunk_str: + continue + if chunk_str.startswith("data: "): + chunk_str = chunk_str[len("data: "):] + try: + chunk_json = json.loads(chunk_str) + except json.JSONDecodeError: + # if chunk is [done], skip it. + logger.warning( + f"Skipping chunk: {chunk_str}") + yield chunk + continue + choices = chunk_json.get("choices", []) + if not choices: + yield chunk + continue + + choice = choices[0] + delta = choice.get("delta") or {} + message = choice.get("message") or {} + content = ( + delta.get("content") + or message.get("content") + or choice.get("text") + or "" + ) + generated_token += content + + stop_reason = choice.get( + "stop_reason") + usage = chunk_json.get("usage", {}) + completion_tokens = (completion_tokens + 1) if stream_flag else \ + (completion_tokens + usage.get("completion_tokens")) + if stop_reason == "recomputed": + retry = True + retry_count += 1 + if chat_flag: + messages[0][ + "content"] = origin_prompt + generated_token + else: + req_data[ + "prompt"] = origin_prompt + generated_token + req_data[ + "max_tokens"] = origin_max_tokens - completion_tokens + retry_count + tmp_request_length = len( + json.dumps(req_data).encode("utf-8")) + instance_info = await _handle_select_instance( + api, req_data, tmp_request_length) + break + if retry_count > 0 and not stream_flag: + if chat_flag: + choice["message"][ + "content"] = generated_token + else: + choice["text"] = generated_token + chunk = json.dumps(chunk_json).encode("utf-8") + yield chunk except Exception as e: logger.error( - f"Error during streaming from decoder {decoder.url}: {str(e)} the aborted request {request_id} will be routing to the target prefiller when new request is ready to dispatch to it" + f"Error during streaming from decoder {instance_info.decoder.url}: {str(e)} the aborted request {instance_info.request_id} will be routing to the target prefiller when new request is ready to dispatch to it" ) - proxy_state.abort_prefiller_request(prefiller_idx, request_id) - proxy_state.release_prefiller_kv(prefiller_idx, - prefiller_score) + proxy_state.abort_prefiller_request( + instance_info.prefiller_idx, instance_info.request_id) + proxy_state.release_prefiller_kv(instance_info.prefiller_idx, + instance_info.prefiller_score) # After streaming done, release tokens - proxy_state.release_decoder(decoder_idx, decoder_score) + proxy_state.release_decoder(instance_info.decoder_idx, + instance_info.decoder_score) return StreamingResponse(generate_stream(), media_type="application/json") @@ -544,4 +647,5 @@ if __name__ == '__main__': global global_args global_args = parse_args() import uvicorn - uvicorn.run(app, host=global_args.host, port=global_args.port) \ No newline at end of file + + uvicorn.run(app, host=global_args.host, port=global_args.port) diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index 6a606959f..a265e9697 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -70,6 +70,8 @@ class AscendConfig: ) and not self.torchair_graph_config.enabled and vllm_config.parallel_config.enable_expert_parallel self.multistream_overlap_shared_expert = additional_config.get( "multistream_overlap_shared_expert", False) + self.recompute_scheduler_enable = additional_config.get( + "recompute_scheduler_enable", False) self.lmhead_tensor_parallel_size = additional_config.get( "lmhead_tensor_parallel_size", None) if self.lmhead_tensor_parallel_size is not None: diff --git a/vllm_ascend/core/recompute_schedule_config.py b/vllm_ascend/core/recompute_schedule_config.py new file mode 100644 index 000000000..be19a1c70 --- /dev/null +++ b/vllm_ascend/core/recompute_schedule_config.py @@ -0,0 +1,39 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + +from dataclasses import dataclass, fields +from typing import Type, Union + +from vllm.config import SchedulerConfig + +MAX_INT = 2147483647 + + +@dataclass +class RecomputeSchedulerConfig(SchedulerConfig): + scheduler_cls: Union[str, Type[object]] = ( + "vllm_ascend.core.recompute_scheduler.RecomputeScheduler") + + @classmethod + def initialize_from_config(cls, vllm_scheduler_config: SchedulerConfig): + scheduler_config = { + field.name: getattr(vllm_scheduler_config, field.name) + for field in fields(vllm_scheduler_config) if field.init + } + scheduler_config["scheduler_cls"] = ( + "vllm_ascend.core.recompute_scheduler.RecomputeScheduler") + return cls(**scheduler_config) diff --git a/vllm_ascend/core/recompute_scheduler.py b/vllm_ascend/core/recompute_scheduler.py new file mode 100644 index 000000000..8946e2f2f --- /dev/null +++ b/vllm_ascend/core/recompute_scheduler.py @@ -0,0 +1,1392 @@ +## +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# +from __future__ import annotations + +import itertools +import time +from collections import defaultdict +from collections.abc import Iterable +from dataclasses import dataclass +from typing import Any, Optional, Union + +import numpy as np +import numpy.typing as npt +from vllm.config import VllmConfig +from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch +from vllm.distributed.kv_transfer.kv_connector.factory import \ + KVConnectorFactory +from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1, + KVConnectorRole) +from vllm.distributed.kv_transfer.kv_connector.v1.base import \ + KVConnectorMetadata +from vllm.distributed.kv_transfer.kv_connector.v1.metrics import \ + KVConnectorStats +from vllm.logger import init_logger +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry +from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager, + compute_encoder_budget) +from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager +from vllm.v1.core.sched.interface import SchedulerInterface +from vllm.v1.core.sched.output import CachedRequestData, NewRequestData +from vllm.v1.core.sched.request_queue import (SchedulingPolicy, + create_request_queue) +from vllm.v1.core.sched.utils import check_stop, remove_all +from vllm.v1.engine import (EngineCoreEventType, EngineCoreOutput, + EngineCoreOutputs, FinishReason) +from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.metrics.stats import SchedulerStats +from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput +from vllm.v1.request import Request, RequestStatus +from vllm.v1.spec_decode.metrics import SpecDecodingStats +from vllm.v1.structured_output import StructuredOutputManager +from vllm.v1.utils import ConstantList + +logger = init_logger(__name__) + + +class RecomputeScheduler(SchedulerInterface): + """This Scheduler extends vllm's original v1 scheduler of version 0.11 + to fix recomputing bug.""" + + def __init__( + self, + vllm_config: VllmConfig, + kv_cache_config: KVCacheConfig, + structured_output_manager: StructuredOutputManager, + mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, + include_finished_set: bool = False, + log_stats: bool = False, + ) -> None: + self.vllm_config = vllm_config + self.scheduler_config = vllm_config.scheduler_config + self.cache_config = vllm_config.cache_config + self.lora_config = vllm_config.lora_config + self.kv_cache_config = kv_cache_config + self.kv_events_config = vllm_config.kv_events_config + self.parallel_config = vllm_config.parallel_config + self.log_stats = log_stats + self.structured_output_manager = structured_output_manager + self.is_encoder_decoder = vllm_config.model_config.is_encoder_decoder + + # include_finished_set controls whether a separate set of finished + # request ids should be included in the EngineCoreOutputs returned + # by update_from_outputs(). This is currently used in the multi-engine + # case to track request lifetimes efficiently. + self.finished_req_ids_dict: Optional[dict[int, set[str]]] = ( + defaultdict(set) if include_finished_set else None) + + # Scheduling constraints. + self.max_num_running_reqs = self.scheduler_config.max_num_seqs + self.max_num_scheduled_tokens = \ + self.scheduler_config.max_num_batched_tokens + self.max_model_len = self.scheduler_config.max_model_len + self.enable_kv_cache_events = ( + self.kv_events_config is not None + and self.kv_events_config.enable_kv_cache_events) + + # Create KVConnector for the Scheduler. Note that each Worker + # will have a corresponding KVConnector with Role=WORKER. + # KV Connector pushes/pull of remote KVs for P/D and offloading. + self.connector = None + if self.vllm_config.kv_transfer_config is not None: + assert len(self.kv_cache_config.kv_cache_groups) == 1, ( + "Multiple KV cache groups are not currently supported " + "with KV connectors") + assert not self.is_encoder_decoder, ( + "Encoder-decoder models are not currently supported " + "with KV connectors") + self.connector = KVConnectorFactory.create_connector( + config=self.vllm_config, role=KVConnectorRole.SCHEDULER) + + self.kv_event_publisher = EventPublisherFactory.create( + self.kv_events_config, + self.parallel_config.data_parallel_rank, + ) + + num_gpu_blocks = self.cache_config.num_gpu_blocks + assert num_gpu_blocks is not None and num_gpu_blocks > 0 + + self.block_size = self.cache_config.block_size + + self.dcp_world_size = \ + vllm_config.parallel_config.decode_context_parallel_size + # Note(hc): The scheduler’s block_size must be multiplied + # by dcp_world_size, since block hashes are computed on the + # original full token sequence at a granularity of + # original_block_size × dcp_world_size. + if self.dcp_world_size > 1: + self.block_size *= self.dcp_world_size + + # req_id -> Request + self.requests: dict[str, Request] = {} + # Scheduling policy + if self.scheduler_config.policy == "priority": + self.policy = SchedulingPolicy.PRIORITY + elif self.scheduler_config.policy == "fcfs": + self.policy = SchedulingPolicy.FCFS + else: + raise ValueError( + f"Unknown scheduling policy: {self.scheduler_config.policy}") + # Priority queues for requests. + self.waiting = create_request_queue(self.policy) + self.running: list[Request] = [] + + # The request IDs that are finished in between the previous and the + # current steps. This is used to notify the workers about the finished + # requests so that they can free the cached states for those requests. + # This is flushed at the end of each scheduling step. + self.finished_req_ids: set[str] = set() + + # KV Connector: requests in process of async KV loading or recving + self.finished_recving_kv_req_ids: set[str] = set() + + # Encoder-related. + # Calculate encoder cache size if applicable + # NOTE: For now we use the same budget for both compute and space. + # This can be changed when we make encoder cache for embedding caching + # across requests. + encoder_compute_budget, encoder_cache_size = compute_encoder_budget( + model_config=vllm_config.model_config, + scheduler_config=vllm_config.scheduler_config, + mm_registry=mm_registry, + ) + + # NOTE(woosuk): Here, "encoder" includes the vision encoder (and + # projector if needed) for MM models as well as encoder-decoder + # transformers. + self.max_num_encoder_input_tokens = encoder_compute_budget + # NOTE: For the models without encoder (e.g., text-only models), + # the encoder cache will not be initialized because cache size is 0 + # for these models. + self.encoder_cache_manager = EncoderCacheManager( + cache_size=encoder_cache_size) + + speculative_config = vllm_config.speculative_config + self.use_eagle = False + self.num_spec_tokens = self.num_lookahead_tokens = 0 + if speculative_config: + self.num_spec_tokens = speculative_config.num_speculative_tokens + if speculative_config.use_eagle(): + self.use_eagle = True + self.num_lookahead_tokens = self.num_spec_tokens + + # Create the KV cache manager. + self.kv_cache_manager = KVCacheManager( + kv_cache_config=kv_cache_config, + max_model_len=self.max_model_len, + enable_caching=self.cache_config.enable_prefix_caching, + use_eagle=self.use_eagle, + log_stats=self.log_stats, + enable_kv_cache_events=self.enable_kv_cache_events, + dcp_world_size=self.dcp_world_size, + ) + self.use_pp = self.parallel_config.pipeline_parallel_size > 1 + + def schedule(self) -> RecomputeSchedulerOutput: + """This scheduler extends vLLM's original v1 scheduler + by introducing a decoding instance recomputing scheduling strategy. + Specifically, if a request is preempted in the decoding instance, + it halts the process with the recomputed symbol and recalculates + its KVC in the prefill instance.""" + + scheduled_new_reqs: list[Request] = [] + scheduled_resumed_reqs: list[Request] = [] + scheduled_running_reqs: list[Request] = [] + preempted_reqs: list[Request] = [] + recomputed_reqs: list[RecomputeReqInfo] = [] + + req_to_new_blocks: dict[str, KVCacheBlocks] = {} + num_scheduled_tokens: dict[str, int] = {} + token_budget = self.max_num_scheduled_tokens + # Encoder-related. + scheduled_encoder_inputs: dict[str, list[int]] = {} + encoder_compute_budget = self.max_num_encoder_input_tokens + # Spec decode-related. + scheduled_spec_decode_tokens: dict[str, list[int]] = {} + + # For logging. + scheduled_timestamp = time.monotonic() + + # First, schedule the RUNNING requests. + req_index = 0 + while req_index < len(self.running) and token_budget > 0: + request = self.running[req_index] + + num_new_tokens = (request.num_tokens_with_spec + + request.num_output_placeholders - + request.num_computed_tokens) + if (0 < self.scheduler_config.long_prefill_token_threshold < + num_new_tokens): + num_new_tokens = ( + self.scheduler_config.long_prefill_token_threshold) + num_new_tokens = min(num_new_tokens, token_budget) + + # Make sure the input position does not exceed the max model len. + # This is necessary when using spec decoding. + num_new_tokens = min( + num_new_tokens, + self.max_model_len - 1 - request.num_computed_tokens) + + # Schedule encoder inputs. + encoder_inputs_to_schedule = None + new_encoder_compute_budget = encoder_compute_budget + if request.has_encoder_inputs: + (encoder_inputs_to_schedule, num_new_tokens, + new_encoder_compute_budget + ) = self._try_schedule_encoder_inputs( + request, request.num_computed_tokens, num_new_tokens, + encoder_compute_budget) + + if num_new_tokens == 0: + # The request cannot be scheduled because one of the following + # reasons: + # 1. No new tokens to schedule. This may happen when + # (1) PP>1 and we have already scheduled all prompt tokens + # but they are not finished yet. + # (2) Async scheduling and the request has reached to either + # its max_total_tokens or max_model_len. + # 2. The encoder budget is exhausted. + # 3. The encoder cache is exhausted. + # NOTE(woosuk): Here, by doing `continue` instead of `break`, + # we do not strictly follow the FCFS scheduling policy and + # allow the lower-priority requests to be scheduled. + req_index += 1 + continue + + while True: + new_blocks = self.kv_cache_manager.allocate_slots( + request, + num_new_tokens, + num_lookahead_tokens=self.num_lookahead_tokens) + if new_blocks is None: + transfer_config = self.vllm_config.kv_transfer_config + if transfer_config is not None and not transfer_config.is_kv_producer: + recomputed_req = self.running.pop() + self.kv_cache_manager.free(recomputed_req) + recomputed_reqs.append( + RecomputeReqInfo(recomputed_req.request_id, + recomputed_req.output_token_ids, + recomputed_req.client_index)) + if recomputed_req == request: + can_schedule = False + break + else: + # The request cannot be scheduled. + # Preempt the lowest-priority request. + if self.policy == SchedulingPolicy.PRIORITY: + preempted_req = max( + self.running, + key=lambda r: (r.priority, r.arrival_time), + ) + self.running.remove(preempted_req) + if preempted_req in scheduled_running_reqs: + scheduled_running_reqs.remove(preempted_req) + else: + preempted_req = self.running.pop() + + self.kv_cache_manager.free(preempted_req) + self.encoder_cache_manager.free(preempted_req) + preempted_req.status = RequestStatus.PREEMPTED + preempted_req.num_computed_tokens = 0 + if self.log_stats: + preempted_req.record_event( + EngineCoreEventType.PREEMPTED, + scheduled_timestamp) + + self.waiting.prepend_request(preempted_req) + preempted_reqs.append(preempted_req) + if preempted_req == request: + # No more request to preempt. + can_schedule = False + break + else: + # The request can be scheduled. + can_schedule = True + break + if not can_schedule: + break + assert new_blocks is not None + + # Schedule the request. + scheduled_running_reqs.append(request) + req_to_new_blocks[request.request_id] = new_blocks + num_scheduled_tokens[request.request_id] = num_new_tokens + token_budget -= num_new_tokens + req_index += 1 + + # Speculative decode related. + if request.spec_token_ids: + num_scheduled_spec_tokens = (num_new_tokens + + request.num_computed_tokens - + request.num_tokens) + if num_scheduled_spec_tokens > 0: + # Trim spec_token_ids list to num_scheduled_spec_tokens. + del request.spec_token_ids[num_scheduled_spec_tokens:] + scheduled_spec_decode_tokens[request.request_id] = ( + request.spec_token_ids) + + # Encoder-related. + if encoder_inputs_to_schedule: + scheduled_encoder_inputs[request.request_id] = ( + encoder_inputs_to_schedule) + # Allocate the encoder cache. + for i in encoder_inputs_to_schedule: + self.encoder_cache_manager.allocate(request, i) + encoder_compute_budget = new_encoder_compute_budget + + # Record the LoRAs in scheduled_running_reqs + scheduled_loras: set[int] = set() + if self.lora_config: + scheduled_loras = set( + req.lora_request.lora_int_id for req in scheduled_running_reqs + if req.lora_request and req.lora_request.lora_int_id > 0) + assert len(scheduled_loras) <= self.lora_config.max_loras + + # Use a temporary RequestQueue to collect requests that need to be + # skipped and put back at the head of the waiting queue later + skipped_waiting_requests = create_request_queue(self.policy) + + # Next, schedule the WAITING requests. + if not preempted_reqs: + while self.waiting and token_budget > 0: + if len(self.running) == self.max_num_running_reqs: + break + + request = self.waiting.peek_request() + + # KVTransfer: skip request if still waiting for remote kvs. + if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: + is_ready = self._update_waiting_for_remote_kv(request) + if is_ready: + request.status = RequestStatus.WAITING + else: + logger.debug( + "%s is still in WAITING_FOR_REMOTE_KVS state.", + request.request_id) + self.waiting.pop_request() + skipped_waiting_requests.prepend_request(request) + continue + + # Skip request if the structured output request is still waiting + # for FSM compilation. + if request.status == RequestStatus.WAITING_FOR_FSM: + structured_output_req = request.structured_output_request + if structured_output_req and structured_output_req.grammar: + request.status = RequestStatus.WAITING + else: + self.waiting.pop_request() + skipped_waiting_requests.prepend_request(request) + continue + + # Check that adding the request still respects the max_loras + # constraint. + if (self.lora_config and request.lora_request and + (len(scheduled_loras) == self.lora_config.max_loras and + request.lora_request.lora_int_id not in scheduled_loras)): + # Scheduling would exceed max_loras, skip. + self.waiting.pop_request() + skipped_waiting_requests.prepend_request(request) + continue + + num_external_computed_tokens = 0 + load_kv_async = False + + # Get already-cached tokens. + if request.num_computed_tokens == 0: + # Get locally-cached tokens. + new_computed_blocks, num_new_local_computed_tokens = \ + self.kv_cache_manager.get_computed_blocks( + request) + + # Get externally-cached tokens if using a KVConnector. + if self.connector is not None: + num_external_computed_tokens, load_kv_async = ( + self.connector.get_num_new_matched_tokens( + request, num_new_local_computed_tokens)) + + if num_external_computed_tokens is None: + # The request cannot be scheduled because + # the KVConnector couldn't determine + # the number of matched tokens. + self.waiting.pop_request() + skipped_waiting_requests.prepend_request(request) + continue + + # Total computed tokens (local + external). + num_computed_tokens = (num_new_local_computed_tokens + + num_external_computed_tokens) + # KVTransfer: WAITING reqs have num_computed_tokens > 0 + # after async KV recvs are completed. + else: + new_computed_blocks = ( + self.kv_cache_manager.create_empty_block_list()) + num_new_local_computed_tokens = 0 + num_computed_tokens = request.num_computed_tokens + + encoder_inputs_to_schedule = None + new_encoder_compute_budget = encoder_compute_budget + + # KVTransfer: loading remote KV, do not allocate for new work. + if load_kv_async: + assert num_external_computed_tokens > 0 + num_new_tokens = 0 + # Number of tokens to be scheduled. + else: + # We use `request.num_tokens` instead of + # `request.num_prompt_tokens` to consider the resumed + # requests, which have output tokens. + num_new_tokens = request.num_tokens - num_computed_tokens + if (0 < self.scheduler_config.long_prefill_token_threshold + < num_new_tokens): + num_new_tokens = ( + self.scheduler_config.long_prefill_token_threshold) + + # chunked prefill has to be enabled explicitly to allow + # pooling requests to be chunked + if not self.scheduler_config.chunked_prefill_enabled and \ + num_new_tokens > token_budget: + self.waiting.pop_request() + skipped_waiting_requests.prepend_request(request) + continue + + num_new_tokens = min(num_new_tokens, token_budget) + assert num_new_tokens > 0 + + # Schedule encoder inputs. + if request.has_encoder_inputs: + (encoder_inputs_to_schedule, num_new_tokens, + new_encoder_compute_budget + ) = self._try_schedule_encoder_inputs( + request, num_computed_tokens, num_new_tokens, + encoder_compute_budget) + if num_new_tokens == 0: + # The request cannot be scheduled. + break + + # Handles an edge case when P/D Disaggregation + # is used with Spec Decoding where an + # extra block gets allocated which + # creates a mismatch between the number + # of local and remote blocks. + effective_lookahead_tokens = (0 if request.num_computed_tokens + == 0 else + self.num_lookahead_tokens) + + # Determine if we need to allocate cross-attention blocks. + if self.is_encoder_decoder and request.has_encoder_inputs: + # TODO(russellb): For Whisper, we know that the input is + # always padded to the maximum length. If we support other + # encoder-decoder models, this will need to be updated if we + # want to only allocate what is needed. + num_encoder_tokens = \ + self.scheduler_config.max_num_encoder_input_tokens + else: + num_encoder_tokens = 0 + + new_blocks = self.kv_cache_manager.allocate_slots( + request, + num_new_tokens + num_external_computed_tokens, + num_new_local_computed_tokens, + new_computed_blocks, + num_lookahead_tokens=effective_lookahead_tokens, + delay_cache_blocks=load_kv_async, + num_encoder_tokens=num_encoder_tokens, + ) + + if new_blocks is None: + # The request cannot be scheduled. + break + + # KVTransfer: the connector uses this info to determine + # if a load is needed. Note that + # This information is used to determine if a load is + # needed for this request. + if self.connector is not None: + self.connector.update_state_after_alloc( + request, + new_computed_blocks + new_blocks, + num_external_computed_tokens, + ) + + # Request was already popped from self.waiting + # unless it was re-added above due to new_blocks being None. + request = self.waiting.pop_request() + if load_kv_async: + # If loading async, allocate memory and put request + # into the WAITING_FOR_REMOTE_KV state. + skipped_waiting_requests.prepend_request(request) + request.status = RequestStatus.WAITING_FOR_REMOTE_KVS + continue + + req_index += 1 + self.running.append(request) + if self.log_stats: + request.record_event(EngineCoreEventType.SCHEDULED, + scheduled_timestamp) + if request.status == RequestStatus.WAITING: + scheduled_new_reqs.append(request) + elif request.status == RequestStatus.PREEMPTED: + scheduled_resumed_reqs.append(request) + else: + raise RuntimeError( + f"Invalid request status: {request.status}") + + if self.lora_config and request.lora_request: + scheduled_loras.add(request.lora_request.lora_int_id) + req_to_new_blocks[request.request_id] = ( + self.kv_cache_manager.get_blocks(request.request_id)) + num_scheduled_tokens[request.request_id] = num_new_tokens + token_budget -= num_new_tokens + request.status = RequestStatus.RUNNING + request.num_computed_tokens = num_computed_tokens + # Count the number of prefix cached tokens. + if request.num_cached_tokens < 0: + request.num_cached_tokens = num_computed_tokens + # Encoder-related. + if encoder_inputs_to_schedule: + scheduled_encoder_inputs[request.request_id] = ( + encoder_inputs_to_schedule) + # Allocate the encoder cache. + for i in encoder_inputs_to_schedule: + self.encoder_cache_manager.allocate(request, i) + encoder_compute_budget = new_encoder_compute_budget + + # Put back any skipped requests at the head of the waiting queue + if skipped_waiting_requests: + self.waiting.prepend_requests(skipped_waiting_requests) + + # Check if the scheduling constraints are satisfied. + total_num_scheduled_tokens = sum(num_scheduled_tokens.values()) + assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens + assert token_budget >= 0 + assert len(self.running) <= self.max_num_running_reqs + # Since some requests in the RUNNING queue may not be scheduled in + # this step, the total number of scheduled requests can be smaller than + # len(self.running). + assert (len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + + len(scheduled_running_reqs) <= len(self.running)) + + # Get the longest common prefix among all requests in the running queue. + # This can be potentially used for cascade attention. + num_common_prefix_blocks = [0] * len( + self.kv_cache_config.kv_cache_groups) + if self.running: + any_request = self.running[0] + num_common_prefix_blocks = ( + self.kv_cache_manager.get_num_common_prefix_blocks( + any_request, len(self.running))) + + # Construct the scheduler output. + new_reqs_data = [ + NewRequestData.from_request( + req, req_to_new_blocks[req.request_id].get_block_ids()) + for req in scheduled_new_reqs + ] + cached_reqs_data = self._make_cached_request_data( + scheduled_running_reqs, + scheduled_resumed_reqs, + num_scheduled_tokens, + scheduled_spec_decode_tokens, + req_to_new_blocks, + ) + scheduled_requests = (scheduled_new_reqs + scheduled_running_reqs + + scheduled_resumed_reqs) + structured_output_request_ids, grammar_bitmask = ( + self.get_grammar_bitmask(scheduled_requests, + scheduled_spec_decode_tokens)) + scheduler_output = RecomputeSchedulerOutput( + scheduled_new_reqs=new_reqs_data, + scheduled_cached_reqs=cached_reqs_data, + num_scheduled_tokens=num_scheduled_tokens, + total_num_scheduled_tokens=total_num_scheduled_tokens, + scheduled_spec_decode_tokens=scheduled_spec_decode_tokens, + scheduled_encoder_inputs=scheduled_encoder_inputs, + num_common_prefix_blocks=num_common_prefix_blocks, + # finished_req_ids is an existing state in the scheduler, + # instead of being newly scheduled in this step. + # It contains the request IDs that are finished in between + # the previous and the current steps. + finished_req_ids=self.finished_req_ids, + free_encoder_mm_hashes=self.encoder_cache_manager. + get_freed_mm_hashes(), + structured_output_request_ids=structured_output_request_ids, + grammar_bitmask=grammar_bitmask, + recomputed_reqs=recomputed_reqs, + ) + + # NOTE(Kuntai): this function is designed for multiple purposes: + # 1. Plan the KV cache store + # 2. Wrap up all the KV cache load / save ops into an opaque object + # 3. Clear the internal states of the connector + if self.connector is not None: + meta = self.connector.build_connector_meta(scheduler_output) + scheduler_output.kv_connector_metadata = meta + + # collect KV cache events from KV cache manager + events = self.kv_cache_manager.take_events() + + # collect KV cache events from connector + if self.connector is not None: + connector_events = self.connector.take_events() + if connector_events: + if events is None: + events = list(connector_events) + else: + events.extend(connector_events) + + # publish collected KV cache events + if events: + batch = KVEventBatch(ts=time.time(), events=events) + self.kv_event_publisher.publish(batch) + + self._update_after_schedule(scheduler_output) + return scheduler_output + + def _update_after_schedule( + self, + scheduler_output: RecomputeSchedulerOutput, + ) -> None: + # Advance the number of computed tokens for the request AFTER + # the request is scheduled. + # 1. The scheduler_output of the current step has to include the + # original number of scheduled tokens to determine input IDs. + # 2. Advance the number of computed tokens here allowing us to + # schedule the prefill request again immediately in the next + # scheduling step. + # 3. If some tokens (e.g. spec tokens) are rejected later, the number of + # computed tokens will be adjusted in update_from_output. + num_scheduled_tokens = scheduler_output.num_scheduled_tokens + for req_id, num_scheduled_token in num_scheduled_tokens.items(): + request = self.requests[req_id] + request.num_computed_tokens += num_scheduled_token + + # NOTE: _free_encoder_inputs relies on num_computed_tokens, which + # may be updated again in _update_from_output for speculative + # decoding. However, it is safe to call the method here because + # encoder inputs are always part of the prompt, not the output, + # and thus are unaffected by speculative decoding. + if request.has_encoder_inputs: + self._free_encoder_inputs(request) + + # Clear the finished request IDs. + # NOTE: We shouldn't do self.finished_req_ids.clear() here because + # it will also affect the scheduler output. + self.finished_req_ids = set() + + def _make_cached_request_data( + self, + running_reqs: list[Request], + resumed_reqs: list[Request], + num_scheduled_tokens: dict[str, int], + spec_decode_tokens: dict[str, list[int]], + req_to_new_blocks: dict[str, KVCacheBlocks], + ) -> CachedRequestData: + req_ids: list[str] = [] + new_token_ids: list[list[int]] = [] + new_block_ids: list[Optional[tuple[list[int], ...]]] = [] + num_computed_tokens: list[int] = [] + + use_connector = self.connector is not None + for req in itertools.chain(running_reqs, resumed_reqs): + req_id = req.request_id + req_ids.append(req_id) + num_tokens = (num_scheduled_tokens[req_id] - + len(spec_decode_tokens.get(req_id, ()))) + if self.use_pp: + # When using PP, the scheduler sends the sampled tokens back, + # because there's no direct communication between the first- + # stage worker and the last-stage worker. Otherwise, we don't + # need to send the sampled tokens back because the model runner + # will cache them. + token_ids = req.all_token_ids[req.num_computed_tokens:req. + num_computed_tokens + num_tokens] + new_token_ids.append(token_ids) + elif use_connector: + # When using a KVConnector, we add a placeholder to avoid index + # out of bounds errors. TODO: Remove this once the KVConnector + # is updated to handle token IDs properly. + new_token_ids.append([]) + new_block_ids.append( + req_to_new_blocks[req_id].get_block_ids(allow_none=True)) + num_computed_tokens.append(req.num_computed_tokens) + # Because resumed_reqs is usually empty, it is more efficient to do + # in-place appending so that we don't need to allocate a new list. + resumed_from_preemption = [False] * len(running_reqs) + resumed_from_preemption += [True] * len(resumed_reqs) + + return CachedRequestData( + req_ids=req_ids, + resumed_from_preemption=resumed_from_preemption, + new_token_ids=new_token_ids, + new_block_ids=new_block_ids, + num_computed_tokens=num_computed_tokens, + ) + + def _try_schedule_encoder_inputs( + self, + request: Request, + num_computed_tokens: int, + num_new_tokens: int, + encoder_compute_budget: int, + ) -> tuple[list[int], int, int]: + """ + Determine which encoder inputs need to be scheduled in the current step, + and update `num_new_tokens` and encoder token budget accordingly. + + An encoder input will be scheduled if: + - Its output tokens overlap with the range of tokens being computed + in this step, i.e., + [num_computed_tokens, num_computed_tokens + num_new_tokens). + - It is not already computed and stored in the encoder cache. + - There is sufficient encoder token budget to process it. + - The encoder cache has space to store it. + + If an encoder input cannot be scheduled due to cache or budget + limitations, the method adjusts `num_new_tokens` to schedule only the + decoder tokens up to just before the unschedulable encoder input. + + Note that num_computed_tokens includes both locally cached + blocks and externally cached blocks (via KVConnector). + """ + if num_new_tokens == 0 or not request.has_encoder_inputs: + return [], num_new_tokens, encoder_compute_budget + encoder_inputs_to_schedule: list[int] = [] + mm_features = request.mm_features + assert mm_features is not None + assert len(mm_features) > 0 + + # NOTE: since scheduler operates on the request level (possibly with + # multiple encoder inputs per request), we need to create temporary + # trackers for accounting at the encoder input level. + mm_hashes_to_schedule = set() + num_tokens_to_schedule = 0 + for i, mm_feature in enumerate(mm_features): + start_pos = mm_feature.mm_position.offset + num_encoder_tokens = mm_feature.mm_position.length + + # The encoder output is needed if the two ranges overlap: + # [num_computed_tokens, num_computed_tokens + num_new_tokens) and + # [start_pos, start_pos + num_encoder_tokens) + if start_pos >= num_computed_tokens + num_new_tokens: + # The encoder input is not needed in this step. + break + + if self.is_encoder_decoder and num_computed_tokens > 0: + assert start_pos == 0, ( + "Encoder input should be processed at the beginning of " + "the sequence when encoder-decoder models are used.") + # Encoder input has already been computed + # The calculation here is a bit different. We don't turn encoder + # output into tokens that get processed by the decoder and + # reflected in num_computed_tokens. Instead, start_pos reflects + # the position where we need to ensure we calculate encoder + # inputs. This should always be 0 to ensure we calculate encoder + # inputs before running the decoder. Once we've calculated some + # decoder tokens (num_computed_tokens > 0), then we know we + # already calculated encoder inputs and can skip here. + continue + elif start_pos + num_encoder_tokens <= num_computed_tokens: + # The encoder input is already computed and stored + # in the decoder's KV cache. + continue + + if not self.is_encoder_decoder: + # We are not using the encoder cache for encoder-decoder models, + # yet. + if request.mm_features[i].identifier in mm_hashes_to_schedule: + # The same encoder input has already been scheduled in the + # current step. + continue + + if self.encoder_cache_manager.check_and_update_cache( + request, i): + # The encoder input is already computed and cached from a + # previous step. + continue + + # If no encoder input chunking is allowed, we do not want to + # partially schedule a multimodal item. If the scheduled range would + # only cover part of the mm input, roll back to before the mm item. + if (self.scheduler_config.disable_chunked_mm_input + and num_computed_tokens < start_pos + and (num_computed_tokens + num_new_tokens) + < (start_pos + num_encoder_tokens)): + num_new_tokens = start_pos - num_computed_tokens + break + + if not self.encoder_cache_manager.can_allocate( + request, i, encoder_compute_budget, + num_tokens_to_schedule): + # The encoder cache is full or the encoder budget is exhausted. + # NOTE(woosuk): We assume that the encoder input tokens should + # be processed altogether, as the encoder usually uses + # bidirectional attention. + if num_computed_tokens < start_pos: + # We only schedule the decoder tokens just before the + # encoder input. + num_new_tokens = start_pos - num_computed_tokens + else: + # Because of prefix caching, num_computed_tokens is greater + # than start_pos even though its encoder input is not + # available. In this case, we can't schedule any token for + # the request in this step. + num_new_tokens = 0 + break + + num_tokens_to_schedule += num_encoder_tokens + encoder_compute_budget -= num_encoder_tokens + mm_hashes_to_schedule.add(request.mm_features[i].identifier) + encoder_inputs_to_schedule.append(i) + + return ( + encoder_inputs_to_schedule, + num_new_tokens, + encoder_compute_budget, + ) + + def get_grammar_bitmask( + self, + requests: list[Request], + scheduled_spec_decode_tokens: dict[str, list[int]], + ): + # NOTE: structured_output_request_ids maps + # a request's (request that uses structured output) + # request_id to its index in the batch. + # This will help us determine to slice the grammar bitmask + # and only applies valid mask for requests that + # uses structured decoding. + structured_output_request_ids: dict[str, int] = {} + for i, req in enumerate(requests): + if req.use_structured_output: + # PERF: in case of chunked prefill, + # request might not include any new tokens. + # Therefore, we might introduce some additional + # cycle to fill in the bitmask, which could be a big no-op. + structured_output_request_ids[req.request_id] = i + + if not structured_output_request_ids: + bitmask = None + else: + bitmask = self.structured_output_manager.grammar_bitmask( + self.requests, + structured_output_request_ids, + scheduled_spec_decode_tokens, + ) + return structured_output_request_ids, bitmask + + def update_from_output( + self, + scheduler_output: RecomputeSchedulerOutput, + model_runner_output: ModelRunnerOutput, + ) -> dict[int, EngineCoreOutputs]: + sampled_token_ids = model_runner_output.sampled_token_ids + logprobs = model_runner_output.logprobs + prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict + num_scheduled_tokens = scheduler_output.num_scheduled_tokens + pooler_outputs = model_runner_output.pooler_output + num_nans_in_logits = model_runner_output.num_nans_in_logits + kv_connector_output = model_runner_output.kv_connector_output + + outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list) + spec_decoding_stats: Optional[SpecDecodingStats] = None + kv_connector_stats = (kv_connector_output.kv_connector_stats + if kv_connector_output else None) + # return recomputed requests as EngineCoreOutput + for req_info in scheduler_output.recomputed_reqs: + outputs[req_info.client_index].append( + EngineCoreOutput( + request_id=req_info.request_id, + finish_reason=FinishReason.STOP, + new_token_ids=[req_info.output_token_ids[-1]], + stop_reason="recomputed", + )) + # NOTE(woosuk): As len(num_scheduled_tokens) can be up to 1K or more, + # the below loop can be a performance bottleneck. We should do our best + # to avoid expensive operations inside the loop. + stopped_running_reqs: set[Request] = set() + stopped_preempted_reqs: set[Request] = set() + for req_id, num_tokens_scheduled in num_scheduled_tokens.items(): + assert num_tokens_scheduled > 0 + request = self.requests.get(req_id) + if request is None: + # The request is already finished. This can happen if the + # request is aborted while the model is executing it (e.g., + # in pipeline parallelism). + continue + + req_index = model_runner_output.req_id_to_index[req_id] + generated_token_ids = sampled_token_ids[ + req_index] if sampled_token_ids else [] + + scheduled_spec_token_ids = ( + scheduler_output.scheduled_spec_decode_tokens.get(req_id)) + if scheduled_spec_token_ids: + num_draft_tokens = len(scheduled_spec_token_ids) + num_accepted = len(generated_token_ids) - 1 + num_rejected = num_draft_tokens - num_accepted + # num_computed_tokens represents the number of tokens + # processed in the current step, considering scheduled + # tokens and rejections. If some tokens are rejected, + # num_computed_tokens is decreased by the number of rejected + # tokens. + request.num_computed_tokens -= num_rejected + spec_decoding_stats = self.make_spec_decoding_stats( + spec_decoding_stats, + num_draft_tokens=num_draft_tokens, + num_accepted_tokens=num_accepted) + + stopped = False + new_logprobs = None + new_token_ids = generated_token_ids + kv_transfer_params = None + status_before_stop = request.status + + # Check for stop and update request status. + if new_token_ids: + new_token_ids, stopped = self._update_request_with_output( + request, new_token_ids) + + # Stop checking for pooler models. + pooler_output = None + if pooler_outputs: + pooler_output = pooler_outputs[req_index] + stopped = check_stop(request, self.max_model_len, + pooler_output) + + if stopped: + kv_transfer_params = self._free_request(request) + if status_before_stop == RequestStatus.RUNNING: + stopped_running_reqs.add(request) + else: + stopped_preempted_reqs.add(request) + + # Extract sample logprobs if needed. + if request.sampling_params is not None \ + and request.sampling_params.logprobs is not None and logprobs: + # NOTE: once we support N tokens per step (spec decode), + # the outer lists can be of length > 1. + new_logprobs = logprobs.slice(req_index, req_index + 1) + + if new_token_ids and self.structured_output_manager.should_advance( + request): + # NOTE: structured_output_request + # should not be None if use_structured_output, we have + # checked above, so safe to ignore type warning + request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr] + req_id, new_token_ids) + + if num_nans_in_logits is not None and req_id in num_nans_in_logits: + request.num_nans_in_logits = num_nans_in_logits[req_id] + + # Get prompt logprobs for this request. + prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) + if new_token_ids or pooler_output is not None \ + or kv_transfer_params: + + # Add EngineCoreOutput for this Request. + outputs[request.client_index].append( + EngineCoreOutput( + request_id=req_id, + new_token_ids=new_token_ids, + finish_reason=request.get_finished_reason(), + new_logprobs=new_logprobs, + new_prompt_logprobs_tensors=prompt_logprobs_tensors, + pooling_output=pooler_output, + stop_reason=request.stop_reason, + events=request.take_events(), + kv_transfer_params=kv_transfer_params, + trace_headers=request.trace_headers, + num_cached_tokens=request.num_cached_tokens, + )) + else: + # Invariant: EngineCore returns no partial prefill outputs. + assert not prompt_logprobs_tensors + + # Remove the stopped requests from the running and waiting queues. + if stopped_running_reqs: + self.running = remove_all(self.running, stopped_running_reqs) + if stopped_preempted_reqs: + # This is a rare case and unlikely to impact performance. + self.waiting.remove_requests(stopped_preempted_reqs) + + # KV Connector: update state for finished KV Transfers. + if model_runner_output.kv_connector_output: + self._update_from_kv_xfer_finished( + model_runner_output.kv_connector_output) + + # Create EngineCoreOutputs for all clients that have requests with + # outputs in this step. + engine_core_outputs = { + client_index: EngineCoreOutputs(outputs=outs) + for client_index, outs in outputs.items() + } + + finished_req_ids = self.finished_req_ids_dict + if finished_req_ids: + # Include ids of requests that finished since last outputs + # were sent. + for client_index, finished_set in finished_req_ids.items(): + # Set finished request set in EngineCoreOutputs for this client. + if (eco := engine_core_outputs.get(client_index)) is not None: + eco.finished_requests = finished_set + else: + engine_core_outputs[client_index] = EngineCoreOutputs( + finished_requests=finished_set) + finished_req_ids.clear() + + if (stats := self.make_stats(spec_decoding_stats, + kv_connector_stats)) is not None: + # Return stats to only one of the front-ends. + if (eco := next(iter(engine_core_outputs.values()), None)) is None: + # We must return the stats even if there are no request + # outputs this step. + engine_core_outputs[0] = eco = EngineCoreOutputs() + eco.scheduler_stats = stats + + return engine_core_outputs + + def _update_request_with_output( + self, + request: Request, + new_token_ids: list[int], + ) -> tuple[list[int], bool]: + # Append generated tokens and check for stop. Note that if + # a request is still being prefilled, we expect the model runner + # to return empty token ids for the request. + stopped = False + for num_new, output_token_id in enumerate(new_token_ids, 1): + request.append_output_token_ids(output_token_id) + + # Check for stop and update request state. + # This must be called before we make the EngineCoreOutput. + stopped = check_stop(request, self.max_model_len) + if stopped: + del new_token_ids[num_new:] # Trim new tokens if needed. + break + return new_token_ids, stopped + + def _free_encoder_inputs(self, request: Request) -> None: + cached_encoder_input_ids = ( + self.encoder_cache_manager.get_cached_input_ids(request)) + # OPTIMIZATION: Avoid list(set) if the set is empty. + if not cached_encoder_input_ids: + return + + # Here, we use list(set) to avoid modifying the set while iterating + # over it. + for input_id in list(cached_encoder_input_ids): + mm_feature = request.mm_features[input_id] + start_pos = mm_feature.mm_position.offset + num_tokens = mm_feature.mm_position.length + if self.is_encoder_decoder and request.num_computed_tokens > 0: + # With Whisper, as soon as we've generated a single token, + # we know we're done with the encoder input. Cross Attention + # KVs have been calculated and cached already. + self.encoder_cache_manager.free_encoder_input( + request, input_id) + elif start_pos + num_tokens <= request.num_computed_tokens: + # The encoder output is already processed and stored + # in the decoder's KV cache. + self.encoder_cache_manager.free_encoder_input( + request, input_id) + + def update_draft_token_ids( + self, + draft_token_ids: DraftTokenIds, + ) -> None: + for req_id, spec_token_ids in zip( + draft_token_ids.req_ids, + draft_token_ids.draft_token_ids, + ): + request = self.requests.get(req_id) + if request is None or request.is_finished(): + # The request may have been finished. Skip. + continue + + # Add newly generated spec token ids to the request. + if not spec_token_ids: + # NOTE(woosuk): request.spec_token_ids should be updated. + request.spec_token_ids.clear() + elif self.structured_output_manager.should_advance(request): + metadata = request.structured_output_request + request.spec_token_ids = metadata.grammar.validate_tokens( # type: ignore[union-attr] + spec_token_ids) + else: + request.spec_token_ids = spec_token_ids + + def get_request_counts(self) -> tuple[int, int]: + """Returns (num_running_reqs, num_waiting_reqs).""" + return len(self.running), len(self.waiting) + + def add_request(self, request: Request) -> None: + self.waiting.add_request(request) + self.requests[request.request_id] = request + if self.log_stats: + request.record_event(EngineCoreEventType.QUEUED) + + def finish_requests( + self, + request_ids: Union[str, Iterable[str]], + finished_status: RequestStatus, + ) -> None: + """Handles the finish signal from outside the scheduler. + + For example, the API server can abort a request when the client + disconnects. + """ + assert RequestStatus.is_finished(finished_status) + if isinstance(request_ids, str): + request_ids = (request_ids, ) + else: + request_ids = set(request_ids) + + running_requests_to_remove = set() + waiting_requests_to_remove = [] + valid_requests = [] + + # First pass: collect requests to remove from queues + for req_id in request_ids: + request = self.requests.get(req_id) + if request is None: + # Invalid request ID. + continue + + valid_requests.append(request) + if request.status == RequestStatus.RUNNING: + running_requests_to_remove.add(request) + else: + waiting_requests_to_remove.append(request) + + # Remove all requests from queues at once for better efficiency + if running_requests_to_remove: + self.running = remove_all(self.running, running_requests_to_remove) + if waiting_requests_to_remove: + self.waiting.remove_requests(waiting_requests_to_remove) + + # Second pass: set status and free requests + for request in valid_requests: + request.status = finished_status + self._free_request(request) + + def _free_request(self, request: Request) -> Optional[dict[str, Any]]: + assert request.is_finished() + + delay_free_blocks, kv_xfer_params = self._connector_finished(request) + self.encoder_cache_manager.free(request) + request_id = request.request_id + self.finished_req_ids.add(request_id) + if self.finished_req_ids_dict is not None: + self.finished_req_ids_dict[request.client_index].add(request_id) + + if not delay_free_blocks: + self._free_blocks(request) + + return kv_xfer_params + + def _free_blocks(self, request: Request): + assert request.is_finished() + self.kv_cache_manager.free(request) + del self.requests[request.request_id] + + def get_num_unfinished_requests(self) -> int: + return len(self.waiting) + len(self.running) + + def has_finished_requests(self) -> bool: + return len(self.finished_req_ids) > 0 + + def reset_prefix_cache(self) -> bool: + return self.kv_cache_manager.reset_prefix_cache() + + def make_stats( + self, + spec_decoding_stats: Optional[SpecDecodingStats] = None, + kv_connector_stats: Optional[KVConnectorStats] = None, + ) -> Optional[SchedulerStats]: + if not self.log_stats: + return None + prefix_cache_stats = self.kv_cache_manager.make_prefix_cache_stats() + assert prefix_cache_stats is not None + return SchedulerStats(num_running_reqs=len(self.running), + num_waiting_reqs=len(self.waiting), + kv_cache_usage=self.kv_cache_manager.usage, + prefix_cache_stats=prefix_cache_stats, + spec_decoding_stats=spec_decoding_stats, + num_corrupted_reqs=sum(req.is_output_corrupted + for req in self.running), + kv_connector_stats=kv_connector_stats.data + if kv_connector_stats else None) + + def make_spec_decoding_stats( + self, + spec_decoding_stats: Optional[SpecDecodingStats], + num_draft_tokens: int, + num_accepted_tokens: int, + ) -> Optional[SpecDecodingStats]: + if not self.log_stats: + return None + if spec_decoding_stats is None: + spec_decoding_stats = SpecDecodingStats.new(self.num_spec_tokens) + spec_decoding_stats.observe_draft( + num_draft_tokens=num_draft_tokens, + num_accepted_tokens=num_accepted_tokens) + return spec_decoding_stats + + def shutdown(self) -> None: + if self.kv_event_publisher: + self.kv_event_publisher.shutdown() + if self.connector is not None: + self.connector.shutdown() + + ######################################################################## + # KV Connector Related Methods + ######################################################################## + + def get_kv_connector(self) -> Optional[KVConnectorBase_V1]: + return self.connector + + def _connector_finished( + self, request: Request) -> tuple[bool, Optional[dict[str, Any]]]: + """ + Invoke the KV connector request_finished() method if applicable. + + Returns optional kv transfer parameters to be included with the + request outputs. + """ + if self.connector is None: + return False, None + + (block_ids, ) = self.kv_cache_manager.get_block_ids(request.request_id) + return self.connector.request_finished(request, block_ids) + + def _update_waiting_for_remote_kv(self, request: Request) -> bool: + """ + KV Connector: check if the request_id is finished_recving. + + The finished_recving_kv_req_ids list is populated + on the previous steps()'s update_from_output based + on the worker side connector. + + When the kv transfer is ready, we cache the blocks + and the request state will be moved back to WAITING from + WAITING_FOR_REMOTE_KV. + """ + assert self.connector is not None + if request.request_id not in self.finished_recving_kv_req_ids: + return False + + # Now that the blocks are ready, actually cache them. + (block_ids, ) = self.kv_cache_manager.get_block_ids(request.request_id) + num_computed_tokens = len(block_ids) * self.block_size + # Handle the case where num request tokens less than one block. + num_computed_tokens = min(num_computed_tokens, request.num_tokens) + if num_computed_tokens == request.num_tokens: + num_computed_tokens -= 1 + # This will cache the blocks iff caching is enabled. + self.kv_cache_manager.cache_blocks(request, num_computed_tokens) + + # Update the request state for scheduling. + request.num_computed_tokens = num_computed_tokens + + # Return that we are ready. + self.finished_recving_kv_req_ids.remove(request.request_id) + return True + + def _update_from_kv_xfer_finished(self, + kv_connector_output: KVConnectorOutput): + """ + KV Connector: update the scheduler state based on the output. + + The Worker side connectors add finished_recving and + finished_sending reqs to the output. + * if finished_sending: free the blocks + # if finished_recving: add to state so we can + schedule the request during the next step. + """ + + if self.connector is not None: + self.connector.update_connector_output(kv_connector_output) + + # KV Connector:: update recv and send status from last step. + for req_id in (kv_connector_output.finished_recving or ()): + logger.debug("Finished recving KV transfer for request %s", req_id) + self.finished_recving_kv_req_ids.add(req_id) + for req_id in (kv_connector_output.finished_sending or ()): + logger.debug("Finished sending KV transfer for request %s", req_id) + if req_id not in self.requests: + logger.warning( + "Got finished sending KV transfer for request %s," + "but the request is already freed.", req_id) + else: + self._free_blocks(self.requests[req_id]) + + +@dataclass +class RecomputeReqInfo: + request_id: str + output_token_ids: ConstantList + client_index: int = 0 + + +@dataclass +class RecomputeSchedulerOutput: + + # list of the requests that are scheduled for the first time. + # We cache the request's data in each worker process, so that we don't + # need to re-send it every scheduling step. + scheduled_new_reqs: list[NewRequestData] + # list of the requests that have been scheduled before. + # Since the request's data is already cached in the worker processes, + # we only send the diff to minimize the communication cost. + scheduled_cached_reqs: CachedRequestData + + # req_id -> num_scheduled_tokens + # Number of tokens scheduled for each request. + num_scheduled_tokens: dict[str, int] + # Total number of tokens scheduled for all requests. + # Equal to sum(num_scheduled_tokens.values()) + total_num_scheduled_tokens: int + # req_id -> spec_token_ids + # If a request does not have any spec decode tokens, it will not be + # included in the dictionary. + scheduled_spec_decode_tokens: dict[str, list[int]] + # req_id -> encoder input indices that need processing. + # E.g., if a request has [0, 1], it could mean the vision encoder needs + # to process that the request's 0-th and 1-th images in the current step. + scheduled_encoder_inputs: dict[str, list[int]] + # Number of common prefix blocks for all requests in each KV cache group. + # This can be used for cascade attention. + num_common_prefix_blocks: list[int] + + # Request IDs that are finished in between the previous and the current + # steps. This is used to notify the workers about the finished requests + # so that they can free the cached states for those requests. + finished_req_ids: set[str] + # list of mm_hash strings associated with the encoder outputs to be + # freed from the encoder cache. + free_encoder_mm_hashes: list[str] + + # Dict of request ids to their index within the batch + # for filling the next token bitmask + structured_output_request_ids: dict[str, int] + # the bitmask for the whole batch + grammar_bitmask: Optional[npt.NDArray[np.int32]] + + # requests that need to recompute kv + recomputed_reqs: list[RecomputeReqInfo] + + # KV Cache Connector metadata. + kv_connector_metadata: Optional[KVConnectorMetadata] = None diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index ac8f01151..a67f05425 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -300,6 +300,12 @@ class NPUPlatform(Platform): vllm_config.scheduler_config, ascend_config.ascend_scheduler_config) vllm_config.scheduler_config = ascend_scheduler_config + elif ascend_config.recompute_scheduler_enable: + from vllm_ascend.core.recompute_schedule_config import \ + RecomputeSchedulerConfig + recompute_scheduler_config = RecomputeSchedulerConfig.initialize_from_config( + vllm_config.scheduler_config) + vllm_config.scheduler_config = recompute_scheduler_config @classmethod def get_attn_backend_cls(