mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
Compare commits
1 Commits
v0.11.0rc4
...
disable-sd
Author | SHA1 | Date | |
---|---|---|---|
b73fdb927a |
@ -24,6 +24,7 @@ def create_scheduler(
|
||||
model: str = "facebook/opt-125m",
|
||||
max_num_seqs: int = 16,
|
||||
max_num_batched_tokens: int = 8192,
|
||||
max_num_spec_tokens: Optional[int] = None,
|
||||
enable_prefix_caching: Optional[bool] = None,
|
||||
long_prefill_token_threshold: int = 0,
|
||||
disable_chunked_mm_input: bool = False,
|
||||
@ -51,6 +52,7 @@ def create_scheduler(
|
||||
scheduler_config = SchedulerConfig(
|
||||
max_num_seqs=max_num_seqs,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
max_num_spec_tokens=max_num_spec_tokens,
|
||||
max_model_len=max_model_len,
|
||||
long_prefill_token_threshold=long_prefill_token_threshold,
|
||||
disable_chunked_mm_input=disable_chunked_mm_input,
|
||||
@ -684,6 +686,59 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool],
|
||||
scheduler.update_from_output(scheduler_output1, model_runner_output)
|
||||
|
||||
|
||||
def test_spec_token_budget():
|
||||
"""Test scheduling behavior when spec token buget limits the total
|
||||
number of scheduled tokens."""
|
||||
# Create scheduler with spec_token_budget=5
|
||||
scheduler = create_scheduler(
|
||||
max_num_batched_tokens=100,
|
||||
max_num_spec_tokens=14, # Total spec budget for this test
|
||||
)
|
||||
|
||||
requests = create_requests(
|
||||
num_requests=2,
|
||||
num_tokens=10,
|
||||
)
|
||||
|
||||
spec_tokens = [list(range(10)), list(range(5))]
|
||||
req_ids = []
|
||||
req_to_index = {}
|
||||
for i, request in enumerate(requests):
|
||||
scheduler.add_request(request)
|
||||
req_ids.append(request.request_id)
|
||||
req_to_index[request.request_id] = i
|
||||
output = scheduler.schedule()
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=req_ids,
|
||||
req_id_to_index=req_to_index,
|
||||
sampled_token_ids=[[0] for _ in range(len(requests))],
|
||||
spec_token_ids=spec_tokens,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
)
|
||||
scheduler.update_from_output(output, model_runner_output)
|
||||
|
||||
output = scheduler.schedule()
|
||||
request1, request2 = requests
|
||||
# --- Verify request1 ---
|
||||
# num_new_tokens = 11
|
||||
# num_scheduled_spec_tokens = 10
|
||||
# Budget starts at 14: 10 <= 14 → no truncation
|
||||
# num_new_tokens = min(10, 14) → 10
|
||||
assert len(request1.spec_token_ids) == 10 # Not truncated
|
||||
assert output.num_scheduled_tokens[request1.request_id] == 11
|
||||
assert len(output.scheduled_spec_decode_tokens[request1.request_id]) == 10
|
||||
|
||||
# --- Verify request2 ---
|
||||
# Remaining budget after request1: 14 - 10 = 4
|
||||
# num_new_tokens = 6
|
||||
# num_scheduled_spec_tokens = 6-1 = 5 > 4 → truncate to 4
|
||||
# num_new_tokens = min(5, 4) → 4
|
||||
assert len(request2.spec_token_ids) == 4 # Truncated from 5
|
||||
assert output.num_scheduled_tokens[request2.request_id] == 5
|
||||
assert len(output.scheduled_spec_decode_tokens[request2.request_id]) == 4
|
||||
|
||||
|
||||
# Note - these test cases mirror some of those in test_rejection_sampler.py
|
||||
@pytest.mark.parametrize(
|
||||
"spec_tokens,output_tokens,expected",
|
||||
|
@ -1841,6 +1841,9 @@ class SchedulerConfig:
|
||||
is primarily set in `ModelConfig` and that value should be manually
|
||||
duplicated here."""
|
||||
|
||||
max_num_spec_tokens: int = None # type: ignore
|
||||
"""Maximum number of speculative tokens for all requests in the batch."""
|
||||
|
||||
max_num_partial_prefills: int = 1
|
||||
"""For chunked prefill, the maximum number of sequences that can be
|
||||
partially prefilled concurrently."""
|
||||
|
@ -62,6 +62,7 @@ class Scheduler(SchedulerInterface):
|
||||
self.max_num_scheduled_tokens = \
|
||||
self.scheduler_config.max_num_batched_tokens
|
||||
self.max_model_len = self.scheduler_config.max_model_len
|
||||
self.max_num_spec_tokens = self.scheduler_config.max_num_spec_tokens
|
||||
|
||||
# Create KVConnector for the Scheduler. Note that each Worker
|
||||
# will have a corresponding KVConnector with Role=WORKER.
|
||||
@ -162,6 +163,8 @@ class Scheduler(SchedulerInterface):
|
||||
req_to_new_block_ids: dict[str, list[int]] = {}
|
||||
num_scheduled_tokens: dict[str, int] = {}
|
||||
token_budget = self.max_num_scheduled_tokens
|
||||
spec_token_budget = self.max_num_spec_tokens
|
||||
|
||||
# Encoder-related.
|
||||
scheduled_encoder_inputs: dict[str, list[int]] = {}
|
||||
encoder_budget = self.max_num_encoder_input_tokens
|
||||
@ -184,6 +187,19 @@ class Scheduler(SchedulerInterface):
|
||||
self.scheduler_config.long_prefill_token_threshold)
|
||||
num_new_tokens = min(num_new_tokens, token_budget)
|
||||
|
||||
num_scheduled_spec_tokens = (num_new_tokens +
|
||||
request.num_computed_tokens -
|
||||
request.num_tokens)
|
||||
if spec_token_budget:
|
||||
if num_scheduled_spec_tokens > spec_token_budget:
|
||||
# We don't truncate the spec_token_ids list here because
|
||||
# it will be trimmed in the end of the while loop.
|
||||
num_scheduled_spec_tokens = spec_token_budget
|
||||
# +1 here to include the last generated token.
|
||||
num_new_tokens = min(num_new_tokens,
|
||||
num_scheduled_spec_tokens + 1)
|
||||
spec_token_budget -= num_scheduled_spec_tokens
|
||||
|
||||
# Make sure the input position does not exceed the max model len.
|
||||
# This is necessary when using spec decoding.
|
||||
num_new_tokens = min(
|
||||
|
Reference in New Issue
Block a user