Compare commits

...

1 Commits

Author SHA1 Message Date
b73fdb927a draft
Signed-off-by: LiuXiaoxuanPKU <lilyliupku@gmail.com>
2025-05-03 10:50:34 -07:00
3 changed files with 74 additions and 0 deletions

View File

@ -24,6 +24,7 @@ def create_scheduler(
model: str = "facebook/opt-125m",
max_num_seqs: int = 16,
max_num_batched_tokens: int = 8192,
max_num_spec_tokens: Optional[int] = None,
enable_prefix_caching: Optional[bool] = None,
long_prefill_token_threshold: int = 0,
disable_chunked_mm_input: bool = False,
@ -51,6 +52,7 @@ def create_scheduler(
scheduler_config = SchedulerConfig(
max_num_seqs=max_num_seqs,
max_num_batched_tokens=max_num_batched_tokens,
max_num_spec_tokens=max_num_spec_tokens,
max_model_len=max_model_len,
long_prefill_token_threshold=long_prefill_token_threshold,
disable_chunked_mm_input=disable_chunked_mm_input,
@ -684,6 +686,59 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool],
scheduler.update_from_output(scheduler_output1, model_runner_output)
def test_spec_token_budget():
"""Test scheduling behavior when spec token buget limits the total
number of scheduled tokens."""
# Create scheduler with spec_token_budget=5
scheduler = create_scheduler(
max_num_batched_tokens=100,
max_num_spec_tokens=14, # Total spec budget for this test
)
requests = create_requests(
num_requests=2,
num_tokens=10,
)
spec_tokens = [list(range(10)), list(range(5))]
req_ids = []
req_to_index = {}
for i, request in enumerate(requests):
scheduler.add_request(request)
req_ids.append(request.request_id)
req_to_index[request.request_id] = i
output = scheduler.schedule()
model_runner_output = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_to_index,
sampled_token_ids=[[0] for _ in range(len(requests))],
spec_token_ids=spec_tokens,
logprobs=None,
prompt_logprobs_dict={},
)
scheduler.update_from_output(output, model_runner_output)
output = scheduler.schedule()
request1, request2 = requests
# --- Verify request1 ---
# num_new_tokens = 11
# num_scheduled_spec_tokens = 10
# Budget starts at 14: 10 <= 14 → no truncation
# num_new_tokens = min(10, 14) → 10
assert len(request1.spec_token_ids) == 10 # Not truncated
assert output.num_scheduled_tokens[request1.request_id] == 11
assert len(output.scheduled_spec_decode_tokens[request1.request_id]) == 10
# --- Verify request2 ---
# Remaining budget after request1: 14 - 10 = 4
# num_new_tokens = 6
# num_scheduled_spec_tokens = 6-1 = 5 > 4 → truncate to 4
# num_new_tokens = min(5, 4) → 4
assert len(request2.spec_token_ids) == 4 # Truncated from 5
assert output.num_scheduled_tokens[request2.request_id] == 5
assert len(output.scheduled_spec_decode_tokens[request2.request_id]) == 4
# Note - these test cases mirror some of those in test_rejection_sampler.py
@pytest.mark.parametrize(
"spec_tokens,output_tokens,expected",

View File

@ -1841,6 +1841,9 @@ class SchedulerConfig:
is primarily set in `ModelConfig` and that value should be manually
duplicated here."""
max_num_spec_tokens: int = None # type: ignore
"""Maximum number of speculative tokens for all requests in the batch."""
max_num_partial_prefills: int = 1
"""For chunked prefill, the maximum number of sequences that can be
partially prefilled concurrently."""

View File

@ -62,6 +62,7 @@ class Scheduler(SchedulerInterface):
self.max_num_scheduled_tokens = \
self.scheduler_config.max_num_batched_tokens
self.max_model_len = self.scheduler_config.max_model_len
self.max_num_spec_tokens = self.scheduler_config.max_num_spec_tokens
# Create KVConnector for the Scheduler. Note that each Worker
# will have a corresponding KVConnector with Role=WORKER.
@ -162,6 +163,8 @@ class Scheduler(SchedulerInterface):
req_to_new_block_ids: dict[str, list[int]] = {}
num_scheduled_tokens: dict[str, int] = {}
token_budget = self.max_num_scheduled_tokens
spec_token_budget = self.max_num_spec_tokens
# Encoder-related.
scheduled_encoder_inputs: dict[str, list[int]] = {}
encoder_budget = self.max_num_encoder_input_tokens
@ -184,6 +187,19 @@ class Scheduler(SchedulerInterface):
self.scheduler_config.long_prefill_token_threshold)
num_new_tokens = min(num_new_tokens, token_budget)
num_scheduled_spec_tokens = (num_new_tokens +
request.num_computed_tokens -
request.num_tokens)
if spec_token_budget:
if num_scheduled_spec_tokens > spec_token_budget:
# We don't truncate the spec_token_ids list here because
# it will be trimmed in the end of the while loop.
num_scheduled_spec_tokens = spec_token_budget
# +1 here to include the last generated token.
num_new_tokens = min(num_new_tokens,
num_scheduled_spec_tokens + 1)
spec_token_budget -= num_scheduled_spec_tokens
# Make sure the input position does not exceed the max model len.
# This is necessary when using spec decoding.
num_new_tokens = min(