Add common states

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon
2025-03-12 22:04:57 -07:00
parent 6cd1b1a18c
commit da07067215
2 changed files with 60 additions and 47 deletions

View File

@ -0,0 +1,50 @@
# SPDX-License-Identifier: Apache-2.0
from vllm.v1.core.sched.output import CachedRequestData
from vllm.v1.request import Request
class CommonSchedulerStates:
def __init__(self):
# The request IDs that are finished in between the previous and the
# current steps. This is used to notify the workers about the finished
# requests so that they can free the cached states for those requests.
# This is flushed at the end of each scheduling step.
self.finished_req_ids: set[str] = set()
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
# them at each scheduling step.
# Request id -> CachedRequestData
self._cached_reqs_data: dict[str, CachedRequestData] = {}
def make_cached_request_data(
self,
request: Request,
num_scheduled_tokens: int,
num_scheduled_spec_tokens: int,
new_block_ids: list[int],
resumed_from_preemption: bool,
) -> CachedRequestData:
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
# them at each scheduling step.
num_computed_tokens = request.num_computed_tokens
num_regular_tokens = num_scheduled_tokens - num_scheduled_spec_tokens
new_token_ids = request.all_token_ids[
num_computed_tokens:num_computed_tokens + num_regular_tokens]
req_data = self._cached_reqs_data.get(request.request_id)
if req_data is not None:
req_data.resumed_from_preemption = resumed_from_preemption
req_data.new_token_ids = new_token_ids
req_data.new_block_ids = new_block_ids
req_data.num_computed_tokens = num_computed_tokens
else:
req_data = CachedRequestData.from_request(request,
resumed_from_preemption,
new_token_ids,
new_block_ids)
self._cached_reqs_data[request.request_id] = req_data
return req_data
def free_request(self, request: Request) -> None:
self._cached_reqs_data.pop(request.request_id, None)
self.finished_req_ids.add(request.request_id)

View File

@ -13,9 +13,9 @@ from vllm.logger import init_logger
from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
compute_encoder_budget)
from vllm.v1.core.kv_cache_manager import KVCacheManager
from vllm.v1.core.sched.common import CommonSchedulerStates
from vllm.v1.core.sched.interface import SchedulerInterface
from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
SchedulerOutput)
from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput
from vllm.v1.core.sched.utils import check_stop
from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType,
EngineCoreOutput, EngineCoreOutputs)
@ -73,16 +73,8 @@ class Scheduler(SchedulerInterface):
# by the executor.
self.scheduled_req_ids: set[str] = set()
# The request IDs that are finished in between the previous and the
# current steps. This is used to notify the workers about the finished
# requests so that they can free the cached states for those requests.
# This is flushed at the end of each scheduling step.
self.finished_req_ids: set[str] = set()
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
# them at each scheduling step.
# Request id -> CachedRequestData
self._cached_reqs_data: dict[str, CachedRequestData] = {}
# Misc states for the scheduler.
self.states = CommonSchedulerStates()
# Encoder-related.
# Calculate encoder cache size if applicable
@ -386,7 +378,7 @@ class Scheduler(SchedulerInterface):
for req in scheduled_new_reqs
]
resumed_reqs_data = [
self._make_cached_request_data(
self.states.make_cached_request_data(
req,
num_scheduled_tokens[req.request_id],
len(scheduled_spec_decode_tokens.get(req.request_id, ())),
@ -395,7 +387,7 @@ class Scheduler(SchedulerInterface):
) for req in scheduled_resumed_reqs
]
running_reqs_data = [
self._make_cached_request_data(
self.states.make_cached_request_data(
req,
num_scheduled_tokens[req.request_id],
len(scheduled_spec_decode_tokens.get(req.request_id, ())),
@ -415,43 +407,15 @@ class Scheduler(SchedulerInterface):
# instead of being newly scheduled in this step.
# It contains the request IDs that are finished in between
# the previous and the current steps.
finished_req_ids=self.finished_req_ids,
finished_req_ids=self.states.finished_req_ids,
free_encoder_input_ids=self.encoder_cache_manager.get_freed_ids(),
structured_output_request_ids=structured_output_request_ids,
grammar_bitmask=grammar_bitmask,
)
self.finished_req_ids = set()
self.states.finished_req_ids = set()
return scheduler_output
def _make_cached_request_data(
self,
request: Request,
num_scheduled_tokens: int,
num_scheduled_spec_tokens: int,
new_block_ids: list[int],
resumed_from_preemption: bool,
) -> CachedRequestData:
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
# them at each scheduling step.
num_computed_tokens = request.num_computed_tokens
num_regular_tokens = num_scheduled_tokens - num_scheduled_spec_tokens
new_token_ids = request.all_token_ids[
num_computed_tokens:num_computed_tokens + num_regular_tokens]
req_data = self._cached_reqs_data.get(request.request_id)
if req_data is not None:
req_data.resumed_from_preemption = resumed_from_preemption
req_data.new_token_ids = new_token_ids
req_data.new_block_ids = new_block_ids
req_data.num_computed_tokens = num_computed_tokens
else:
req_data = CachedRequestData.from_request(request,
resumed_from_preemption,
new_token_ids,
new_block_ids)
self._cached_reqs_data[request.request_id] = req_data
return req_data
def _try_schedule_encoder_inputs(
self,
request: Request,
@ -688,15 +652,14 @@ class Scheduler(SchedulerInterface):
self.kv_cache_manager.free(request)
self.kv_cache_manager.free_block_hashes(request)
self.encoder_cache_manager.free(request)
self._cached_reqs_data.pop(request.request_id, None)
self.states.free_request(request)
del self.requests[request.request_id]
self.finished_req_ids.add(request.request_id)
def get_num_unfinished_requests(self) -> int:
return len(self.waiting) + len(self.running)
def has_finished_requests(self) -> bool:
return len(self.finished_req_ids) > 0
return len(self.states.finished_req_ids) > 0
def get_num_unscheduled_requests(self) -> int:
"""Number of requests that are not being processed by the executor."""