mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
Add common states
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
50
vllm/v1/core/sched/common.py
Normal file
50
vllm/v1/core/sched/common.py
Normal file
@ -0,0 +1,50 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from vllm.v1.core.sched.output import CachedRequestData
|
||||
from vllm.v1.request import Request
|
||||
|
||||
|
||||
class CommonSchedulerStates:
|
||||
|
||||
def __init__(self):
|
||||
# The request IDs that are finished in between the previous and the
|
||||
# current steps. This is used to notify the workers about the finished
|
||||
# requests so that they can free the cached states for those requests.
|
||||
# This is flushed at the end of each scheduling step.
|
||||
self.finished_req_ids: set[str] = set()
|
||||
|
||||
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
|
||||
# them at each scheduling step.
|
||||
# Request id -> CachedRequestData
|
||||
self._cached_reqs_data: dict[str, CachedRequestData] = {}
|
||||
|
||||
def make_cached_request_data(
|
||||
self,
|
||||
request: Request,
|
||||
num_scheduled_tokens: int,
|
||||
num_scheduled_spec_tokens: int,
|
||||
new_block_ids: list[int],
|
||||
resumed_from_preemption: bool,
|
||||
) -> CachedRequestData:
|
||||
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
|
||||
# them at each scheduling step.
|
||||
num_computed_tokens = request.num_computed_tokens
|
||||
num_regular_tokens = num_scheduled_tokens - num_scheduled_spec_tokens
|
||||
new_token_ids = request.all_token_ids[
|
||||
num_computed_tokens:num_computed_tokens + num_regular_tokens]
|
||||
req_data = self._cached_reqs_data.get(request.request_id)
|
||||
if req_data is not None:
|
||||
req_data.resumed_from_preemption = resumed_from_preemption
|
||||
req_data.new_token_ids = new_token_ids
|
||||
req_data.new_block_ids = new_block_ids
|
||||
req_data.num_computed_tokens = num_computed_tokens
|
||||
else:
|
||||
req_data = CachedRequestData.from_request(request,
|
||||
resumed_from_preemption,
|
||||
new_token_ids,
|
||||
new_block_ids)
|
||||
self._cached_reqs_data[request.request_id] = req_data
|
||||
return req_data
|
||||
|
||||
def free_request(self, request: Request) -> None:
|
||||
self._cached_reqs_data.pop(request.request_id, None)
|
||||
self.finished_req_ids.add(request.request_id)
|
@ -13,9 +13,9 @@ from vllm.logger import init_logger
|
||||
from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
|
||||
compute_encoder_budget)
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheManager
|
||||
from vllm.v1.core.sched.common import CommonSchedulerStates
|
||||
from vllm.v1.core.sched.interface import SchedulerInterface
|
||||
from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
|
||||
SchedulerOutput)
|
||||
from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput
|
||||
from vllm.v1.core.sched.utils import check_stop
|
||||
from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType,
|
||||
EngineCoreOutput, EngineCoreOutputs)
|
||||
@ -73,16 +73,8 @@ class Scheduler(SchedulerInterface):
|
||||
# by the executor.
|
||||
self.scheduled_req_ids: set[str] = set()
|
||||
|
||||
# The request IDs that are finished in between the previous and the
|
||||
# current steps. This is used to notify the workers about the finished
|
||||
# requests so that they can free the cached states for those requests.
|
||||
# This is flushed at the end of each scheduling step.
|
||||
self.finished_req_ids: set[str] = set()
|
||||
|
||||
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
|
||||
# them at each scheduling step.
|
||||
# Request id -> CachedRequestData
|
||||
self._cached_reqs_data: dict[str, CachedRequestData] = {}
|
||||
# Misc states for the scheduler.
|
||||
self.states = CommonSchedulerStates()
|
||||
|
||||
# Encoder-related.
|
||||
# Calculate encoder cache size if applicable
|
||||
@ -386,7 +378,7 @@ class Scheduler(SchedulerInterface):
|
||||
for req in scheduled_new_reqs
|
||||
]
|
||||
resumed_reqs_data = [
|
||||
self._make_cached_request_data(
|
||||
self.states.make_cached_request_data(
|
||||
req,
|
||||
num_scheduled_tokens[req.request_id],
|
||||
len(scheduled_spec_decode_tokens.get(req.request_id, ())),
|
||||
@ -395,7 +387,7 @@ class Scheduler(SchedulerInterface):
|
||||
) for req in scheduled_resumed_reqs
|
||||
]
|
||||
running_reqs_data = [
|
||||
self._make_cached_request_data(
|
||||
self.states.make_cached_request_data(
|
||||
req,
|
||||
num_scheduled_tokens[req.request_id],
|
||||
len(scheduled_spec_decode_tokens.get(req.request_id, ())),
|
||||
@ -415,43 +407,15 @@ class Scheduler(SchedulerInterface):
|
||||
# instead of being newly scheduled in this step.
|
||||
# It contains the request IDs that are finished in between
|
||||
# the previous and the current steps.
|
||||
finished_req_ids=self.finished_req_ids,
|
||||
finished_req_ids=self.states.finished_req_ids,
|
||||
free_encoder_input_ids=self.encoder_cache_manager.get_freed_ids(),
|
||||
structured_output_request_ids=structured_output_request_ids,
|
||||
grammar_bitmask=grammar_bitmask,
|
||||
)
|
||||
|
||||
self.finished_req_ids = set()
|
||||
self.states.finished_req_ids = set()
|
||||
return scheduler_output
|
||||
|
||||
def _make_cached_request_data(
|
||||
self,
|
||||
request: Request,
|
||||
num_scheduled_tokens: int,
|
||||
num_scheduled_spec_tokens: int,
|
||||
new_block_ids: list[int],
|
||||
resumed_from_preemption: bool,
|
||||
) -> CachedRequestData:
|
||||
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
|
||||
# them at each scheduling step.
|
||||
num_computed_tokens = request.num_computed_tokens
|
||||
num_regular_tokens = num_scheduled_tokens - num_scheduled_spec_tokens
|
||||
new_token_ids = request.all_token_ids[
|
||||
num_computed_tokens:num_computed_tokens + num_regular_tokens]
|
||||
req_data = self._cached_reqs_data.get(request.request_id)
|
||||
if req_data is not None:
|
||||
req_data.resumed_from_preemption = resumed_from_preemption
|
||||
req_data.new_token_ids = new_token_ids
|
||||
req_data.new_block_ids = new_block_ids
|
||||
req_data.num_computed_tokens = num_computed_tokens
|
||||
else:
|
||||
req_data = CachedRequestData.from_request(request,
|
||||
resumed_from_preemption,
|
||||
new_token_ids,
|
||||
new_block_ids)
|
||||
self._cached_reqs_data[request.request_id] = req_data
|
||||
return req_data
|
||||
|
||||
def _try_schedule_encoder_inputs(
|
||||
self,
|
||||
request: Request,
|
||||
@ -688,15 +652,14 @@ class Scheduler(SchedulerInterface):
|
||||
self.kv_cache_manager.free(request)
|
||||
self.kv_cache_manager.free_block_hashes(request)
|
||||
self.encoder_cache_manager.free(request)
|
||||
self._cached_reqs_data.pop(request.request_id, None)
|
||||
self.states.free_request(request)
|
||||
del self.requests[request.request_id]
|
||||
self.finished_req_ids.add(request.request_id)
|
||||
|
||||
def get_num_unfinished_requests(self) -> int:
|
||||
return len(self.waiting) + len(self.running)
|
||||
|
||||
def has_finished_requests(self) -> bool:
|
||||
return len(self.finished_req_ids) > 0
|
||||
return len(self.states.finished_req_ids) > 0
|
||||
|
||||
def get_num_unscheduled_requests(self) -> int:
|
||||
"""Number of requests that are not being processed by the executor."""
|
||||
|
Reference in New Issue
Block a user