Compare commits

...

1 Commits

Author SHA1 Message Date
90eb28ca21 [V1][Scheduler] Use dict for running queue
This is just a random idea, still need to benchmark

Potential advantages for large batch sizes:
- Don't need to copy entire list every iteration
- O(1) removal of aborted requests

Signed-off-by: Nick Hill <nhill@redhat.com>
2025-03-13 16:11:07 -04:00

View File

@ -66,7 +66,7 @@ class Scheduler:
self.requests: dict[str, Request] = {} self.requests: dict[str, Request] = {}
# Priority queues for requests. # Priority queues for requests.
self.waiting: deque[Request] = deque() self.waiting: deque[Request] = deque()
self.running: list[Request] = [] self.running: dict[Request, None] = {}
# The requests that have been scheduled and are being executed # The requests that have been scheduled and are being executed
# by the executor. # by the executor.
self.scheduled_req_ids: set[str] = set() self.scheduled_req_ids: set[str] = set()
@ -140,12 +140,12 @@ class Scheduler:
scheduled_timestamp = time.monotonic() scheduled_timestamp = time.monotonic()
# First, schedule the RUNNING requests. # First, schedule the RUNNING requests.
req_index = 0 running_count = len(self.running)
while req_index < len(self.running) and token_budget > 0: for req_index, request in enumerate(self.running):
request = self.running[req_index] if token_budget <= 0 or req_index == running_count:
break
if request.request_id in self.scheduled_req_ids: if request.request_id in self.scheduled_req_ids:
# This request has already been scheduled. # This request has already been scheduled.
req_index += 1
continue continue
num_new_tokens = (request.num_tokens_with_spec - num_new_tokens = (request.num_tokens_with_spec -
@ -165,7 +165,6 @@ class Scheduler:
# NOTE(woosuk): Here, by doing `continue` instead of `break`, # NOTE(woosuk): Here, by doing `continue` instead of `break`,
# we do not strictly follow the FCFS scheduling policy and # we do not strictly follow the FCFS scheduling policy and
# allow the lower-priority requests to be scheduled. # allow the lower-priority requests to be scheduled.
req_index += 1
continue continue
while True: while True:
@ -174,7 +173,8 @@ class Scheduler:
if new_blocks is None: if new_blocks is None:
# The request cannot be scheduled. # The request cannot be scheduled.
# Preempt the lowest-priority request. # Preempt the lowest-priority request.
preempted_req = self.running.pop() preempted_req = next(reversed(self.running))
running_count += 1
self.kv_cache_manager.free(preempted_req) self.kv_cache_manager.free(preempted_req)
preempted_req.status = RequestStatus.PREEMPTED preempted_req.status = RequestStatus.PREEMPTED
preempted_req.num_computed_tokens = 0 preempted_req.num_computed_tokens = 0
@ -182,7 +182,7 @@ class Scheduler:
self.waiting.appendleft(preempted_req) self.waiting.appendleft(preempted_req)
preempted_reqs.append(preempted_req) preempted_reqs.append(preempted_req)
if preempted_req == request: if req_index == running_count:
# No more request to preempt. # No more request to preempt.
can_schedule = False can_schedule = False
break break
@ -208,7 +208,6 @@ class Scheduler:
] ]
num_scheduled_tokens[request.request_id] = num_new_tokens num_scheduled_tokens[request.request_id] = num_new_tokens
token_budget -= num_new_tokens token_budget -= num_new_tokens
req_index += 1
# Speculative decode related. # Speculative decode related.
if request.spec_token_ids: if request.spec_token_ids:
@ -230,6 +229,10 @@ class Scheduler:
self.encoder_cache_manager.allocate(request, i) self.encoder_cache_manager.allocate(request, i)
encoder_budget = new_encoder_budget encoder_budget = new_encoder_budget
# Remove preempted requests from the running queue.
while len(self.running) > running_count:
self.running.popitem()
# Record the LoRAs in scheduled_running_reqs # Record the LoRAs in scheduled_running_reqs
requested_loras: set[int] = set() requested_loras: set[int] = set()
if self.lora_config: if self.lora_config:
@ -255,9 +258,8 @@ class Scheduler:
if structured_output_req and structured_output_req.grammar: if structured_output_req and structured_output_req.grammar:
request.status = RequestStatus.WAITING request.status = RequestStatus.WAITING
else: else:
waiting_structured_output_req = self.waiting.popleft() self.waiting.popleft()
waiting_for_fsm.appendleft( waiting_for_fsm.appendleft(request)
waiting_structured_output_req)
continue continue
# Check that adding the request still respects the max_loras # Check that adding the request still respects the max_loras
@ -316,9 +318,8 @@ class Scheduler:
self.waiting.popleft() self.waiting.popleft()
if request.use_structured_output: if request.use_structured_output:
structured_output_request_ids[ structured_output_request_ids[
request.request_id] = req_index request.request_id] = running_count
req_index += 1 self.running[request] = None
self.running.append(request)
self.scheduled_req_ids.add(request.request_id) self.scheduled_req_ids.add(request.request_id)
self.request_scheduled(request, scheduled_timestamp) self.request_scheduled(request, scheduled_timestamp)
if request.status == RequestStatus.WAITING: if request.status == RequestStatus.WAITING:
@ -367,7 +368,7 @@ class Scheduler:
# This can be potentially used for cascade attention. # This can be potentially used for cascade attention.
num_common_prefix_blocks = 0 num_common_prefix_blocks = 0
if self.running: if self.running:
any_request = self.running[0] any_request = next(iter(self.running))
num_common_prefix_blocks = ( num_common_prefix_blocks = (
self.kv_cache_manager.get_num_common_prefix_blocks( self.kv_cache_manager.get_num_common_prefix_blocks(
any_request, len(self.running))) any_request, len(self.running)))
@ -531,7 +532,7 @@ class Scheduler:
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
num_scheduled_tokens = scheduler_output.num_scheduled_tokens num_scheduled_tokens = scheduler_output.num_scheduled_tokens
new_running: list[Request] = [] stopped_requests: list[Request] = []
outputs: list[EngineCoreOutput] = [] outputs: list[EngineCoreOutput] = []
# NOTE(woosuk): As len(self.running) can be up to 1K or more, the below # NOTE(woosuk): As len(self.running) can be up to 1K or more, the below
@ -542,7 +543,6 @@ class Scheduler:
num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0) num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0)
if num_tokens_scheduled == 0: if num_tokens_scheduled == 0:
# The request was not scheduled in this step. # The request was not scheduled in this step.
new_running.append(request)
continue continue
req_index = model_runner_output.req_id_to_index[req_id] req_index = model_runner_output.req_id_to_index[req_id]
@ -601,6 +601,7 @@ class Scheduler:
stopped = self._check_stop(request) stopped = self._check_stop(request)
if stopped: if stopped:
self._free_request(request) self._free_request(request)
stopped_requests.append(request)
break break
# Extract sample logprobs if needed. # Extract sample logprobs if needed.
@ -635,10 +636,10 @@ class Scheduler:
events=request.take_events())) events=request.take_events()))
self.scheduled_req_ids.remove(request.request_id) self.scheduled_req_ids.remove(request.request_id)
if not stopped:
new_running.append(request)
self.running = new_running for stopped_request in stopped_requests:
del self.running[stopped_request]
return EngineCoreOutputs( return EngineCoreOutputs(
outputs=outputs, outputs=outputs,
scheduler_stats=self.make_stats(), scheduler_stats=self.make_stats(),
@ -691,7 +692,7 @@ class Scheduler:
continue continue
if request.status == RequestStatus.RUNNING: if request.status == RequestStatus.RUNNING:
self.running.remove(request) del self.running[request]
self.scheduled_req_ids.discard(request.request_id) self.scheduled_req_ids.discard(request.request_id)
else: else:
self.waiting.remove(request) self.waiting.remove(request)