Files
vllm/tests/v1/core/test_scheduler.py
Woosuk Kwon 22bf5c5077 fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-10-11 11:38:33 -07:00

2083 lines
73 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
from typing import Optional
from unittest.mock import Mock
import pytest
import torch
from vllm.config import (
CacheConfig,
KVTransferConfig,
ModelConfig,
SchedulerConfig,
SpeculativeConfig,
VllmConfig,
)
from vllm.multimodal.inputs import (
MultiModalFeatureSpec,
MultiModalKwargsItem,
PlaceholderRange,
)
from vllm.sampling_params import SamplingParams, StructuredOutputsParams
from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.kv_cache_interface import (
FullAttentionSpec,
KVCacheConfig,
KVCacheGroupSpec,
)
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.structured_output import StructuredOutputManager
from vllm.v1.structured_output.request import StructuredOutputRequest
from .utils import EOS_TOKEN_ID, create_requests, create_scheduler
pytestmark = pytest.mark.cpu_test
def test_add_requests():
scheduler = create_scheduler()
requests = create_requests(num_requests=10)
for i, request in enumerate(requests):
scheduler.add_request(request)
assert request.request_id in scheduler.requests
assert len(scheduler.waiting) == i + 1
def test_finish_request():
scheduler = create_scheduler()
requests = create_requests(num_requests=10)
for request in requests:
scheduler.add_request(request)
for i, request in enumerate(requests):
scheduler.finish_requests(request.request_id, RequestStatus.FINISHED_ABORTED)
assert request.request_id not in scheduler.requests
assert len(scheduler.waiting) == 9 - i
def test_get_num_unfinished_requests():
scheduler = create_scheduler()
requests = create_requests(num_requests=10)
for request in requests:
scheduler.add_request(request)
for i, request in enumerate(requests):
scheduler.finish_requests(request.request_id, RequestStatus.FINISHED_STOPPED)
assert scheduler.get_num_unfinished_requests() == len(requests) - i - 1
@pytest.mark.parametrize(
"enable_prefix_caching, prompt_logprobs",
[
(None, None),
(True, 5),
],
)
def test_schedule(
enable_prefix_caching: Optional[bool], prompt_logprobs: Optional[int]
):
"""Test scheduling.
Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs
"""
scheduler = create_scheduler(enable_prefix_caching=enable_prefix_caching)
requests = create_requests(num_requests=10, prompt_logprobs=prompt_logprobs)
for request in requests:
scheduler.add_request(request)
# Test initial scheduling
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == len(requests)
assert output.scheduled_cached_reqs.num_reqs == 0
assert len(output.finished_req_ids) == 0
# Verify all requests are scheduled.
for req_id, num_tokens in output.num_scheduled_tokens.items():
assert num_tokens == len(requests[int(req_id)].prompt_token_ids)
# Verify requests moved from waiting to running
assert len(scheduler.waiting) == 0
assert len(scheduler.running) == len(requests)
for i, request in enumerate(requests):
assert scheduler.running[i] == request
def test_schedule_multimodal_requests():
scheduler = create_scheduler(model="llava-hf/llava-1.5-7b-hf")
mm_positions = [[PlaceholderRange(offset=i, length=100)] for i in range(10)]
requests = create_requests(
num_requests=10,
num_tokens=200,
mm_positions=mm_positions,
)
for request in requests:
scheduler.add_request(request)
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == len(requests)
assert output.scheduled_cached_reqs.num_reqs == 0
assert len(output.finished_req_ids) == 0
for req_id, num_tokens in output.num_scheduled_tokens.items():
assert num_tokens == len(requests[int(req_id)].prompt_token_ids)
assert len(output.scheduled_encoder_inputs) == 10
for req_id, encoder_input in output.scheduled_encoder_inputs.items():
assert len(encoder_input) == 1
def test_schedule_partial_requests():
"""Test scheduling behavior with partial requests.
This test verifies that:
1. The scheduler can handle multiple partial requests in a single step when
constrained by encoder budget.
2. A request in RUNNING state may be unscheduled in subsequent steps if
there is insufficient encoder budget.
"""
scheduler = create_scheduler(
model="llava-hf/llava-1.5-7b-hf",
max_num_batched_tokens=1024,
)
mm_positions = [[PlaceholderRange(offset=100, length=600)] for _ in range(3)]
requests = create_requests(
num_requests=3,
num_tokens=800,
mm_positions=mm_positions,
)
for request in requests:
scheduler.add_request(request)
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 3
assert output.scheduled_cached_reqs.num_reqs == 0
assert len(output.finished_req_ids) == 0
assert scheduler.max_num_encoder_input_tokens == 1024
# The first request is scheduled fully.
assert output.num_scheduled_tokens[requests[0].request_id] == 800
# The second request is scheduled partially.
# The <img> tokens are not scheduled because of the encoder budget.
assert output.num_scheduled_tokens[requests[1].request_id] == 100
# The third request is also scheduled partially.
# The <img> tokens are not scheduled because of the encoder budget.
assert output.num_scheduled_tokens[requests[2].request_id] == 100
model_runner_output = ModelRunnerOutput(
req_ids=[request.request_id for request in requests],
# Only the first request has a sampled token id because
# the rest requests are still being prefilled.
sampled_token_ids=[[0], [], []],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(output, model_runner_output)
# Schedule the next step.
# Only the first and second requests are scheduled.
# The third request is in the RUNNING state but not scheduled in this step
# because of the encoder budget.
output = scheduler.schedule()
assert len(scheduler.running) == 3
assert len(output.scheduled_new_reqs) == 0
assert output.scheduled_cached_reqs.num_reqs == 2
assert len(output.finished_req_ids) == 0
assert output.num_scheduled_tokens[requests[0].request_id] == 1
assert output.num_scheduled_tokens[requests[1].request_id] == 700
assert requests[2].request_id not in output.num_scheduled_tokens
def test_no_mm_input_chunking():
# Disable multimodal input chunking.
scheduler = create_scheduler(
model="llava-hf/llava-1.5-7b-hf",
max_num_batched_tokens=1024,
disable_chunked_mm_input=True,
max_model_len=2048,
)
mm_positions = [[PlaceholderRange(offset=400, length=800)]]
requests = create_requests(
num_requests=1, num_tokens=1200, mm_positions=mm_positions
)
for request in requests:
scheduler.add_request(request)
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 1
assert output.scheduled_cached_reqs.num_reqs == 0
assert len(output.finished_req_ids) == 0
# We want to only see the 400 text tokens at the start scheduled
assert output.num_scheduled_tokens[requests[0].request_id] == 400
model_runner_output = ModelRunnerOutput(
req_ids=[request.request_id for request in requests],
sampled_token_ids=[[] for _ in range(len(requests))],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(output, model_runner_output)
output = scheduler.schedule()
assert len(scheduler.running) == 1
assert len(output.scheduled_new_reqs) == 0
assert output.scheduled_cached_reqs.num_reqs == 1
assert len(output.finished_req_ids) == 0
assert output.num_scheduled_tokens[requests[0].request_id] == 800
# Test that we fail if we disable chunked mm input and use too small
# of a max_num_batched_tokens for the mm input.
with pytest.raises(ValueError):
_ = create_scheduler(
model="llava-hf/llava-1.5-7b-hf",
max_num_batched_tokens=100,
disable_chunked_mm_input=True,
)
@pytest.mark.parametrize("enable_prefix_caching", [True, False])
def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
"""Test scheduling behavior with concurrent partial requests.
This test verifies that: there are multiple long prefill requests in the
RUNNING state, and we can schedule them together.
"""
scheduler = create_scheduler(
model="facebook/opt-125m",
max_num_batched_tokens=1024,
long_prefill_token_threshold=400,
enable_prefix_caching=enable_prefix_caching,
)
requests = create_requests(
num_requests=3,
num_tokens=800,
)
for request in requests:
scheduler.add_request(request)
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 3
assert output.scheduled_cached_reqs.num_reqs == 0
assert len(output.finished_req_ids) == 0
# The first request is scheduled partially - 400.
assert output.num_scheduled_tokens[requests[0].request_id] == 400
# The second request is scheduled partially - 400.
assert output.num_scheduled_tokens[requests[1].request_id] == 400
# The third request is also scheduled partially - 1024 - 400 - 400 = 224.
assert output.num_scheduled_tokens[requests[2].request_id] == 224
model_runner_output = ModelRunnerOutput(
req_ids=[request.request_id for request in requests],
sampled_token_ids=[[] for _ in range(len(requests))],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(output, model_runner_output)
# Schedule the next step. All three requests are running.
# Processed the remaining prefills of the first and second requests.
output1 = scheduler.schedule()
assert len(scheduler.running) == 3
assert len(output1.scheduled_new_reqs) == 0
assert output1.scheduled_cached_reqs.num_reqs == 3
assert len(output1.finished_req_ids) == 0
assert output1.num_scheduled_tokens[requests[0].request_id] == 400
assert output1.num_scheduled_tokens[requests[1].request_id] == 400
assert output1.num_scheduled_tokens[requests[2].request_id] == 224
# Schedule the third step. All three requests are running.
# First and second requests are in the decode stage.
# All the remaining tokens in the third request are processed.
model_runner_output = ModelRunnerOutput(
req_ids=[request.request_id for request in requests],
sampled_token_ids=[[0], [0]] + [[] for _ in range(len(requests) - 2)],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(output1, model_runner_output)
output2 = scheduler.schedule()
assert len(scheduler.running) == 3
assert len(output2.scheduled_new_reqs) == 0
assert output2.scheduled_cached_reqs.num_reqs == 3
assert len(output2.finished_req_ids) == 0
assert output2.num_scheduled_tokens[requests[0].request_id] == 1
assert output2.num_scheduled_tokens[requests[1].request_id] == 1
assert output2.num_scheduled_tokens[requests[2].request_id] == 800 - 224 - 224
def test_stop_via_update_from_output():
"""Test stopping behavior through update_from_output"""
scheduler = create_scheduler(num_speculative_tokens=1)
# Test case 1: Stop on EOS token
requests = create_requests(num_requests=2, max_tokens=10)
for req in requests:
req.num_computed_tokens = req.num_tokens
scheduler.requests[req.request_id] = req
scheduler.running.append(req)
req.status = RequestStatus.RUNNING
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={requests[0].request_id: 1, requests[1].request_id: 2},
total_num_scheduled_tokens=3,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [],
requests[1].request_id: [10],
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
sampled_token_ids=[
[EOS_TOKEN_ID],
[10, 11],
], # First request hits EOS, second continues
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(scheduler_output, model_output)
# Verify first request stopped, second continues
assert len(scheduler.running) == 1
assert scheduler.running[0].request_id == requests[1].request_id
assert requests[0].status == RequestStatus.FINISHED_STOPPED
assert requests[0].request_id in scheduler.finished_req_ids
assert list(requests[0].output_token_ids) == [EOS_TOKEN_ID]
assert list(requests[1].output_token_ids) == [10, 11]
# Test case 2: Stop on custom stop token
scheduler = create_scheduler(num_speculative_tokens=2)
requests = create_requests(num_requests=2, max_tokens=10, stop_token_ids=[42, 43])
for req in requests:
req.num_computed_tokens = req.num_tokens
scheduler.requests[req.request_id] = req
scheduler.running.append(req)
req.status = RequestStatus.RUNNING
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={requests[0].request_id: 3, requests[1].request_id: 2},
total_num_scheduled_tokens=5,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [10, 42],
requests[1].request_id: [13],
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
sampled_token_ids=[[10, 42, 12], [13, 14]], # First request hits stop token
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(scheduler_output, model_output)
# Verify first request stopped on custom token
assert len(scheduler.running) == 1
assert scheduler.running[0].request_id == requests[1].request_id
assert requests[0].status == RequestStatus.FINISHED_STOPPED
assert requests[0].stop_reason == 42
assert requests[0].request_id in scheduler.finished_req_ids
assert list(requests[0].output_token_ids) == [10, 42]
assert list(requests[1].output_token_ids) == [13, 14]
# Test case 3: Stop on max tokens
scheduler = create_scheduler(num_speculative_tokens=2)
requests = create_requests(num_requests=2, max_tokens=2)
for req in requests:
req.num_computed_tokens = req.num_tokens
scheduler.requests[req.request_id] = req
scheduler.running.append(req)
req.status = RequestStatus.RUNNING
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={requests[0].request_id: 3, requests[1].request_id: 1},
total_num_scheduled_tokens=4,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [10, 11],
requests[1].request_id: [],
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
sampled_token_ids=[[10, 11, 12], [13]], # First request exceeds max_tokens
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(scheduler_output, model_output)
# Verify first request stopped due to length
assert len(scheduler.running) == 1
assert scheduler.running[0].request_id == requests[1].request_id
assert requests[0].status == RequestStatus.FINISHED_LENGTH_CAPPED
assert requests[0].request_id in scheduler.finished_req_ids
assert list(requests[0].output_token_ids) == [10, 11] # Truncated to max_tokens
assert list(requests[1].output_token_ids) == [13]
# Test case 4: Ignore EOS flag
scheduler = create_scheduler(num_speculative_tokens=2)
requests = create_requests(num_requests=1, max_tokens=10)
requests[0].sampling_params.ignore_eos = True
requests[0].num_computed_tokens = requests[0].num_tokens
scheduler.requests[requests[0].request_id] = requests[0]
scheduler.running.append(requests[0])
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={requests[0].request_id: 3},
total_num_scheduled_tokens=3,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={requests[0].request_id: [EOS_TOKEN_ID, 10]},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
model_output = ModelRunnerOutput(
req_ids=[requests[0].request_id],
sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(scheduler_output, model_output)
# Verify request continues past EOS
assert len(scheduler.running) == 1
assert not requests[0].is_finished()
assert list(requests[0].output_token_ids) == [EOS_TOKEN_ID, 10, 11]
def test_check_stop_min_tokens():
"""Test that requests don't stop when min_tokens requirement isn't met."""
from vllm.v1.core.sched.utils import check_stop
# Test case 1: num_output_tokens < min_tokens
# Should return False (don't stop)
sampling_params = SamplingParams(
ignore_eos=False,
max_tokens=20,
min_tokens=5,
)
request = Request(
request_id="0",
prompt_token_ids=[0, 1, 2],
sampling_params=sampling_params,
pooling_params=None,
eos_token_id=EOS_TOKEN_ID,
)
# Simulate having generated 3 output tokens (less than min_tokens=5)
request.append_output_token_ids([10, 11, EOS_TOKEN_ID]) # EOS token present
result = check_stop(request, max_model_len=100)
assert result is False, "Should not stop when num_output_tokens<min_tokens"
# Test case 2: num_output_tokens >= min_tokens
# Should follow normal stopping logic (stop on EOS)
request.append_output_token_ids(
[
10,
11,
12,
13,
14,
EOS_TOKEN_ID,
]
) # 6 tokens > min_tokens
result = check_stop(request, max_model_len=100)
assert result is True, "Should stop on EOS when min_tokens met"
assert request.status == RequestStatus.FINISHED_STOPPED
# Test case 3: min_tokens = 0, should follow normal stopping logic
sampling_params_no_min = SamplingParams(
ignore_eos=False,
max_tokens=20,
min_tokens=0,
)
request_no_min = Request(
request_id="1",
prompt_token_ids=[0, 1, 2],
sampling_params=sampling_params_no_min,
pooling_params=None,
eos_token_id=EOS_TOKEN_ID,
)
request_no_min.append_output_token_ids([10, EOS_TOKEN_ID])
result = check_stop(request_no_min, max_model_len=100)
assert result is True, "Should stop on EOS when min_tokens=0"
assert request_no_min.status == RequestStatus.FINISHED_STOPPED
# Test case 4: min_tokens > 0 with stop token (not EOS)
sampling_params_stop = SamplingParams(
ignore_eos=False,
max_tokens=20,
min_tokens=5,
stop_token_ids=[42],
)
request_stop = Request(
request_id="2",
prompt_token_ids=[0, 1, 2],
sampling_params=sampling_params_stop,
pooling_params=None,
eos_token_id=EOS_TOKEN_ID,
)
# Only 3 output tokens, less than min_tokens=5, but has stop token
request_stop.append_output_token_ids([10, 11, 42])
result = check_stop(request_stop, max_model_len=100)
assert result is False, "Should not stop when num_output_tokens<min_tokens"
# Test case 5: min_tokens met, should stop on stop token
request_stop.append_output_token_ids(
[10, 11, 12, 13, 14, 42]
) # 6 tokens >= min_tokens=5
result = check_stop(request_stop, max_model_len=100)
assert result is True, "Should stop on stop token when min_tokens met"
assert request_stop.status == RequestStatus.FINISHED_STOPPED
assert request_stop.stop_reason == 42
@pytest.mark.parametrize(
"enable_prefix_caching, prompt_logprobs",
[
(None, None),
(True, 5),
],
)
def test_schedule_concurrent_batches(
enable_prefix_caching: Optional[bool], prompt_logprobs: Optional[int]
):
scheduler = create_scheduler(
max_num_batched_tokens=1024,
max_num_seqs=2,
enable_prefix_caching=enable_prefix_caching,
)
requests = create_requests(
num_requests=2,
num_tokens=512,
prompt_logprobs=prompt_logprobs,
)
# Schedule the first request.
scheduler.add_request(requests[0])
scheduler_output0 = scheduler.schedule()
assert len(scheduler_output0.scheduled_new_reqs) == 1
assert scheduler_output0.num_scheduled_tokens[requests[0].request_id] == 512
# The first request is still running, so only schedule the second request.
scheduler.add_request(requests[1])
scheduler_output1 = scheduler.schedule()
assert len(scheduler_output1.scheduled_new_reqs) == 1
assert scheduler_output1.num_scheduled_tokens[requests[1].request_id] == 512
# Model output of the first request.
model_runner_output = ModelRunnerOutput(
req_ids=[requests[0].request_id],
sampled_token_ids=[[0]],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(scheduler_output0, model_runner_output)
# Schedule the next step.
# The first request can be scheduled again while the second
# request is still running.
scheduler_output2 = scheduler.schedule()
assert scheduler_output2.num_scheduled_tokens[requests[0].request_id] == 1
# Model output of the second request.
model_runner_output = ModelRunnerOutput(
req_ids=[requests[1].request_id],
sampled_token_ids=[[0]],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(scheduler_output1, model_runner_output)
def test_preempt_during_execution():
# NOTE(woosuk): The actual number of available blocks is 10 instead of 11
# because block 0 is reserved as the null block.
scheduler = create_scheduler(
max_num_batched_tokens=100,
block_size=16,
num_blocks=11,
enable_prefix_caching=False,
)
requests = create_requests(num_requests=2, num_tokens=80, block_size=16)
# Schedule the first request.
scheduler.add_request(requests[0])
scheduler_output0 = scheduler.schedule()
assert len(scheduler_output0.num_scheduled_tokens) == 1
assert len(scheduler_output0.scheduled_new_reqs[0].block_ids[0]) == 5
# Schedule the second request while the first request is still running.
# This scenario can occur in certain cases, when max_concurrent_batches > 1
# (e.g., when pipeline parallelism is used).
scheduler.add_request(requests[1])
scheduler_output1 = scheduler.schedule()
assert len(scheduler_output1.num_scheduled_tokens) == 1
assert len(scheduler_output1.scheduled_new_reqs[0].block_ids[0]) == 5
# Get the output of the first request.
model_runner_output0 = ModelRunnerOutput(
req_ids=[requests[0].request_id],
sampled_token_ids=[[0]],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(scheduler_output0, model_runner_output0)
# Schedule the first request again. This will cause the preemption
# of the second request because the KV cache is full.
_ = scheduler.schedule()
assert len(scheduler.running) == 1
assert scheduler.running[0] == requests[0]
assert requests[1].status == RequestStatus.PREEMPTED
model_runner_output1 = ModelRunnerOutput(
req_ids=[requests[1].request_id],
sampled_token_ids=[[42]],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(scheduler_output1, model_runner_output1)
# The second request (that is preempted) should be updated with the
# sampled token id.
assert len(requests[1].output_token_ids) == 1
assert requests[1].output_token_ids[0] == 42
# Note - these test cases mirror some of those in test_rejection_sampler.py
@pytest.mark.parametrize(
"spec_tokens,output_tokens,expected",
[
([[1, 2, 3]], [[1, 2, 3, 4]], (1, 3, 3, [1, 1, 1])), # perfect match
([[1, 2, 3]], [[1, 5]], (1, 3, 1, [1, 0, 0])), # early mismatch
([[1, 2], [3]], [[1, 2, 5], [3, 4]], (2, 3, 3, [2, 1])), # multiple sequences
([[1]], [[1, 2]], (1, 1, 1, [1])), # single token sequence
([[]], [[5]], (0, 0, 0, [0])), # empty sequence
(
[[1, 2, 3], [4, 5, 6]],
[[1, 2, 7], [4, 8]],
(2, 6, 3, [2, 1, 0]),
), # multiple mismatches
],
)
def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
"""Test scheduling behavior with speculative decoding.
This test verifies that:
1. Speculated tokens get scheduled correctly
2. Spec decoding stats properly count number of draft and accepted tokens
"""
num_spec_tokens = max(1, max(len(t) for t in spec_tokens))
scheduler = create_scheduler(num_speculative_tokens=num_spec_tokens)
requests = create_requests(num_requests=len(spec_tokens), num_tokens=1)
req_ids = []
for i, request in enumerate(requests):
scheduler.add_request(request)
req_ids.append(request.request_id)
# Schedule a decode, which will also draft speculative tokens
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == len(requests)
assert output.total_num_scheduled_tokens == len(requests)
for i in range(len(requests)):
req_id = requests[i].request_id
assert output.num_scheduled_tokens[req_id] == 1
assert req_id not in output.scheduled_spec_decode_tokens
model_runner_output = ModelRunnerOutput(
req_ids=req_ids,
sampled_token_ids=[[0] for _ in range(len(requests))],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
engine_core_outputs = scheduler.update_from_output(output, model_runner_output)
draft_token_ids = DraftTokenIds(req_ids, spec_tokens)
scheduler.update_draft_token_ids(draft_token_ids)
for i in range(len(requests)):
running_req = scheduler.running[i]
# The prompt token
assert running_req.num_computed_tokens == 1
# The prompt token and the sampled token
assert running_req.num_tokens == 2
# The prompt token, the sampled token, and the speculated tokens
assert running_req.num_tokens_with_spec == 2 + len(spec_tokens[i])
# No draft or accepted tokens counted yet
assert not engine_core_outputs or (
engine_core_outputs[0].scheduler_stats.spec_decoding_stats is None
)
# Schedule the speculated tokens for validation
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 0
# The sampled token and speculated tokens
assert output.total_num_scheduled_tokens == len(requests) + sum(
len(ids) for ids in spec_tokens
)
for i in range(len(requests)):
req_id = requests[i].request_id
assert output.num_scheduled_tokens[req_id] == 1 + len(spec_tokens[i])
if spec_tokens[i]:
assert len(output.scheduled_spec_decode_tokens[req_id]) == len(
spec_tokens[i]
)
else:
assert req_id not in output.scheduled_spec_decode_tokens
model_runner_output = ModelRunnerOutput(
req_ids=req_ids,
sampled_token_ids=output_tokens,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
engine_core_outputs = scheduler.update_from_output(output, model_runner_output)
scheduler_stats = (
engine_core_outputs[0].scheduler_stats if engine_core_outputs else None
)
if expected[0] == 0:
assert scheduler_stats is not None
assert scheduler_stats.spec_decoding_stats is None
else:
assert scheduler_stats is not None
assert scheduler_stats.spec_decoding_stats is not None
stats = scheduler_stats.spec_decoding_stats
assert stats.num_drafts == expected[0]
assert stats.num_draft_tokens == expected[1]
assert stats.num_accepted_tokens == expected[2]
assert stats.num_accepted_tokens_per_pos == expected[3]
def _assert_right_scheduler_output(
output: SchedulerOutput,
num_requests: int,
expected_num_scheduled_tokens: int,
):
"""Check if SchedulerOutput is correct after remote KV cache hit."""
# We should inject the kv_connector_metadata.
assert len(output.kv_connector_metadata.requests) == num_requests
# Only num_tokens - matched_num_new_tokens should be scheduled.
for _, num_scheduled_tokens in output.num_scheduled_tokens.items():
assert num_scheduled_tokens == expected_num_scheduled_tokens
def _assert_right_kv_cache_manager(
scheduler: Scheduler,
requests: list[Request],
num_tokens: int,
block_size: int,
num_requests: int,
num_total_blocks: int,
):
"""Check whether KVCacheManager is correct after allocate."""
# Make sure the request stats are right.
EXPECTED_TOTAL_BLOCKS = num_tokens // block_size
for req in requests:
blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
0
].req_to_blocks[req.request_id]
hashes = req.block_hashes
assert (
scheduler.kv_cache_manager.coordinator.single_type_managers[
0
].num_cached_block[req.request_id]
== EXPECTED_TOTAL_BLOCKS
)
assert len(blocks) == EXPECTED_TOTAL_BLOCKS
assert len(hashes) == EXPECTED_TOTAL_BLOCKS
# Make sure we actually touched all the blocks.
BLOCKS_PER_REQ = num_tokens / block_size
assert (
scheduler.kv_cache_manager.block_pool.get_num_free_blocks()
== num_total_blocks - num_requests * BLOCKS_PER_REQ
)
def _step_until_done(
scheduler: Scheduler,
output: SchedulerOutput,
model_runner_output: ModelRunnerOutput,
):
"""Loop over schedule(), update_from_output() until finished."""
all_finished = False
_ = scheduler.update_from_output(output, model_runner_output)
while not all_finished:
# Schedule + a few iterations until stopping.
output = scheduler.schedule()
assert len(scheduler.running)
for _, num_scheduled_tokens in output.num_scheduled_tokens.items():
# We should be in the decode phase now.
assert num_scheduled_tokens == 1
assert len(output.kv_connector_metadata.requests) == 0
ecos = scheduler.update_from_output(output, model_runner_output)[0]
all_done = True
for eco in ecos.outputs:
if eco.finish_reason is None:
all_done = False
all_finished = all_done
def test_kv_connector_basic():
"""
Test whether Scheduler with KVConnector schedules tokens, allocates
memory, and cleans up requests as expected under normal operation.
"""
# Setup Scheduler.
scheduler = create_scheduler(
enable_prefix_caching=True,
use_kv_connector=True,
)
NUM_TOTAL_BLOCKS = scheduler.kv_cache_manager.block_pool.get_num_free_blocks()
BLOCK_SIZE = scheduler.cache_config.block_size
# Mock External Cache Hit.
NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE * 2
scheduler.connector.get_num_new_matched_tokens = Mock(name="method")
scheduler.connector.get_num_new_matched_tokens.return_value = (
NUM_MATCHED_NEW_TOKENS,
False,
)
######################################################
# FIRST SET OF REQUESTS - External Hit Only
NUM_REQUESTS = 2
NUM_TOKENS = NUM_MATCHED_NEW_TOKENS * 2
MAX_TOKENS = 3
requests = create_requests(
num_requests=NUM_REQUESTS,
num_tokens=NUM_TOKENS,
max_tokens=MAX_TOKENS,
block_size=BLOCK_SIZE,
)
req_ids = []
for i, request in enumerate(requests):
scheduler.add_request(request)
req_ids.append(request.request_id)
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
req_ids=req_ids,
sampled_token_ids=[[1000]] * len(req_ids),
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
# Ensure ScheduleOutput is correct.
output = scheduler.schedule()
_assert_right_scheduler_output(
output=output,
num_requests=NUM_REQUESTS,
# Just the incremental tokens should be scheduled.
expected_num_scheduled_tokens=NUM_TOKENS - NUM_MATCHED_NEW_TOKENS,
)
# Ensure KVCacheManager is correct.
_assert_right_kv_cache_manager(
scheduler, requests, NUM_TOKENS, BLOCK_SIZE, NUM_REQUESTS, NUM_TOTAL_BLOCKS
)
# Continue Generation until done.
_step_until_done(scheduler, output, MODEL_RUNNER_OUTPUT)
_ = scheduler.schedule()
# Confirm we clean up the memory properly.
assert (
scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == NUM_TOTAL_BLOCKS
)
######################################################
# SECOND SET OF REQUESTS - Local And External Hit
NUM_TOKENS_PREFIX = NUM_TOKENS
# We will get a local prefix cache hit for the first
# NUM_TOKENS_PREFIX tokens since they are used above.
NUM_TOKENS = NUM_TOKENS_PREFIX * 2
requests = create_requests(
num_requests=NUM_REQUESTS,
num_tokens=NUM_TOKENS,
max_tokens=MAX_TOKENS,
block_size=BLOCK_SIZE,
)
req_ids = []
for i, request in enumerate(requests):
scheduler.add_request(request)
req_ids.append(request.request_id)
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
req_ids=req_ids,
sampled_token_ids=[[1000]] * len(req_ids),
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
# We should get a local cache hit of NUM_TOKENS_PREFIX and
# a remote KV cache hit of NUM_MATCHED_NEW_TOKENS.
output = scheduler.schedule()
_assert_right_scheduler_output(
output=output,
num_requests=NUM_REQUESTS,
# Just the incremental tokens after local + remote cache hit.
expected_num_scheduled_tokens=(
NUM_TOKENS - NUM_TOKENS_PREFIX - NUM_MATCHED_NEW_TOKENS
),
)
# Ensure KVCacheManager is correct.
_assert_right_kv_cache_manager(
scheduler, requests, NUM_TOKENS, BLOCK_SIZE, NUM_REQUESTS, NUM_TOTAL_BLOCKS
)
# Continue Generation until done.
_step_until_done(scheduler, output, MODEL_RUNNER_OUTPUT)
_ = scheduler.schedule()
# Confirm we clean up the memory properly.
assert (
scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == NUM_TOTAL_BLOCKS
)
def test_kv_connector_unable_to_allocate():
"""
Test whether scheduler with KVConnector is able to handle
unable to allocate (run out of blocks in allocate_slots().
"""
# Setup Scheduler With Mock External Cache Hit.
BLOCK_SIZE = 4
NUM_BLOCKS = 10
scheduler = create_scheduler(
enable_prefix_caching=True,
use_kv_connector=True,
block_size=BLOCK_SIZE,
num_blocks=NUM_BLOCKS,
)
NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE * 2
scheduler.connector.get_num_new_matched_tokens = Mock(name="method")
scheduler.connector.get_num_new_matched_tokens.return_value = (
NUM_MATCHED_NEW_TOKENS,
False,
)
# Create two requests. The second request will not be able to
# allocate slots because it will not have enough blocks.
NUM_REQUESTS = 2
NUM_TOKENS = (NUM_BLOCKS // 2 + 1) * BLOCK_SIZE
MAX_TOKENS = 2
requests = create_requests(
num_requests=NUM_REQUESTS,
num_tokens=NUM_TOKENS,
max_tokens=MAX_TOKENS,
block_size=BLOCK_SIZE,
)
req_ids = []
for i, request in enumerate(requests):
scheduler.add_request(request)
req_ids.append(request.request_id)
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
req_ids=req_ids,
sampled_token_ids=[[1000]] * len(req_ids),
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
# Just one request should be running.
output = scheduler.schedule()
_assert_right_scheduler_output(
output,
num_requests=1,
expected_num_scheduled_tokens=NUM_TOKENS - NUM_MATCHED_NEW_TOKENS,
)
assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 1
# All memory should be freed, with one request waiting.
_step_until_done(scheduler, output, MODEL_RUNNER_OUTPUT)
assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == NUM_BLOCKS - 1
assert len(scheduler.running) == 0
assert len(scheduler.waiting) == 1
# Just one request should be running.
output = scheduler.schedule()
_assert_right_scheduler_output(
output,
num_requests=1,
expected_num_scheduled_tokens=NUM_TOKENS - NUM_MATCHED_NEW_TOKENS,
)
assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 0
# All memory should be freed, with no requests waiting / running.
_step_until_done(scheduler, output, MODEL_RUNNER_OUTPUT)
assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == NUM_BLOCKS - 1
assert len(scheduler.running) == 0
assert len(scheduler.waiting) == 0
def test_kv_connector_handles_preemption():
"""
Test whether scheduler with KVConnector is able to handle
unable to allocate (run out of blocks in allocate_slots().
"""
# Setup Scheduler With Mock External Cache Hit.
BLOCK_SIZE = 2
# NOTE: there is 1 null block, so this is 6 blocks.
NUM_BLOCKS = 7
scheduler = create_scheduler(
enable_prefix_caching=True,
use_kv_connector=True,
block_size=BLOCK_SIZE,
num_blocks=NUM_BLOCKS,
)
NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE
scheduler.connector.get_num_new_matched_tokens = Mock(name="method")
scheduler.connector.get_num_new_matched_tokens.return_value = (
NUM_MATCHED_NEW_TOKENS,
False,
)
# Create two requests.
# Both can be scheduled at first, but the second request
# will be preempted and re-scheduled.
NUM_REQUESTS = 2
NUM_TOKENS = BLOCK_SIZE * 2 + 1
MAX_TOKENS = BLOCK_SIZE * 2
requests = create_requests(
num_requests=NUM_REQUESTS,
num_tokens=NUM_TOKENS,
max_tokens=MAX_TOKENS,
block_size=BLOCK_SIZE,
)
req_ids = []
for i, request in enumerate(requests):
scheduler.add_request(request)
req_ids.append(request.request_id)
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
req_ids=req_ids,
sampled_token_ids=[[1000]] * len(req_ids),
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
# All can be scheduled - 1st token.
output = scheduler.schedule()
_assert_right_scheduler_output(
output,
# 2 remote kv cache hits.
num_requests=2,
expected_num_scheduled_tokens=NUM_TOKENS - NUM_MATCHED_NEW_TOKENS,
)
assert len(scheduler.running) == 2
_ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT)
# All can be scheduled - 2nd token.
output = scheduler.schedule()
_assert_right_scheduler_output(
output,
# no connector_metadata
num_requests=0,
expected_num_scheduled_tokens=1,
)
assert len(scheduler.running) == 2
_ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT)
# This will generate a new block and cause a preemption - 3rd token.
output = scheduler.schedule()
_assert_right_scheduler_output(
output,
# no connector_metadata
num_requests=0,
expected_num_scheduled_tokens=1,
)
assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 1
_ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT)
assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 1
# Only 1 can be scheduled - 4th (and last token).
output = scheduler.schedule()
_assert_right_scheduler_output(
output,
# no connector_metadata
num_requests=0,
expected_num_scheduled_tokens=1,
)
assert len(scheduler.waiting) == 1
assert len(scheduler.running) == 1
_ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT)
assert len(scheduler.running) == 0
# All memory should be freed since nothing is running.
assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == NUM_BLOCKS - 1
# Restarts the preempted request - generate 3rd token.
# This will have a local and remote cache hit.
output = scheduler.schedule()
_assert_right_scheduler_output(
output,
# 1 remote kv_cache hit!
num_requests=1,
# Only 1 block was preempted and there is a single
# remote hit. So only single new token is scheduled.
expected_num_scheduled_tokens=1,
)
assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 0
_ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT)
assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 0
# Only 1 can be scheduled - 4th (and last token).
output = scheduler.schedule()
_assert_right_scheduler_output(
output,
# no connector_metadata
num_requests=0,
expected_num_scheduled_tokens=1,
)
assert len(scheduler.running) == 1
_ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT)
assert len(scheduler.running) == 0
# All memory should be freed since nothing is running.
assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == NUM_BLOCKS - 1
def make_output(scheduler: Scheduler):
return ModelRunnerOutput(
req_ids=[req.request_id for req in scheduler.running],
sampled_token_ids=[[1000]] * len(scheduler.running),
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
def assert_scheduler_empty(scheduler: Scheduler):
"""Confirm the scheduler is "empty" - i.e. no leaks."""
# Scheduler Metadata.
assert len(scheduler.requests) == 0
assert len(scheduler.waiting) == 0
assert len(scheduler.running) == 0
assert len(scheduler.finished_req_ids) == 0
# EncoderCacheManager.
assert len(scheduler.encoder_cache_manager.freed) == 0
assert len(scheduler.encoder_cache_manager.cached) == 0
# KVCache Manager.
assert (
len(
scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks
)
== 0
)
assert (
len(
scheduler.kv_cache_manager.coordinator.single_type_managers[
0
].num_cached_block
)
== 0
)
num_free_blocks = (
scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks
)
assert num_free_blocks == (scheduler.kv_cache_manager.block_pool.num_gpu_blocks - 1)
# NOTE(rob): just the ref count on blocks will be 0. The hash
# value, etc will remain since we lazily evict for prefix cache.
for block in scheduler.kv_cache_manager.block_pool.blocks:
assert block.ref_cnt == 0
# assert block._block_hash is None
# assert (
# len(scheduler.kv_cache_manager.block_pool.cached_block_hash_to_block
# ) == 0)
def test_memory_leak():
"""Test that we do not have a memory leak."""
scheduler = create_scheduler(enable_prefix_caching=True)
NUM_REQUESTS = 5
NUM_TOKENS = 10
MAX_TOKENS = 10
requests = create_requests(
num_requests=NUM_REQUESTS, num_tokens=NUM_TOKENS, max_tokens=MAX_TOKENS
)
# Add each request.
for request in requests:
scheduler.add_request(request)
scheduler_output = scheduler.schedule()
model_runner_output = make_output(scheduler)
scheduler.update_from_output(scheduler_output, model_runner_output)
# Iterate until done.
while True:
scheduler_output = scheduler.schedule()
if len(scheduler.running) == 0:
break
model_runner_output = make_output(scheduler)
scheduler.update_from_output(scheduler_output, model_runner_output)
# Confirm no memory leak.
assert_scheduler_empty(scheduler)
def create_scheduler_with_priority(
model: str = "facebook/opt-125m",
max_num_seqs: int = 16,
max_num_batched_tokens: int = 8192,
enable_prefix_caching: Optional[bool] = None,
long_prefill_token_threshold: int = 0,
disable_chunked_mm_input: bool = False,
use_kv_connector: bool = False,
num_blocks: int = 10000,
block_size: int = 16,
max_model_len: Optional[int] = None,
num_speculative_tokens: Optional[int] = None,
) -> Scheduler:
"""Create scheduler with priority policy enabled.
Args:
model: model under test
max_num_seqs: max sequences to schedule
max_num_batch_tokens: max num tokens to batch
enable_prefix_caching: optionally force APC config
(True/False) or use default
(None)
Returns:
{class}`Scheduler` instance with priority scheduling
"""
if max_model_len is None:
max_model_len = max_num_batched_tokens
scheduler_config = SchedulerConfig(
max_num_seqs=max_num_seqs,
max_num_batched_tokens=max_num_batched_tokens,
max_model_len=max_model_len,
long_prefill_token_threshold=long_prefill_token_threshold,
disable_chunked_mm_input=disable_chunked_mm_input,
enable_chunked_prefill=True,
policy="priority", # Enable priority scheduling
)
model_config = ModelConfig(
model=model,
trust_remote_code=True,
dtype="float16",
seed=42,
)
# Cache config, optionally force APC
kwargs_cache = (
{}
if enable_prefix_caching is None
else {"enable_prefix_caching": enable_prefix_caching}
)
cache_config = CacheConfig(
block_size=block_size,
gpu_memory_utilization=0.9,
swap_space=0,
cache_dtype="auto",
**kwargs_cache,
)
kv_transfer_config = (
KVTransferConfig(
kv_connector="SharedStorageConnector",
kv_role="kv_both",
kv_connector_extra_config={"shared_storage_path": "local_storage"},
)
if use_kv_connector
else None
)
speculative_config: Optional[SpeculativeConfig] = None
if num_speculative_tokens is not None:
speculative_config = SpeculativeConfig(
model="ngram", num_speculative_tokens=num_speculative_tokens
)
vllm_config = VllmConfig(
scheduler_config=scheduler_config,
model_config=model_config,
cache_config=cache_config,
kv_transfer_config=kv_transfer_config,
speculative_config=speculative_config,
)
kv_cache_config = KVCacheConfig(
num_blocks=num_blocks, # A large number of blocks to hold all requests
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(
["layer"], FullAttentionSpec(block_size, 1, 1, torch.float32, False)
)
],
)
cache_config.num_gpu_blocks = num_blocks
return Scheduler(
vllm_config=vllm_config,
kv_cache_config=kv_cache_config,
log_stats=True,
structured_output_manager=StructuredOutputManager(vllm_config),
block_size=block_size,
)
def create_requests_with_priority(
num_requests: int,
priorities: list[int],
arrival_times: Optional[list[float]] = None,
num_tokens: int = 10,
mm_positions: Optional[list[list[PlaceholderRange]]] = None,
max_tokens: int = 16,
stop_token_ids: Optional[list[int]] = None,
prompt_logprobs: Optional[int] = None,
starting_idx: int = 0,
):
"""Create requests with specified priorities and arrival times."""
assert len(priorities) == num_requests
if arrival_times is not None:
assert len(arrival_times) == num_requests
else:
arrival_times = [float(i) for i in range(num_requests)]
sampling_params = SamplingParams(
ignore_eos=False,
max_tokens=max_tokens,
stop_token_ids=stop_token_ids,
prompt_logprobs=prompt_logprobs,
)
requests = []
for i in range(num_requests):
mm_features = []
if mm_positions is not None:
mm_position = mm_positions[i]
for j, position in enumerate(mm_position):
identifier = f"hash{i}_{j}"
mm_feature = MultiModalFeatureSpec(
data=MultiModalKwargsItem.dummy("dummy_m"),
mm_position=position,
identifier=identifier,
modality="image",
)
mm_features.append(mm_feature)
request = Request(
request_id=f"{i + starting_idx}",
prompt_token_ids=[i + starting_idx] * num_tokens,
sampling_params=sampling_params,
pooling_params=None,
mm_features=mm_features if mm_features else None,
eos_token_id=EOS_TOKEN_ID,
arrival_time=arrival_times[i],
priority=priorities[i],
)
requests.append(request)
return requests
def test_priority_scheduling_basic_ordering():
"""Test that requests are scheduled in priority order
(lower value = higher priority)."""
scheduler = create_scheduler_with_priority()
# Create requests with different priorities
# Priority 0 (highest), 1, 2 (lowest)
priorities = [2, 0, 1] # Add in non-priority order
arrival_times = [1.0, 2.0, 3.0] # All different arrival times
requests = create_requests_with_priority(
num_requests=3, priorities=priorities, arrival_times=arrival_times
)
# Add requests in non-priority order
for request in requests:
scheduler.add_request(request)
# Schedule and verify priority order
output = scheduler.schedule()
# Should schedule all requests since they fit in budget
assert len(output.scheduled_new_reqs) == 3
# Verify they are scheduled in priority order:
# req_1 (priority 0), req_2 (priority 1), req_0 (priority 2)
scheduled_req_ids = [req.req_id for req in output.scheduled_new_reqs]
assert scheduled_req_ids == ["1", "2", "0"]
def test_priority_scheduling_arrival_time_tiebreaker():
"""Test that arrival time is used
as tiebreaker when priorities are equal."""
scheduler = create_scheduler_with_priority()
# Create requests with same priority but different arrival times
priorities = [1, 1, 1] # All same priority
arrival_times = [3.0, 1.0, 2.0] # Different arrival times
requests = create_requests_with_priority(
num_requests=3, priorities=priorities, arrival_times=arrival_times
)
# Add requests in non-arrival order
for request in requests:
scheduler.add_request(request)
# Schedule and verify arrival time order
output = scheduler.schedule()
# Should schedule all requests since they fit in budget
assert len(output.scheduled_new_reqs) == 3
# Verify they are scheduled in arrival time order:
# req_1 (1.0), req_2 (2.0), req_0 (3.0)
scheduled_req_ids = [req.req_id for req in output.scheduled_new_reqs]
assert scheduled_req_ids == ["1", "2", "0"]
def test_priority_scheduling_mixed_priority_and_arrival():
"""Test priority scheduling with mixed priorities and arrival times."""
scheduler = create_scheduler_with_priority()
# Create requests with mixed priorities and arrival times
priorities = [2, 1, 1, 0] # Mixed priorities
arrival_times = [1.0, 3.0, 2.0, 4.0] # Mixed arrival times
requests = create_requests_with_priority(
num_requests=4, priorities=priorities, arrival_times=arrival_times
)
# Add requests
for request in requests:
scheduler.add_request(request)
# Schedule and verify order
output = scheduler.schedule()
# Should schedule all requests since they fit in budget
assert len(output.scheduled_new_reqs) == 4
# Expected order:
# 1. req_3 (priority 0, arrival 4.0)
# 2. req_2 (priority 1, arrival 2.0) - earlier arrival than req_1
# 3. req_1 (priority 1, arrival 3.0)
# 4. req_0 (priority 2, arrival 1.0)
scheduled_req_ids = [req.req_id for req in output.scheduled_new_reqs]
assert scheduled_req_ids == ["3", "2", "1", "0"]
def test_priority_scheduling_preemption():
"""Test that priority scheduling preempts
lower priority requests when memory is constrained."""
# Create scheduler with very limited memory to force preemption
scheduler = create_scheduler_with_priority(
max_num_seqs=3, # Allow multiple requests
max_num_batched_tokens=200,
num_blocks=6, # Very limited blocks to force memory pressure
block_size=16, # Standard block size
)
# Create initial low-priority requests that will consume most memory
low_priority_requests = create_requests_with_priority(
num_requests=2,
priorities=[5, 5], # Low priority
arrival_times=[1.0, 2.0],
num_tokens=30, # Large enough to consume significant memory
)
# Add and schedule low priority requests
for request in low_priority_requests:
scheduler.add_request(request)
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 2
# Simulate model execution to move requests to running state
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in low_priority_requests],
sampled_token_ids=[[100] for _ in low_priority_requests],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(output, model_output)
# Verify both requests are running
assert len(scheduler.running) == 2
# Now add a high-priority request that requires memory allocation
# This should trigger preemption due to memory constraints
high_priority_request = create_requests_with_priority(
num_requests=1,
priorities=[0], # High priority
arrival_times=[3.0],
num_tokens=30, # Large enough to require significant memory
)[0]
scheduler.add_request(high_priority_request)
# Schedule again - this should trigger
# preemption when trying to allocate memory
output = scheduler.schedule()
# Due to the scheduler's design, if preemption happens
# during running request scheduling,
# waiting requests won't be scheduled in the same step
# Let's check if preemption occurred by looking at the waiting queue
# If preemption happened, we should see requests in the
# waiting queue
if len(scheduler.waiting) > 1: # high priority + preempted request
# Preemption occurred - verify the high priority request
# gets scheduled next
output2 = scheduler.schedule()
assert len(output2.scheduled_new_reqs) == 1
# High priority request
assert output2.scheduled_new_reqs[0].req_id == "0"
else:
# No preemption needed - all requests fit
# This is also valid behavior if memory allows
assert len(output.scheduled_new_reqs) == 1
# High priority request
assert output.scheduled_new_reqs[0].req_id == "0"
def test_priority_scheduling_no_preemption_when_space_available():
"""Test that preemption doesn't happen
when there's space for new requests."""
scheduler = create_scheduler_with_priority(
max_num_seqs=3, # Allow 3 concurrent requests
max_num_batched_tokens=200, # Sufficient token budget
)
# Add two low-priority running requests
low_priority_requests = create_requests_with_priority(
num_requests=2, priorities=[5, 5], arrival_times=[1.0, 2.0], num_tokens=30
)
for request in low_priority_requests:
scheduler.add_request(request)
output = scheduler.schedule()
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in low_priority_requests],
sampled_token_ids=[[100] for _ in low_priority_requests],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(output, model_output)
# Add high-priority request
high_priority_request = create_requests_with_priority(
num_requests=1, priorities=[0], arrival_times=[3.0], num_tokens=30
)[0]
scheduler.add_request(high_priority_request)
# Schedule - should not preempt since there's space
output = scheduler.schedule()
# Should schedule the new request without preemption
assert len(output.scheduled_new_reqs) == 1
assert len(scheduler.running) == 3 # All three requests running
assert len(scheduler.waiting) == 0 # No requests waiting
def test_priority_scheduling_preemption_victim_selection():
"""Test that the correct victim is selected for
preemption based on priority and arrival time."""
# This test verifies the priority-based victim selection logic
# by checking the waiting queue order after adding requests with different
# priorities
scheduler = create_scheduler_with_priority(
max_num_seqs=1, # Force sequential processing to test priority order
)
# Create requests with different priorities
requests = create_requests_with_priority(
num_requests=3,
priorities=[3, 2, 0], # Different priorities: low, medium, high
arrival_times=[1.0, 2.0, 3.0],
num_tokens=10,
)
# Add all requests
for request in requests:
scheduler.add_request(request)
# Schedule - should only schedule the highest priority request
# (req_2, priority 0)
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 1
assert output.scheduled_new_reqs[0].req_id == "2" # Highest priority
# Verify the waiting queue has the remaining requests in priority order
assert len(scheduler.waiting) == 2
# Extract waiting requests and verify priority order
waiting_requests = list(scheduler.waiting)
waiting_priorities = [req.priority for req in waiting_requests]
waiting_req_ids = [req.request_id for req in waiting_requests]
# Should be req_1 (priority 2) then req_0 (priority 3)
assert waiting_priorities == [2, 3]
assert waiting_req_ids == ["1", "0"]
def test_priority_scheduling_equal_priority_preemption():
"""Test arrival time tiebreaker when requests have equal priority."""
# This test verifies that arrival time is used as a tiebreaker for equal
# priorities
scheduler = create_scheduler_with_priority(
max_num_seqs=1, # Force sequential processing
)
# Create requests with same priority but different arrival times
requests = create_requests_with_priority(
num_requests=3,
priorities=[2, 2, 2], # Same priority
arrival_times=[3.0, 1.0, 2.0], # Different arrival times
num_tokens=10,
)
# Add all requests
for request in requests:
scheduler.add_request(request)
# Schedule - should schedule the request with earliest arrival time
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 1
assert output.scheduled_new_reqs[0].req_id == "1" # Earliest arrival (1.0)
# Verify the waiting queue has remaining requests in arrival time order
assert len(scheduler.waiting) == 2
# Extract waiting requests and verify arrival time order
waiting_requests = list(scheduler.waiting)
waiting_arrival_times = [req.arrival_time for req in waiting_requests]
waiting_req_ids = [req.request_id for req in waiting_requests]
# Should be req_2 (arrival 2.0) then req_0 (arrival 3.0)
assert waiting_arrival_times == [2.0, 3.0]
assert waiting_req_ids == ["2", "0"]
def test_priority_scheduling_waiting_queue_order():
"""Test that the waiting queue maintains priority order."""
scheduler = create_scheduler_with_priority(
max_num_seqs=1, # Only one request can run at a time
)
# Create multiple requests with different priorities
requests = create_requests_with_priority(
num_requests=4,
priorities=[3, 1, 2, 0], # Mixed priorities
arrival_times=[1.0, 2.0, 3.0, 4.0],
num_tokens=10,
)
# Add all requests
for request in requests:
scheduler.add_request(request)
# Schedule - should only schedule the highest priority request
# (req_3, priority 0)
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 1
assert output.scheduled_new_reqs[0].req_id == "3"
# Verify waiting queue has remaining requests in priority order
assert len(scheduler.waiting) == 3
# Extract requests from waiting queue
# (it's a heap, so we need to pop to see order)
waiting_requests = list(scheduler.waiting)
waiting_priorities = [req.priority for req in waiting_requests]
waiting_req_ids = [req.request_id for req in waiting_requests]
# Should be ordered by priority: req_1 (1), req_2 (2), req_0 (3)
assert waiting_req_ids == ["1", "2", "0"]
assert waiting_priorities == [1, 2, 3]
def test_priority_scheduling_fcfs_fallback():
"""Test that FCFS behavior is maintained when all
requests have same priority."""
scheduler = create_scheduler_with_priority()
# Create requests with same priority but different arrival times
priorities = [1, 1, 1, 1] # All same priority
arrival_times = [4.0, 1.0, 3.0, 2.0] # Different arrival times
requests = create_requests_with_priority(
num_requests=4, priorities=priorities, arrival_times=arrival_times
)
# Add requests
for request in requests:
scheduler.add_request(request)
# Schedule
output = scheduler.schedule()
# Should schedule all requests in arrival time order
assert len(output.scheduled_new_reqs) == 4
scheduled_req_ids = [req.req_id for req in output.scheduled_new_reqs]
# Expected order by arrival time:
# req_1 (1.0), req_3 (2.0), req_2 (3.0), req_0 (4.0)
assert scheduled_req_ids == ["1", "3", "2", "0"]
def test_priority_scheduling_with_limited_slots():
"""Test priority scheduling when max_num_seqs limits concurrent requests."""
scheduler = create_scheduler_with_priority(
max_num_seqs=2, # Only allow 2 concurrent requests
max_num_batched_tokens=1000, # Plenty of token budget
)
# Create requests with different priorities
requests = create_requests_with_priority(
num_requests=4,
priorities=[3, 1, 2, 0], # Mixed priorities
arrival_times=[1.0, 2.0, 3.0, 4.0],
num_tokens=10,
)
# Add all requests
for request in requests:
scheduler.add_request(request)
# Schedule - should only schedule the 2 highest priority requests
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 2
# Should schedule req_3 (priority 0) and req_1 (priority 1)
scheduled_req_ids = [req.req_id for req in output.scheduled_new_reqs]
assert "3" in scheduled_req_ids # Priority 0
assert "1" in scheduled_req_ids # Priority 1
# Remaining requests should be in waiting queue in priority order
assert len(scheduler.waiting) == 2
# Extract waiting requests and verify order
waiting_requests = list(scheduler.waiting)
waiting_priorities = [req.priority for req in waiting_requests]
waiting_req_ids = [req.request_id for req in waiting_requests]
# Should be req_2 (priority 2) then req_0 (priority 3)
assert waiting_priorities == [2, 3]
assert waiting_req_ids == ["2", "0"]
def test_priority_scheduling_heap_property():
"""Test that the waiting queue maintains heap
property for priority scheduling."""
scheduler = create_scheduler_with_priority(
max_num_seqs=1, # Only one request can run at a time
)
# Add requests in random priority order
priorities = [5, 1, 8, 3, 2, 7, 4, 6]
arrival_times = [float(i) for i in range(len(priorities))]
requests = create_requests_with_priority(
num_requests=len(priorities),
priorities=priorities,
arrival_times=arrival_times,
num_tokens=10,
)
# Add all requests
for request in requests:
scheduler.add_request(request)
# Schedule one request at a time and verify priority order
scheduled_priorities = []
while scheduler.waiting:
output = scheduler.schedule()
if output.scheduled_new_reqs:
req = output.scheduled_new_reqs[0]
scheduled_priorities.append(requests[int(req.req_id)].priority)
# Simulate completion to make room for next request
model_output = ModelRunnerOutput(
req_ids=[req.req_id],
sampled_token_ids=[[100]],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(output, model_output)
# Finish the request to make room for the next one
scheduler.finish_requests(req.req_id, RequestStatus.FINISHED_STOPPED)
# Verify requests were scheduled in priority order (lowest value first)
expected_priorities = sorted(priorities)
assert scheduled_priorities == expected_priorities
def test_schedule_skip_tokenizer_init():
scheduler = create_scheduler(skip_tokenizer_init=True)
requests = create_requests(num_requests=5)
for request in requests:
scheduler.add_request(request)
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == len(requests)
assert output.grammar_bitmask is None
def test_schedule_skip_tokenizer_init_structured_output_request():
scheduler = create_scheduler(skip_tokenizer_init=True)
structured_outputs_params = StructuredOutputsParams(regex="[0-9]+")
sampling_params = SamplingParams(
ignore_eos=False,
max_tokens=16,
structured_outputs=structured_outputs_params,
)
request = Request(
request_id="0",
prompt_token_ids=[0, 1],
mm_features=None,
sampling_params=sampling_params,
pooling_params=None,
eos_token_id=EOS_TOKEN_ID,
structured_output_request=StructuredOutputRequest(sampling_params),
)
scheduler.add_request(request)
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 0
assert len(scheduler.running) == 0
assert len(scheduler.waiting) == 1
def test_priority_scheduling_preemption_and_resumption_when_out_of_kv():
"""Test that priority scheduling preempts lower priority requests
when out of KV cache space."""
# Create scheduler with very limited memory to force preemption
scheduler = create_scheduler_with_priority(
max_num_seqs=2, # Allow multiple requests
max_num_batched_tokens=200,
num_blocks=5, # Can hold 64 tokens (first block is null)
block_size=16, # Standard block size
use_kv_connector=True,
)
# Create a request and schedule it
request_low = create_requests_with_priority(
num_requests=1,
priorities=[1],
arrival_times=[0.0],
num_tokens=30,
starting_idx=0,
)[0]
scheduler.add_request(request_low)
# 1st schedule
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 1
assert len(scheduler.waiting) == 0
assert len(scheduler.running) == 1
# Simulate model execution - 1st decode
model_output = ModelRunnerOutput(
req_ids=[request_low.request_id],
sampled_token_ids=[[100]],
# spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(output, model_output)
# Create a high priority request and schedule it
request_high = create_requests_with_priority(
num_requests=1,
priorities=[0],
arrival_times=[1.0],
num_tokens=32,
starting_idx=1,
)[0]
scheduler.add_request(request_high)
# 2nd schedule
output = scheduler.schedule()
# KV cache should be full at this point
assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == 0
assert len(output.scheduled_new_reqs) == 1
assert output.scheduled_cached_reqs.num_reqs == 1
assert len(scheduler.waiting) == 0
assert len(scheduler.running) == 2
# Simulate model execution - 2nd decode
requests = [request_low, request_high]
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
sampled_token_ids=[[100] for _ in requests],
# spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(output, model_output)
# 3rd schedule - this should trigger preemption
# req_low needs 32 tokens = 2 blocks
# req_high needs 33 tokens = 3 blocks
# so doesn't fit in 4 blocks.
output = scheduler.schedule()
# Should have preempted req_low
assert len(output.scheduled_new_reqs) == 0
assert output.scheduled_cached_reqs.num_reqs == 1
assert output.scheduled_cached_reqs.req_ids[0] == request_high.request_id
assert scheduler.requests[request_low.request_id].status == RequestStatus.PREEMPTED
assert len(scheduler.waiting) == 1
assert len(scheduler.running) == 1
# Simulate model execution - 3rd decode
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
sampled_token_ids=[[], [100]],
# spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
# Finish the requests to make room for the preempted requests to resume
scheduler.update_from_output(output, model_output)
scheduler.finish_requests(request_high.request_id, RequestStatus.FINISHED_STOPPED)
# 4th Schedule - this should trigger the resumption
output = scheduler.schedule()
scheduled_cached_reqs = output.scheduled_cached_reqs
resumed_from_preemption = scheduled_cached_reqs.resumed_from_preemption
assert len(output.scheduled_new_reqs) == 0
assert scheduled_cached_reqs.num_reqs == 1
assert len(scheduler.waiting) == 0
assert len(scheduler.running) == 1
# Preempted request resumed in scheduled_cached_reqs
assert len(resumed_from_preemption) == 1
assert len(scheduled_cached_reqs.resumed_req_token_ids) == 1
assert resumed_from_preemption[0]
assert scheduled_cached_reqs.req_ids[0] == request_low.request_id
assert scheduled_cached_reqs.resumed_req_token_ids[0] is not None
# Resumed tokens include 30 prompt tokens and 2 decoded tokens
assert len(scheduled_cached_reqs.resumed_req_token_ids[0]) == 32
assert scheduled_cached_reqs.resumed_req_token_ids[0][31] == 100
@pytest.mark.parametrize(
("enable_chunked_prefill", "is_encoder_decoder", "expect_enabled"),
[
(True, False, True),
(False, False, False),
# Encoder-decoder models should always have it disabled
(False, True, False),
(True, True, False),
],
)
def test_chunked_prefill_disabled_for_encoder_decoder(
enable_chunked_prefill: bool, is_encoder_decoder: bool, expect_enabled: bool
) -> None:
"""Validate that chunked prefill is appropriately disabled for
encoder-decoder models."""
scheduler_config = SchedulerConfig(
enable_chunked_prefill=enable_chunked_prefill,
is_encoder_decoder=is_encoder_decoder,
)
# `is_encoder_decoder` should only be used during construction
# of the config, and otherwise stored in the model config.
assert "is_encoder_decoder" not in vars(scheduler_config)
assert "is_encoder_decoder" not in [
f.name for f in dataclasses.fields(scheduler_config)
]
_validate_chunked_prefill_settings_for_encoder_decoder(
scheduler_config, is_encoder_decoder, expect_enabled
)
# Ensure it is retained in VllmConfig, even after its post-init.
vllm_config = VllmConfig(scheduler_config=scheduler_config)
_validate_chunked_prefill_settings_for_encoder_decoder(
vllm_config.scheduler_config, is_encoder_decoder, expect_enabled
)
def _validate_chunked_prefill_settings_for_encoder_decoder(
scheduler_config: SchedulerConfig, is_encoder_decoder: bool, expect_enabled: bool
) -> None:
"""Validate chunked prefill settings in the scheduler config for
encoder-decoder models."""
assert scheduler_config.chunked_prefill_enabled is expect_enabled
assert scheduler_config.enable_chunked_prefill is expect_enabled
if is_encoder_decoder:
# Encoder-decoder models should automatically disable chunked multimodal
# inputs as well
assert scheduler_config.disable_chunked_mm_input is not expect_enabled
if is_encoder_decoder and not expect_enabled:
assert scheduler_config.long_prefill_token_threshold == 0